From 4d3ce0a65e8442b9993578872b1facad258b7c17 Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Tue, 27 Jun 2023 13:21:50 -0400 Subject: [PATCH] fixed setting different clip, removed unused stubs (#2987) --- machine-learning/app/main.py | 15 ++++----------- machine-learning/app/models/__init__.py | 2 +- machine-learning/app/models/base.py | 2 +- machine-learning/app/models/clip.py | 10 ---------- machine-learning/app/schemas.py | 2 -- 5 files changed, 6 insertions(+), 25 deletions(-) diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index 49436977bf..e59d0d8382 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -27,13 +27,10 @@ app = FastAPI() @app.on_event("startup") async def startup_event() -> None: app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True) - same_clip = settings.clip_image_model == settings.clip_text_model - app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION - app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT models = [ (settings.classification_model, ModelType.IMAGE_CLASSIFICATION), - (settings.clip_image_model, app.state.clip_vision_type), - (settings.clip_text_model, app.state.clip_text_type), + (settings.clip_image_model, ModelType.CLIP), + (settings.clip_text_model, ModelType.CLIP), (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION), ] @@ -87,9 +84,7 @@ async def image_classification( async def clip_encode_image( image: Image.Image = Depends(dep_pil_image), ) -> list[float]: - model = await app.state.model_cache.get( - settings.clip_image_model, app.state.clip_vision_type - ) + model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP) embedding = model.predict(image) return embedding @@ -100,9 +95,7 @@ async def clip_encode_image( status_code=200, ) async def clip_encode_text(payload: TextModelRequest) -> list[float]: - model = await app.state.model_cache.get( - settings.clip_text_model, app.state.clip_text_type - ) + model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP) embedding = model.predict(payload.text) return embedding diff --git a/machine-learning/app/models/__init__.py b/machine-learning/app/models/__init__.py index b646135053..e5b5aa7599 100644 --- a/machine-learning/app/models/__init__.py +++ b/machine-learning/app/models/__init__.py @@ -1,3 +1,3 @@ -from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder +from .clip import CLIPSTEncoder from .facial_recognition import FaceRecognizer from .image_classification import ImageClassifier diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index 122f3627e6..0ef3173ce8 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from pathlib import Path from typing import Any diff --git a/machine-learning/app/models/clip.py b/machine-learning/app/models/clip.py index 51731f790e..9e55b28d57 100644 --- a/machine-learning/app/models/clip.py +++ b/machine-learning/app/models/clip.py @@ -25,13 +25,3 @@ class CLIPSTEncoder(InferenceModel): def predict(self, image_or_text: Image | str) -> list[float]: return self.model.encode(image_or_text).tolist() - - -# stubs to allow different behavior between the two in the future -# and handle loading different image and text clip models -class CLIPSTVisionEncoder(CLIPSTEncoder): - _model_type = ModelType.CLIP_VISION - - -class CLIPSTTextEncoder(CLIPSTEncoder): - _model_type = ModelType.CLIP_TEXT diff --git a/machine-learning/app/schemas.py b/machine-learning/app/schemas.py index db6b7b50bc..16618faa68 100644 --- a/machine-learning/app/schemas.py +++ b/machine-learning/app/schemas.py @@ -61,6 +61,4 @@ class FaceResponse(BaseModel): class ModelType(Enum): IMAGE_CLASSIFICATION = "image-classification" CLIP = "clip" - CLIP_VISION = "clip-vision" - CLIP_TEXT = "clip-text" FACIAL_RECOGNITION = "facial-recognition"