0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-21 00:52:43 -05:00

fixed setting different clip, removed unused stubs (#2987)

This commit is contained in:
Mert 2023-06-27 13:21:50 -04:00 committed by GitHub
parent b3e97a1a0c
commit 4d3ce0a65e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 6 additions and 25 deletions

View file

@ -27,13 +27,10 @@ app = FastAPI()
@app.on_event("startup")
async def startup_event() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
same_clip = settings.clip_image_model == settings.clip_text_model
app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
models = [
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
(settings.clip_image_model, app.state.clip_vision_type),
(settings.clip_text_model, app.state.clip_text_type),
(settings.clip_image_model, ModelType.CLIP),
(settings.clip_text_model, ModelType.CLIP),
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
]
@ -87,9 +84,7 @@ async def image_classification(
async def clip_encode_image(
image: Image.Image = Depends(dep_pil_image),
) -> list[float]:
model = await app.state.model_cache.get(
settings.clip_image_model, app.state.clip_vision_type
)
model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
embedding = model.predict(image)
return embedding
@ -100,9 +95,7 @@ async def clip_encode_image(
status_code=200,
)
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
model = await app.state.model_cache.get(
settings.clip_text_model, app.state.clip_text_type
)
model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
embedding = model.predict(payload.text)
return embedding

View file

@ -1,3 +1,3 @@
from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder
from .clip import CLIPSTEncoder
from .facial_recognition import FaceRecognizer
from .image_classification import ImageClassifier

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from abc import abstractmethod, ABC
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any

View file

@ -25,13 +25,3 @@ class CLIPSTEncoder(InferenceModel):
def predict(self, image_or_text: Image | str) -> list[float]:
return self.model.encode(image_or_text).tolist()
# stubs to allow different behavior between the two in the future
# and handle loading different image and text clip models
class CLIPSTVisionEncoder(CLIPSTEncoder):
_model_type = ModelType.CLIP_VISION
class CLIPSTTextEncoder(CLIPSTEncoder):
_model_type = ModelType.CLIP_TEXT

View file

@ -61,6 +61,4 @@ class FaceResponse(BaseModel):
class ModelType(Enum):
IMAGE_CLASSIFICATION = "image-classification"
CLIP = "clip"
CLIP_VISION = "clip-vision"
CLIP_TEXT = "clip-text"
FACIAL_RECOGNITION = "facial-recognition"