0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-07 00:50:23 -05:00

chore(ml): use strict mypy (#5001)

* improved typing

* improved export typing

* strict mypy & check export folder

* formatting

* add formatting checks for export folder

* re-added init call
This commit is contained in:
Mert 2023-11-13 11:18:46 -05:00 committed by GitHub
parent 9fa9ad05b1
commit 935f471ccb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 70 additions and 55 deletions

View file

@ -168,13 +168,13 @@ jobs:
poetry install --with dev
- name: Lint with ruff
run: |
poetry run ruff check --format=github app
poetry run ruff check --format=github app export
- name: Check black formatting
run: |
poetry run black --check app
poetry run black --check app export
- name: Run mypy type checking
run: |
poetry run mypy --install-types --non-interactive app/
poetry run mypy --install-types --non-interactive --strict app/ export/
- name: Run tests and coverage
run: |
poetry run pytest --cov app

View file

@ -36,7 +36,8 @@ def deployed_app() -> TestClient:
@pytest.fixture(scope="session")
def responses() -> dict[str, Any]:
return json.load(open("responses.json", "r"))
responses: dict[str, Any] = json.load(open("responses.json", "r"))
return responses
@pytest.fixture(scope="session")

View file

@ -7,7 +7,7 @@ from zipfile import BadZipFile
import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel

View file

@ -8,6 +8,7 @@ from typing import Any
import onnxruntime as ort
from huggingface_hub import snapshot_download
from typing_extensions import Buffer
from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelType
@ -139,11 +140,12 @@ class InferenceModel(ABC):
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions):
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
def __getstate__(self) -> bytes:
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
def __setstate__(self, state: Any) -> None:
self.__init__() # type: ignore
for attr, val in pickle.loads(state):
def __setstate__(self, state: Buffer) -> None:
self.__init__() # type: ignore[misc]
attrs: list[tuple[str, Any]] = pickle.loads(state)
for attr, val in attrs:
setattr(self, attr, val)

View file

@ -6,7 +6,7 @@ from aiocache.plugins import BasePlugin, TimingPlugin
from app.models import from_model_type
from ..schemas import ModelType
from ..schemas import ModelType, has_profiling
from .base import InferenceModel
@ -50,20 +50,20 @@ class ModelCache:
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
async with OptimisticLock(self.cache, key) as lock:
model = await self.cache.get(key)
model: InferenceModel | None = await self.cache.get(key)
if model is None:
model = from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl)
return model
async def get_profiling(self) -> dict[str, float] | None:
if not hasattr(self.cache, "profiling"):
if not has_profiling(self.cache):
return None
return self.cache.profiling # type: ignore
return self.cache.profiling
class RevalidationPlugin(BasePlugin):
class RevalidationPlugin(BasePlugin): # type: ignore[misc]
"""Revalidates cache item's TTL after cache hit."""
async def post_get(

View file

@ -51,7 +51,7 @@ class BaseCLIPEncoder(InferenceModel):
provider_options=self.provider_options,
)
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
if isinstance(image_or_text, bytes):
image_or_text = Image.open(BytesIO(image_or_text))
@ -60,16 +60,16 @@ class BaseCLIPEncoder(InferenceModel):
if self.mode == "text":
raise TypeError("Cannot encode image as text-only model")
outputs = self.vision_model.run(None, self.transform(image_or_text))
outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
case str():
if self.mode == "vision":
raise TypeError("Cannot encode text as vision-only model")
outputs = self.text_model.run(None, self.tokenize(image_or_text))
outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
return outputs[0][0].tolist()
return outputs
@abstractmethod
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
@ -151,11 +151,13 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
@cached_property
def model_cfg(self) -> dict[str, Any]:
return json.load(self.model_cfg_path.open())
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
return model_cfg
@cached_property
def preprocess_cfg(self) -> dict[str, Any]:
return json.load(self.preprocess_cfg_path.open())
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
return preprocess_cfg
class MCLIPEncoder(OpenCLIPEncoder):

View file

@ -8,7 +8,7 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
from insightface.utils.face_align import norm_crop
from app.config import clean_name
from app.schemas import ModelType, ndarray_f32
from app.schemas import BoundingBox, Face, ModelType, ndarray_f32
from .base import InferenceModel
@ -52,7 +52,7 @@ class FaceRecognizer(InferenceModel):
)
self.rec_model.prepare(ctx_id=0)
def _predict(self, image: ndarray_f32 | bytes) -> list[dict[str, Any]]:
def _predict(self, image: ndarray_f32 | bytes) -> list[Face]:
if isinstance(image, bytes):
image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
bboxes, kpss = self.det_model.detect(image)
@ -67,9 +67,8 @@ class FaceRecognizer(InferenceModel):
height, width, _ = image.shape
for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
cropped_img = norm_crop(image, kps)
embedding = self.rec_model.get_feat(cropped_img)[0].tolist()
results.append(
{
embedding: ndarray_f32 = self.rec_model.get_feat(cropped_img)[0]
face: Face = {
"imageWidth": width,
"imageHeight": height,
"boundingBox": {
@ -81,7 +80,7 @@ class FaceRecognizer(InferenceModel):
"score": score,
"embedding": embedding,
}
)
results.append(face)
return results
@property

View file

@ -66,7 +66,7 @@ class ImageClassifier(InferenceModel):
def _predict(self, image: Image.Image | bytes) -> list[str]:
if isinstance(image, bytes):
image = Image.open(BytesIO(image))
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
predictions: list[dict[str, Any]] = self.model(image)
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
return tags

View file

@ -1,17 +1,12 @@
from enum import StrEnum
from typing import TypeAlias
from typing import Any, Protocol, TypeAlias, TypedDict, TypeGuard
import numpy as np
from pydantic import BaseModel
def to_lower_camel(string: str) -> str:
tokens = [token.capitalize() if i > 0 else token for i, token in enumerate(string.split("_"))]
return "".join(tokens)
class TextModelRequest(BaseModel):
text: str
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
class TextResponse(BaseModel):
@ -22,7 +17,7 @@ class MessageResponse(BaseModel):
message: str
class BoundingBox(BaseModel):
class BoundingBox(TypedDict):
x1: int
y1: int
x2: int
@ -35,6 +30,17 @@ class ModelType(StrEnum):
FACIAL_RECOGNITION = "facial-recognition"
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
class HasProfiling(Protocol):
profiling: dict[str, float]
class Face(TypedDict):
boundingBox: BoundingBox
embedding: ndarray_f32
imageWidth: int
imageHeight: int
score: float
def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
return hasattr(obj, "profiling") and type(obj.profiling) == dict

View file

@ -1,6 +1,7 @@
import tempfile
import warnings
from dataclasses import dataclass, field
from math import e
from pathlib import Path
import open_clip
@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
output_path = Path(output_path)
def encode_image(image: torch.Tensor) -> torch.Tensor:
return model.encode_image(image, normalize=True)
output = model.encode_image(image, normalize=True)
assert isinstance(output, torch.Tensor)
return output
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
traced = torch.jit.trace(encode_image, args)
traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o
output_path = Path(output_path)
def encode_text(text: torch.Tensor) -> torch.Tensor:
return model.encode_text(text, normalize=True)
output = model.encode_text(text, normalize=True)
assert isinstance(output, torch.Tensor)
return output
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
traced = torch.jit.trace(encode_text, args)
traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)