mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
chore(ml): removed vit-b check and st warning (#4422)
This commit is contained in:
parent
b8d6cc1e09
commit
d8ecefaea5
1 changed files with 3 additions and 25 deletions
|
@ -16,13 +16,6 @@ from ..config import log
|
|||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
_ST_TO_JINA_MODEL_NAME = {
|
||||
"clip-ViT-B-16": "ViT-B-16::openai",
|
||||
"clip-ViT-B-32": "ViT-B-32::openai",
|
||||
"clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
|
||||
"clip-ViT-L-14": "ViT-L-14::openai",
|
||||
}
|
||||
|
||||
|
||||
class CLIPEncoder(InferenceModel):
|
||||
_model_type = ModelType.CLIP
|
||||
|
@ -36,11 +29,10 @@ class CLIPEncoder(InferenceModel):
|
|||
) -> None:
|
||||
if mode is not None and mode not in ("text", "vision"):
|
||||
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
|
||||
if "vit-b" not in model_name.lower():
|
||||
raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
|
||||
if model_name not in _MODELS:
|
||||
raise ValueError(f"Unknown model name {model_name}.")
|
||||
self.mode = mode
|
||||
jina_model_name = self._get_jina_model_name(model_name)
|
||||
super().__init__(jina_model_name, cache_dir, **model_kwargs)
|
||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||
|
||||
def _download(self) -> None:
|
||||
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
|
||||
|
@ -104,20 +96,6 @@ class CLIPEncoder(InferenceModel):
|
|||
|
||||
return outputs[0][0].tolist()
|
||||
|
||||
def _get_jina_model_name(self, model_name: str) -> str:
|
||||
if model_name in _MODELS:
|
||||
return model_name
|
||||
elif model_name in _ST_TO_JINA_MODEL_NAME:
|
||||
log.warn(
|
||||
(
|
||||
f"Sentence-Transformer models like '{model_name}' are not supported."
|
||||
f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
|
||||
),
|
||||
)
|
||||
return _ST_TO_JINA_MODEL_NAME[model_name]
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}.")
|
||||
|
||||
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
||||
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
||||
download_model(
|
||||
|
|
Loading…
Add table
Reference in a new issue