mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
70 lines
2.6 KiB
Python
70 lines
2.6 KiB
Python
|
import json
|
||
|
from abc import abstractmethod
|
||
|
from functools import cached_property
|
||
|
from pathlib import Path
|
||
|
from typing import Any
|
||
|
|
||
|
import numpy as np
|
||
|
from numpy.typing import NDArray
|
||
|
from PIL import Image
|
||
|
|
||
|
from app.config import log
|
||
|
from app.models.base import InferenceModel
|
||
|
from app.models.transforms import crop_pil, decode_pil, get_pil_resampling, normalize, resize_pil, to_numpy
|
||
|
from app.schemas import ModelSession, ModelTask, ModelType
|
||
|
|
||
|
|
||
|
class BaseCLIPVisualEncoder(InferenceModel):
|
||
|
depends = []
|
||
|
identity = (ModelType.VISUAL, ModelTask.SEARCH)
|
||
|
|
||
|
def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> NDArray[np.float32]:
|
||
|
image = decode_pil(inputs)
|
||
|
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
|
||
|
return res
|
||
|
|
||
|
@abstractmethod
|
||
|
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
def model_cfg_path(self) -> Path:
|
||
|
return self.cache_dir / "config.json"
|
||
|
|
||
|
@property
|
||
|
def preprocess_cfg_path(self) -> Path:
|
||
|
return self.model_dir / "preprocess_cfg.json"
|
||
|
|
||
|
@cached_property
|
||
|
def model_cfg(self) -> dict[str, Any]:
|
||
|
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
||
|
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||
|
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||
|
return model_cfg
|
||
|
|
||
|
@cached_property
|
||
|
def preprocess_cfg(self) -> dict[str, Any]:
|
||
|
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
|
||
|
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
||
|
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
|
||
|
return preprocess_cfg
|
||
|
|
||
|
|
||
|
class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
|
||
|
def _load(self) -> ModelSession:
|
||
|
size: list[int] | int = self.preprocess_cfg["size"]
|
||
|
self.size = size[0] if isinstance(size, list) else size
|
||
|
|
||
|
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
|
||
|
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
||
|
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
||
|
|
||
|
return super()._load()
|
||
|
|
||
|
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
||
|
image = resize_pil(image, self.size)
|
||
|
image = crop_pil(image, self.size)
|
||
|
image_np = to_numpy(image)
|
||
|
image_np = normalize(image_np, self.mean, self.std)
|
||
|
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
|