2023-10-31 05:02:04 -05:00
|
|
|
import json
|
|
|
|
from abc import abstractmethod
|
|
|
|
from functools import cached_property
|
2023-08-29 08:58:00 -05:00
|
|
|
from io import BytesIO
|
2023-10-31 05:02:04 -05:00
|
|
|
from pathlib import Path
|
2023-08-24 23:28:51 -05:00
|
|
|
from typing import Any, Literal
|
2023-06-24 22:18:09 -05:00
|
|
|
|
2023-10-31 05:02:04 -05:00
|
|
|
import numpy as np
|
2023-08-24 23:28:51 -05:00
|
|
|
import onnxruntime as ort
|
2023-08-29 08:58:00 -05:00
|
|
|
from PIL import Image
|
2023-12-20 20:47:56 -05:00
|
|
|
from tokenizers import Encoding, Tokenizer
|
2023-10-31 05:02:04 -05:00
|
|
|
|
2023-11-11 20:04:49 -05:00
|
|
|
from app.config import clean_name, log
|
2023-10-31 05:02:04 -05:00
|
|
|
from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy
|
2023-12-20 20:47:56 -05:00
|
|
|
from app.schemas import ModelType, ndarray_f32, ndarray_i32
|
2023-06-24 22:18:09 -05:00
|
|
|
|
|
|
|
from .base import InferenceModel
|
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
|
2023-10-31 05:02:04 -05:00
|
|
|
class BaseCLIPEncoder(InferenceModel):
|
2023-06-24 22:18:09 -05:00
|
|
|
_model_type = ModelType.CLIP
|
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
cache_dir: str | None = None,
|
|
|
|
mode: Literal["text", "vision"] | None = None,
|
|
|
|
**model_kwargs: Any,
|
|
|
|
) -> None:
|
|
|
|
self.mode = mode
|
2023-10-10 12:26:30 -05:00
|
|
|
super().__init__(model_name, cache_dir, **model_kwargs)
|
2023-08-24 23:28:51 -05:00
|
|
|
|
2023-09-09 04:02:44 -05:00
|
|
|
def _load(self) -> None:
|
2023-08-24 23:28:51 -05:00
|
|
|
if self.mode == "text" or self.mode is None:
|
2023-09-09 04:02:44 -05:00
|
|
|
log.debug(f"Loading clip text model '{self.model_name}'")
|
2023-10-31 05:02:04 -05:00
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
self.text_model = ort.InferenceSession(
|
2023-10-31 05:02:04 -05:00
|
|
|
self.textual_path.as_posix(),
|
2023-08-24 23:28:51 -05:00
|
|
|
sess_options=self.sess_options,
|
|
|
|
providers=self.providers,
|
|
|
|
provider_options=self.provider_options,
|
|
|
|
)
|
2023-12-20 20:47:56 -05:00
|
|
|
log.debug(f"Loaded clip text model '{self.model_name}'")
|
2023-08-24 23:28:51 -05:00
|
|
|
|
|
|
|
if self.mode == "vision" or self.mode is None:
|
2023-09-09 04:02:44 -05:00
|
|
|
log.debug(f"Loading clip vision model '{self.model_name}'")
|
2023-10-31 05:02:04 -05:00
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
self.vision_model = ort.InferenceSession(
|
2023-10-31 05:02:04 -05:00
|
|
|
self.visual_path.as_posix(),
|
2023-08-24 23:28:51 -05:00
|
|
|
sess_options=self.sess_options,
|
|
|
|
providers=self.providers,
|
|
|
|
provider_options=self.provider_options,
|
|
|
|
)
|
2023-12-20 20:47:56 -05:00
|
|
|
log.debug(f"Loaded clip vision model '{self.model_name}'")
|
2023-06-24 22:18:09 -05:00
|
|
|
|
2023-11-13 11:18:46 -05:00
|
|
|
def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
|
2023-08-29 08:58:00 -05:00
|
|
|
if isinstance(image_or_text, bytes):
|
|
|
|
image_or_text = Image.open(BytesIO(image_or_text))
|
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
match image_or_text:
|
2023-08-29 08:58:00 -05:00
|
|
|
case Image.Image():
|
2023-08-24 23:28:51 -05:00
|
|
|
if self.mode == "text":
|
|
|
|
raise TypeError("Cannot encode image as text-only model")
|
2023-10-31 05:02:04 -05:00
|
|
|
|
2023-11-13 11:18:46 -05:00
|
|
|
outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
|
2023-08-24 23:28:51 -05:00
|
|
|
case str():
|
|
|
|
if self.mode == "vision":
|
|
|
|
raise TypeError("Cannot encode text as vision-only model")
|
2023-10-31 05:02:04 -05:00
|
|
|
|
2023-11-13 11:18:46 -05:00
|
|
|
outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
|
2023-08-24 23:28:51 -05:00
|
|
|
case _:
|
|
|
|
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
|
|
|
|
2023-11-13 11:18:46 -05:00
|
|
|
return outputs
|
2023-08-24 23:28:51 -05:00
|
|
|
|
2023-10-31 05:02:04 -05:00
|
|
|
@abstractmethod
|
|
|
|
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@property
|
|
|
|
def textual_dir(self) -> Path:
|
|
|
|
return self.cache_dir / "textual"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def visual_dir(self) -> Path:
|
|
|
|
return self.cache_dir / "visual"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def model_cfg_path(self) -> Path:
|
|
|
|
return self.cache_dir / "config.json"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def textual_path(self) -> Path:
|
|
|
|
return self.textual_dir / "model.onnx"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def visual_path(self) -> Path:
|
|
|
|
return self.visual_dir / "model.onnx"
|
|
|
|
|
2023-12-20 20:47:56 -05:00
|
|
|
@property
|
|
|
|
def tokenizer_file_path(self) -> Path:
|
|
|
|
return self.textual_dir / "tokenizer.json"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def tokenizer_cfg_path(self) -> Path:
|
|
|
|
return self.textual_dir / "tokenizer_config.json"
|
|
|
|
|
2023-10-31 05:02:04 -05:00
|
|
|
@property
|
|
|
|
def preprocess_cfg_path(self) -> Path:
|
|
|
|
return self.visual_dir / "preprocess_cfg.json"
|
2023-08-24 23:28:51 -05:00
|
|
|
|
2023-09-05 20:48:40 -05:00
|
|
|
@property
|
|
|
|
def cached(self) -> bool:
|
2023-10-31 05:02:04 -05:00
|
|
|
return self.textual_path.is_file() and self.visual_path.is_file()
|
|
|
|
|
2023-12-20 20:47:56 -05:00
|
|
|
@cached_property
|
|
|
|
def model_cfg(self) -> dict[str, Any]:
|
|
|
|
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
|
|
|
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
|
|
|
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
|
|
|
return model_cfg
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
def tokenizer_file(self) -> dict[str, Any]:
|
|
|
|
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
|
|
|
tokenizer_file: dict[str, Any] = json.load(self.tokenizer_file_path.open())
|
|
|
|
log.debug(f"Loaded tokenizer file for CLIP model '{self.model_name}'")
|
|
|
|
return tokenizer_file
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
def tokenizer_cfg(self) -> dict[str, Any]:
|
|
|
|
log.debug(f"Loading tokenizer config for CLIP model '{self.model_name}'")
|
|
|
|
tokenizer_cfg: dict[str, Any] = json.load(self.tokenizer_cfg_path.open())
|
|
|
|
log.debug(f"Loaded tokenizer config for CLIP model '{self.model_name}'")
|
|
|
|
return tokenizer_cfg
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
def preprocess_cfg(self) -> dict[str, Any]:
|
|
|
|
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
|
|
|
|
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
|
|
|
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
|
|
|
|
return preprocess_cfg
|
|
|
|
|
2023-10-31 05:02:04 -05:00
|
|
|
|
|
|
|
class OpenCLIPEncoder(BaseCLIPEncoder):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
cache_dir: str | None = None,
|
|
|
|
mode: Literal["text", "vision"] | None = None,
|
|
|
|
**model_kwargs: Any,
|
|
|
|
) -> None:
|
2023-11-11 20:04:49 -05:00
|
|
|
super().__init__(clean_name(model_name), cache_dir, mode, **model_kwargs)
|
2023-10-31 05:02:04 -05:00
|
|
|
|
|
|
|
def _load(self) -> None:
|
|
|
|
super()._load()
|
|
|
|
|
2023-12-20 20:47:56 -05:00
|
|
|
context_length = self.model_cfg["text_cfg"]["context_length"]
|
|
|
|
pad_token = self.tokenizer_cfg["pad_token"]
|
2023-10-31 05:02:04 -05:00
|
|
|
|
|
|
|
self.size = (
|
|
|
|
self.preprocess_cfg["size"][0] if type(self.preprocess_cfg["size"]) == list else self.preprocess_cfg["size"]
|
|
|
|
)
|
|
|
|
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
|
|
|
|
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
|
|
|
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
|
|
|
|
2023-12-20 20:47:56 -05:00
|
|
|
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
|
|
|
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
|
|
|
pad_id = self.tokenizer.token_to_id(pad_token)
|
|
|
|
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
|
|
|
self.tokenizer.enable_truncation(max_length=context_length)
|
|
|
|
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
|
|
|
|
2023-10-31 05:02:04 -05:00
|
|
|
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
|
2023-12-20 20:47:56 -05:00
|
|
|
tokens: Encoding = self.tokenizer.encode(text)
|
|
|
|
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
2023-10-31 05:02:04 -05:00
|
|
|
|
|
|
|
def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
|
|
|
|
image = resize(image, self.size)
|
|
|
|
image = crop(image, self.size)
|
|
|
|
image_np = to_numpy(image)
|
|
|
|
image_np = normalize(image_np, self.mean, self.std)
|
|
|
|
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
|
|
|
|
|
|
|
|
|
|
|
|
class MCLIPEncoder(OpenCLIPEncoder):
|
|
|
|
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
|
2023-12-20 20:47:56 -05:00
|
|
|
tokens: Encoding = self.tokenizer.encode(text)
|
|
|
|
return {
|
|
|
|
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
|
|
|
"attention_mask": np.array([tokens.attention_mask], dtype=np.int32),
|
|
|
|
}
|