0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-21 00:52:43 -05:00

chore(ml): removed vit-b check and st warning (#4422)

This commit is contained in:
Mert 2023-10-10 13:26:30 -04:00 committed by GitHub
parent b8d6cc1e09
commit d8ecefaea5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -16,13 +16,6 @@ from ..config import log
from ..schemas import ModelType
from .base import InferenceModel
_ST_TO_JINA_MODEL_NAME = {
"clip-ViT-B-16": "ViT-B-16::openai",
"clip-ViT-B-32": "ViT-B-32::openai",
"clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
"clip-ViT-L-14": "ViT-L-14::openai",
}
class CLIPEncoder(InferenceModel):
_model_type = ModelType.CLIP
@ -36,11 +29,10 @@ class CLIPEncoder(InferenceModel):
) -> None:
if mode is not None and mode not in ("text", "vision"):
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
if "vit-b" not in model_name.lower():
raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
if model_name not in _MODELS:
raise ValueError(f"Unknown model name {model_name}.")
self.mode = mode
jina_model_name = self._get_jina_model_name(model_name)
super().__init__(jina_model_name, cache_dir, **model_kwargs)
super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self) -> None:
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
@ -104,20 +96,6 @@ class CLIPEncoder(InferenceModel):
return outputs[0][0].tolist()
def _get_jina_model_name(self, model_name: str) -> str:
if model_name in _MODELS:
return model_name
elif model_name in _ST_TO_JINA_MODEL_NAME:
log.warn(
(
f"Sentence-Transformer models like '{model_name}' are not supported."
f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
),
)
return _ST_TO_JINA_MODEL_NAME[model_name]
else:
raise ValueError(f"Unknown model name {model_name}.")
def _download_model(self, model_name: str, model_md5: str) -> bool:
# downloading logic is adapted from clip-server's CLIPOnnxModel class
download_model(