2023-08-24 23:28:51 -05:00
|
|
|
import os
|
|
|
|
import zipfile
|
2023-08-29 08:58:00 -05:00
|
|
|
from io import BytesIO
|
2023-08-24 23:28:51 -05:00
|
|
|
from typing import Any, Literal
|
2023-06-24 22:18:09 -05:00
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
import onnxruntime as ort
|
|
|
|
import torch
|
|
|
|
from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
|
|
|
|
from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
|
|
|
|
from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
|
|
|
|
from clip_server.model.tokenization import Tokenizer
|
2023-08-29 08:58:00 -05:00
|
|
|
from PIL import Image
|
2023-08-24 23:28:51 -05:00
|
|
|
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
2023-06-24 22:18:09 -05:00
|
|
|
|
2023-08-30 03:22:01 -05:00
|
|
|
from ..config import log
|
2023-06-24 22:18:09 -05:00
|
|
|
from ..schemas import ModelType
|
|
|
|
from .base import InferenceModel
|
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
|
|
|
|
class CLIPEncoder(InferenceModel):
|
2023-06-24 22:18:09 -05:00
|
|
|
_model_type = ModelType.CLIP
|
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
cache_dir: str | None = None,
|
|
|
|
mode: Literal["text", "vision"] | None = None,
|
|
|
|
**model_kwargs: Any,
|
|
|
|
) -> None:
|
|
|
|
if mode is not None and mode not in ("text", "vision"):
|
|
|
|
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
|
2023-10-10 12:26:30 -05:00
|
|
|
if model_name not in _MODELS:
|
|
|
|
raise ValueError(f"Unknown model name {model_name}.")
|
2023-08-24 23:28:51 -05:00
|
|
|
self.mode = mode
|
2023-10-10 12:26:30 -05:00
|
|
|
super().__init__(model_name, cache_dir, **model_kwargs)
|
2023-08-24 23:28:51 -05:00
|
|
|
|
2023-09-09 04:02:44 -05:00
|
|
|
def _download(self) -> None:
|
2023-08-24 23:28:51 -05:00
|
|
|
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
|
|
|
|
text_onnx_path = self.cache_dir / "textual.onnx"
|
|
|
|
vision_onnx_path = self.cache_dir / "visual.onnx"
|
|
|
|
|
|
|
|
if not text_onnx_path.is_file():
|
|
|
|
self._download_model(*models[0])
|
|
|
|
|
|
|
|
if not vision_onnx_path.is_file():
|
|
|
|
self._download_model(*models[1])
|
2023-08-05 21:45:13 -05:00
|
|
|
|
2023-09-09 04:02:44 -05:00
|
|
|
def _load(self) -> None:
|
2023-08-24 23:28:51 -05:00
|
|
|
if self.mode == "text" or self.mode is None:
|
2023-09-09 04:02:44 -05:00
|
|
|
log.debug(f"Loading clip text model '{self.model_name}'")
|
2023-08-24 23:28:51 -05:00
|
|
|
self.text_model = ort.InferenceSession(
|
|
|
|
self.cache_dir / "textual.onnx",
|
|
|
|
sess_options=self.sess_options,
|
|
|
|
providers=self.providers,
|
|
|
|
provider_options=self.provider_options,
|
|
|
|
)
|
|
|
|
self.text_outputs = [output.name for output in self.text_model.get_outputs()]
|
|
|
|
self.tokenizer = Tokenizer(self.model_name)
|
|
|
|
|
|
|
|
if self.mode == "vision" or self.mode is None:
|
2023-09-09 04:02:44 -05:00
|
|
|
log.debug(f"Loading clip vision model '{self.model_name}'")
|
2023-08-24 23:28:51 -05:00
|
|
|
self.vision_model = ort.InferenceSession(
|
|
|
|
self.cache_dir / "visual.onnx",
|
|
|
|
sess_options=self.sess_options,
|
|
|
|
providers=self.providers,
|
|
|
|
provider_options=self.provider_options,
|
|
|
|
)
|
|
|
|
self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
|
|
|
|
|
|
|
|
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
|
|
|
|
self.transform = _transform_pil_image(image_size)
|
2023-06-24 22:18:09 -05:00
|
|
|
|
2023-08-29 08:58:00 -05:00
|
|
|
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
|
|
|
|
if isinstance(image_or_text, bytes):
|
|
|
|
image_or_text = Image.open(BytesIO(image_or_text))
|
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
match image_or_text:
|
2023-08-29 08:58:00 -05:00
|
|
|
case Image.Image():
|
2023-08-24 23:28:51 -05:00
|
|
|
if self.mode == "text":
|
|
|
|
raise TypeError("Cannot encode image as text-only model")
|
|
|
|
pixel_values = self.transform(image_or_text)
|
|
|
|
assert isinstance(pixel_values, torch.Tensor)
|
|
|
|
pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
|
|
|
|
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
|
|
|
|
case str():
|
|
|
|
if self.mode == "vision":
|
|
|
|
raise TypeError("Cannot encode text as vision-only model")
|
|
|
|
text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
|
|
|
|
inputs = {
|
|
|
|
"input_ids": text_inputs["input_ids"].int().numpy(),
|
|
|
|
"attention_mask": text_inputs["attention_mask"].int().numpy(),
|
|
|
|
}
|
|
|
|
outputs = self.text_model.run(self.text_outputs, inputs)
|
|
|
|
case _:
|
|
|
|
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
|
|
|
|
|
|
|
return outputs[0][0].tolist()
|
|
|
|
|
|
|
|
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
|
|
|
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
|
|
|
download_model(
|
|
|
|
url=_S3_BUCKET_V2 + model_name,
|
|
|
|
target_folder=self.cache_dir.as_posix(),
|
|
|
|
md5sum=model_md5,
|
|
|
|
with_resume=True,
|
|
|
|
)
|
|
|
|
file = self.cache_dir / model_name.split("/")[1]
|
|
|
|
if file.suffix == ".zip":
|
|
|
|
with zipfile.ZipFile(file, "r") as zip_ref:
|
|
|
|
zip_ref.extractall(self.cache_dir)
|
|
|
|
os.remove(file)
|
|
|
|
return True
|
|
|
|
|
2023-09-05 20:48:40 -05:00
|
|
|
@property
|
|
|
|
def cached(self) -> bool:
|
|
|
|
return (self.cache_dir / "textual.onnx").is_file() and (self.cache_dir / "visual.onnx").is_file()
|
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
|
|
|
|
# same as `_transform_blob` without `_blob2image`
|
|
|
|
def _transform_pil_image(n_px: int) -> Compose:
|
|
|
|
return Compose(
|
|
|
|
[
|
|
|
|
Resize(n_px, interpolation=BICUBIC),
|
|
|
|
CenterCrop(n_px),
|
|
|
|
_convert_image_to_rgb,
|
|
|
|
ToTensor(),
|
|
|
|
Normalize(
|
|
|
|
(0.48145466, 0.4578275, 0.40821073),
|
|
|
|
(0.26862954, 0.26130258, 0.27577711),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
)
|