From 935f471ccb26ebd804526a75caa9c6655c444736 Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Mon, 13 Nov 2023 11:18:46 -0500 Subject: [PATCH] chore(ml): use strict mypy (#5001) * improved typing * improved export typing * strict mypy & check export folder * formatting * add formatting checks for export folder * re-added init call --- .github/workflows/test.yml | 6 ++-- machine-learning/app/conftest.py | 3 +- machine-learning/app/main.py | 2 +- machine-learning/app/models/base.py | 10 +++--- machine-learning/app/models/cache.py | 10 +++--- machine-learning/app/models/clip.py | 14 ++++---- .../app/models/facial_recognition.py | 33 +++++++++---------- .../app/models/image_classification.py | 2 +- machine-learning/app/schemas.py | 32 ++++++++++-------- machine-learning/export/models/openclip.py | 13 +++++--- 10 files changed, 70 insertions(+), 55 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 58efec3776..07e9f16d3b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -168,13 +168,13 @@ jobs: poetry install --with dev - name: Lint with ruff run: | - poetry run ruff check --format=github app + poetry run ruff check --format=github app export - name: Check black formatting run: | - poetry run black --check app + poetry run black --check app export - name: Run mypy type checking run: | - poetry run mypy --install-types --non-interactive app/ + poetry run mypy --install-types --non-interactive --strict app/ export/ - name: Run tests and coverage run: | poetry run pytest --cov app diff --git a/machine-learning/app/conftest.py b/machine-learning/app/conftest.py index 3bbb89c527..5e2dc1e847 100644 --- a/machine-learning/app/conftest.py +++ b/machine-learning/app/conftest.py @@ -36,7 +36,8 @@ def deployed_app() -> TestClient: @pytest.fixture(scope="session") def responses() -> dict[str, Any]: - return json.load(open("responses.json", "r")) + responses: dict[str, Any] = json.load(open("responses.json", "r")) + return responses @pytest.fixture(scope="session") diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index 375c14a9e4..e1d71e9fa2 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -7,7 +7,7 @@ from zipfile import BadZipFile import orjson from fastapi import FastAPI, Form, HTTPException, UploadFile from fastapi.responses import ORJSONResponse -from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore +from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile from starlette.formparsers import MultiPartParser from app.models.base import InferenceModel diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index 8149502ecc..d3252d0003 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -8,6 +8,7 @@ from typing import Any import onnxruntime as ort from huggingface_hub import snapshot_download +from typing_extensions import Buffer from ..config import get_cache_dir, get_hf_model_name, log, settings from ..schemas import ModelType @@ -139,11 +140,12 @@ class InferenceModel(ABC): # HF deep copies configs, so we need to make session options picklable -class PicklableSessionOptions(ort.SessionOptions): +class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc] def __getstate__(self) -> bytes: return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))]) - def __setstate__(self, state: Any) -> None: - self.__init__() # type: ignore - for attr, val in pickle.loads(state): + def __setstate__(self, state: Buffer) -> None: + self.__init__() # type: ignore[misc] + attrs: list[tuple[str, Any]] = pickle.loads(state) + for attr, val in attrs: setattr(self, attr, val) diff --git a/machine-learning/app/models/cache.py b/machine-learning/app/models/cache.py index bd8b59b3e5..1d6a0fc763 100644 --- a/machine-learning/app/models/cache.py +++ b/machine-learning/app/models/cache.py @@ -6,7 +6,7 @@ from aiocache.plugins import BasePlugin, TimingPlugin from app.models import from_model_type -from ..schemas import ModelType +from ..schemas import ModelType, has_profiling from .base import InferenceModel @@ -50,20 +50,20 @@ class ModelCache: key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}" async with OptimisticLock(self.cache, key) as lock: - model = await self.cache.get(key) + model: InferenceModel | None = await self.cache.get(key) if model is None: model = from_model_type(model_type, model_name, **model_kwargs) await lock.cas(model, ttl=self.ttl) return model async def get_profiling(self) -> dict[str, float] | None: - if not hasattr(self.cache, "profiling"): + if not has_profiling(self.cache): return None - return self.cache.profiling # type: ignore + return self.cache.profiling -class RevalidationPlugin(BasePlugin): +class RevalidationPlugin(BasePlugin): # type: ignore[misc] """Revalidates cache item's TTL after cache hit.""" async def post_get( diff --git a/machine-learning/app/models/clip.py b/machine-learning/app/models/clip.py index 296f790c3c..1dee967de2 100644 --- a/machine-learning/app/models/clip.py +++ b/machine-learning/app/models/clip.py @@ -51,7 +51,7 @@ class BaseCLIPEncoder(InferenceModel): provider_options=self.provider_options, ) - def _predict(self, image_or_text: Image.Image | str) -> list[float]: + def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32: if isinstance(image_or_text, bytes): image_or_text = Image.open(BytesIO(image_or_text)) @@ -60,16 +60,16 @@ class BaseCLIPEncoder(InferenceModel): if self.mode == "text": raise TypeError("Cannot encode image as text-only model") - outputs = self.vision_model.run(None, self.transform(image_or_text)) + outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0] case str(): if self.mode == "vision": raise TypeError("Cannot encode text as vision-only model") - outputs = self.text_model.run(None, self.tokenize(image_or_text)) + outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0] case _: raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") - return outputs[0][0].tolist() + return outputs @abstractmethod def tokenize(self, text: str) -> dict[str, ndarray_i32]: @@ -151,11 +151,13 @@ class OpenCLIPEncoder(BaseCLIPEncoder): @cached_property def model_cfg(self) -> dict[str, Any]: - return json.load(self.model_cfg_path.open()) + model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open()) + return model_cfg @cached_property def preprocess_cfg(self) -> dict[str, Any]: - return json.load(self.preprocess_cfg_path.open()) + preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open()) + return preprocess_cfg class MCLIPEncoder(OpenCLIPEncoder): diff --git a/machine-learning/app/models/facial_recognition.py b/machine-learning/app/models/facial_recognition.py index a8fa6484d3..24719eb83a 100644 --- a/machine-learning/app/models/facial_recognition.py +++ b/machine-learning/app/models/facial_recognition.py @@ -8,7 +8,7 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace from insightface.utils.face_align import norm_crop from app.config import clean_name -from app.schemas import ModelType, ndarray_f32 +from app.schemas import BoundingBox, Face, ModelType, ndarray_f32 from .base import InferenceModel @@ -52,7 +52,7 @@ class FaceRecognizer(InferenceModel): ) self.rec_model.prepare(ctx_id=0) - def _predict(self, image: ndarray_f32 | bytes) -> list[dict[str, Any]]: + def _predict(self, image: ndarray_f32 | bytes) -> list[Face]: if isinstance(image, bytes): image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR) bboxes, kpss = self.det_model.detect(image) @@ -67,21 +67,20 @@ class FaceRecognizer(InferenceModel): height, width, _ = image.shape for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss): cropped_img = norm_crop(image, kps) - embedding = self.rec_model.get_feat(cropped_img)[0].tolist() - results.append( - { - "imageWidth": width, - "imageHeight": height, - "boundingBox": { - "x1": x1, - "y1": y1, - "x2": x2, - "y2": y2, - }, - "score": score, - "embedding": embedding, - } - ) + embedding: ndarray_f32 = self.rec_model.get_feat(cropped_img)[0] + face: Face = { + "imageWidth": width, + "imageHeight": height, + "boundingBox": { + "x1": x1, + "y1": y1, + "x2": x2, + "y2": y2, + }, + "score": score, + "embedding": embedding, + } + results.append(face) return results @property diff --git a/machine-learning/app/models/image_classification.py b/machine-learning/app/models/image_classification.py index cbf784e5a4..b8c38327cd 100644 --- a/machine-learning/app/models/image_classification.py +++ b/machine-learning/app/models/image_classification.py @@ -66,7 +66,7 @@ class ImageClassifier(InferenceModel): def _predict(self, image: Image.Image | bytes) -> list[str]: if isinstance(image, bytes): image = Image.open(BytesIO(image)) - predictions: list[dict[str, Any]] = self.model(image) # type: ignore + predictions: list[dict[str, Any]] = self.model(image) tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score] return tags diff --git a/machine-learning/app/schemas.py b/machine-learning/app/schemas.py index ad1faac8c3..9e7f62fc8c 100644 --- a/machine-learning/app/schemas.py +++ b/machine-learning/app/schemas.py @@ -1,17 +1,12 @@ from enum import StrEnum -from typing import TypeAlias +from typing import Any, Protocol, TypeAlias, TypedDict, TypeGuard import numpy as np from pydantic import BaseModel - -def to_lower_camel(string: str) -> str: - tokens = [token.capitalize() if i > 0 else token for i, token in enumerate(string.split("_"))] - return "".join(tokens) - - -class TextModelRequest(BaseModel): - text: str +ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]] +ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]] +ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]] class TextResponse(BaseModel): @@ -22,7 +17,7 @@ class MessageResponse(BaseModel): message: str -class BoundingBox(BaseModel): +class BoundingBox(TypedDict): x1: int y1: int x2: int @@ -35,6 +30,17 @@ class ModelType(StrEnum): FACIAL_RECOGNITION = "facial-recognition" -ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]] -ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]] -ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]] +class HasProfiling(Protocol): + profiling: dict[str, float] + + +class Face(TypedDict): + boundingBox: BoundingBox + embedding: ndarray_f32 + imageWidth: int + imageHeight: int + score: float + + +def has_profiling(obj: Any) -> TypeGuard[HasProfiling]: + return hasattr(obj, "profiling") and type(obj.profiling) == dict diff --git a/machine-learning/export/models/openclip.py b/machine-learning/export/models/openclip.py index c29dafce74..46c11cb4ef 100644 --- a/machine-learning/export/models/openclip.py +++ b/machine-learning/export/models/openclip.py @@ -1,6 +1,7 @@ import tempfile import warnings from dataclasses import dataclass, field +from math import e from pathlib import Path import open_clip @@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path = Path(output_path) def encode_image(image: torch.Tensor) -> torch.Tensor: - return model.encode_image(image, normalize=True) + output = model.encode_image(image, normalize=True) + assert isinstance(output, torch.Tensor) + return output args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),) - traced = torch.jit.trace(encode_image, args) + traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call] with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) @@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o output_path = Path(output_path) def encode_text(text: torch.Tensor) -> torch.Tensor: - return model.encode_text(text, normalize=True) + output = model.encode_text(text, normalize=True) + assert isinstance(output, torch.Tensor) + return output args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),) - traced = torch.jit.trace(encode_text, args) + traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call] with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning)