mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
feat(ml): composable ml (#9973)
* modularize model classes * various fixes * expose port * change response * round coordinates * simplify preload * update server * simplify interface simplify * update tests * composable endpoint * cleanup fixes remove unnecessary interface support text input, cleanup * ew camelcase * update server server fixes fix typing * ml fixes update locustfile fixes * cleaner response * better repo response * update tests formatting and typing rename * undo compose change * linting fix type actually fix typing * stricter typing fix detection-only response no need for defaultdict * update spec file update api linting * update e2e * unnecessary dimension * remove commented code * remove duplicate code * remove unused imports * add batch dim
This commit is contained in:
parent
7a46f80ddc
commit
2b1b43a7e4
39 changed files with 982 additions and 999 deletions
|
@ -12,8 +12,6 @@ from rich.logging import RichHandler
|
||||||
from uvicorn import Server
|
from uvicorn import Server
|
||||||
from uvicorn.workers import UvicornWorker
|
from uvicorn.workers import UvicornWorker
|
||||||
|
|
||||||
from .schemas import ModelType
|
|
||||||
|
|
||||||
|
|
||||||
class PreloadModelData(BaseModel):
|
class PreloadModelData(BaseModel):
|
||||||
clip: str | None
|
clip: str | None
|
||||||
|
@ -21,7 +19,7 @@ class PreloadModelData(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
cache_folder: str = "/cache"
|
cache_folder: Path = Path("/cache")
|
||||||
model_ttl: int = 300
|
model_ttl: int = 300
|
||||||
model_ttl_poll_s: int = 10
|
model_ttl_poll_s: int = 10
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
|
@ -55,14 +53,6 @@ def clean_name(model_name: str) -> str:
|
||||||
return model_name.split("/")[-1].translate(_clean_name)
|
return model_name.split("/")[-1].translate(_clean_name)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
|
|
||||||
return Path(settings.cache_folder) / model_type.value / clean_name(model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_hf_model_name(model_name: str) -> str:
|
|
||||||
return f"immich-app/{clean_name(model_name)}"
|
|
||||||
|
|
||||||
|
|
||||||
LOG_LEVELS: dict[str, int] = {
|
LOG_LEVELS: dict[str, int] = {
|
||||||
"critical": logging.ERROR,
|
"critical": logging.ERROR,
|
||||||
"error": logging.ERROR,
|
"error": logging.ERROR,
|
||||||
|
|
|
@ -6,22 +6,34 @@ import threading
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
from typing import Any, AsyncGenerator, Callable, Iterator
|
from typing import Any, AsyncGenerator, Callable, Iterator
|
||||||
from zipfile import BadZipFile
|
from zipfile import BadZipFile
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
|
from fastapi import Depends, FastAPI, File, Form, HTTPException
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
||||||
|
from PIL.Image import Image
|
||||||
|
from pydantic import ValidationError
|
||||||
from starlette.formparsers import MultiPartParser
|
from starlette.formparsers import MultiPartParser
|
||||||
|
|
||||||
|
from app.models import get_model_deps
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
|
from app.models.transforms import decode_pil
|
||||||
|
|
||||||
from .config import PreloadModelData, log, settings
|
from .config import PreloadModelData, log, settings
|
||||||
from .models.cache import ModelCache
|
from .models.cache import ModelCache
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
|
InferenceEntries,
|
||||||
|
InferenceEntry,
|
||||||
|
InferenceResponse,
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
|
ModelIdentity,
|
||||||
|
ModelTask,
|
||||||
ModelType,
|
ModelType,
|
||||||
|
PipelineRequest,
|
||||||
|
T,
|
||||||
TextResponse,
|
TextResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -63,12 +75,21 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
async def preload_models(preload_models: PreloadModelData) -> None:
|
async def preload_models(preload: PreloadModelData) -> None:
|
||||||
log.info(f"Preloading models: {preload_models}")
|
log.info(f"Preloading models: {preload}")
|
||||||
if preload_models.clip is not None:
|
if preload.clip is not None:
|
||||||
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
|
model = await model_cache.get(preload.clip, ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||||
if preload_models.facial_recognition is not None:
|
await load(model)
|
||||||
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
|
|
||||||
|
model = await model_cache.get(preload.clip, ModelType.VISUAL, ModelTask.SEARCH)
|
||||||
|
await load(model)
|
||||||
|
|
||||||
|
if preload.facial_recognition is not None:
|
||||||
|
model = await model_cache.get(preload.facial_recognition, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
|
||||||
|
await load(model)
|
||||||
|
|
||||||
|
model = await model_cache.get(preload.facial_recognition, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
||||||
|
await load(model)
|
||||||
|
|
||||||
|
|
||||||
def update_state() -> Iterator[None]:
|
def update_state() -> Iterator[None]:
|
||||||
|
@ -81,6 +102,27 @@ def update_state() -> Iterator[None]:
|
||||||
active_requests -= 1
|
active_requests -= 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_entries(entries: str = Form()) -> InferenceEntries:
|
||||||
|
try:
|
||||||
|
request: PipelineRequest = orjson.loads(entries)
|
||||||
|
without_deps: list[InferenceEntry] = []
|
||||||
|
with_deps: list[InferenceEntry] = []
|
||||||
|
for task, types in request.items():
|
||||||
|
for type, entry in types.items():
|
||||||
|
parsed: InferenceEntry = {
|
||||||
|
"name": entry["modelName"],
|
||||||
|
"task": task,
|
||||||
|
"type": type,
|
||||||
|
"options": entry.get("options", {}),
|
||||||
|
}
|
||||||
|
dep = get_model_deps(parsed["name"], type, task)
|
||||||
|
(with_deps if dep else without_deps).append(parsed)
|
||||||
|
return without_deps, with_deps
|
||||||
|
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
|
||||||
|
log.error(f"Invalid request format: {e}")
|
||||||
|
raise HTTPException(422, "Invalid request format.")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,42 +138,63 @@ def ping() -> str:
|
||||||
|
|
||||||
@app.post("/predict", dependencies=[Depends(update_state)])
|
@app.post("/predict", dependencies=[Depends(update_state)])
|
||||||
async def predict(
|
async def predict(
|
||||||
model_name: str = Form(alias="modelName"),
|
entries: InferenceEntries = Depends(get_entries),
|
||||||
model_type: ModelType = Form(alias="modelType"),
|
image: bytes | None = File(default=None),
|
||||||
options: str = Form(default="{}"),
|
|
||||||
text: str | None = Form(default=None),
|
text: str | None = Form(default=None),
|
||||||
image: UploadFile | None = None,
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if image is not None:
|
if image is not None:
|
||||||
inputs: str | bytes = await image.read()
|
inputs: Image | str = await run(lambda: decode_pil(image))
|
||||||
elif text is not None:
|
elif text is not None:
|
||||||
inputs = text
|
inputs = text
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, "Either image or text must be provided")
|
raise HTTPException(400, "Either image or text must be provided")
|
||||||
try:
|
response = await run_inference(inputs, entries)
|
||||||
kwargs = orjson.loads(options)
|
return ORJSONResponse(response)
|
||||||
except orjson.JSONDecodeError:
|
|
||||||
raise HTTPException(400, f"Invalid options JSON: {options}")
|
|
||||||
|
|
||||||
model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
|
|
||||||
model.configure(**kwargs)
|
|
||||||
outputs = await run(model.predict, inputs)
|
|
||||||
return ORJSONResponse(outputs)
|
|
||||||
|
|
||||||
|
|
||||||
async def run(func: Callable[..., Any], inputs: Any) -> Any:
|
async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
|
||||||
|
outputs: dict[ModelIdentity, Any] = {}
|
||||||
|
response: InferenceResponse = {}
|
||||||
|
|
||||||
|
async def _run_inference(entry: InferenceEntry) -> None:
|
||||||
|
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
|
||||||
|
inputs = [payload]
|
||||||
|
for dep in model.depends:
|
||||||
|
try:
|
||||||
|
inputs.append(outputs[dep])
|
||||||
|
except KeyError:
|
||||||
|
message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
|
||||||
|
raise HTTPException(400, message)
|
||||||
|
model = await load(model)
|
||||||
|
output = await run(model.predict, *inputs, **entry["options"])
|
||||||
|
outputs[model.identity] = output
|
||||||
|
response[entry["task"]] = output
|
||||||
|
|
||||||
|
without_deps, with_deps = entries
|
||||||
|
await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
|
||||||
|
if with_deps:
|
||||||
|
await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
|
||||||
|
if isinstance(payload, Image):
|
||||||
|
response["imageHeight"], response["imageWidth"] = payload.height, payload.width
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||||
if thread_pool is None:
|
if thread_pool is None:
|
||||||
return func(inputs)
|
return func(*args, **kwargs)
|
||||||
return await asyncio.get_running_loop().run_in_executor(thread_pool, func, inputs)
|
partial_func = partial(func, *args, **kwargs)
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
|
||||||
|
|
||||||
|
|
||||||
async def load(model: InferenceModel) -> InferenceModel:
|
async def load(model: InferenceModel) -> InferenceModel:
|
||||||
if model.loaded:
|
if model.loaded:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load(model: InferenceModel) -> None:
|
def _load(model: InferenceModel) -> InferenceModel:
|
||||||
with lock:
|
with lock:
|
||||||
model.load()
|
model.load()
|
||||||
|
return model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await run(_load, model)
|
await run(_load, model)
|
||||||
|
|
|
@ -1,24 +1,40 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import ModelType
|
from app.models.base import InferenceModel
|
||||||
|
from app.models.clip.textual import MClipTextualEncoder, OpenClipTextualEncoder
|
||||||
|
from app.models.clip.visual import OpenClipVisualEncoder
|
||||||
|
from app.schemas import ModelSource, ModelTask, ModelType
|
||||||
|
|
||||||
from .base import InferenceModel
|
from .constants import get_model_source
|
||||||
from .clip import MCLIPEncoder, OpenCLIPEncoder
|
from .facial_recognition.detection import FaceDetector
|
||||||
from .constants import is_insightface, is_mclip, is_openclip
|
from .facial_recognition.recognition import FaceRecognizer
|
||||||
from .facial_recognition import FaceRecognizer
|
|
||||||
|
|
||||||
|
|
||||||
def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
|
def get_model_class(model_name: str, model_type: ModelType, model_task: ModelTask) -> type[InferenceModel]:
|
||||||
match model_type:
|
source = get_model_source(model_name)
|
||||||
case ModelType.CLIP:
|
match source, model_type, model_task:
|
||||||
if is_openclip(model_name):
|
case ModelSource.OPENCLIP | ModelSource.MCLIP, ModelType.VISUAL, ModelTask.SEARCH:
|
||||||
return OpenCLIPEncoder(model_name, **model_kwargs)
|
return OpenClipVisualEncoder
|
||||||
elif is_mclip(model_name):
|
|
||||||
return MCLIPEncoder(model_name, **model_kwargs)
|
case ModelSource.OPENCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
|
||||||
case ModelType.FACIAL_RECOGNITION:
|
return OpenClipTextualEncoder
|
||||||
if is_insightface(model_name):
|
|
||||||
return FaceRecognizer(model_name, **model_kwargs)
|
case ModelSource.MCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
|
||||||
|
return MClipTextualEncoder
|
||||||
|
|
||||||
|
case ModelSource.INSIGHTFACE, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION:
|
||||||
|
return FaceDetector
|
||||||
|
|
||||||
|
case ModelSource.INSIGHTFACE, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION:
|
||||||
|
return FaceRecognizer
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown model type {model_type}")
|
raise ValueError(f"Unknown model combination: {source}, {model_type}, {model_task}")
|
||||||
|
|
||||||
raise ValueError(f"Unknown {model_type} model {model_name}")
|
|
||||||
|
def from_model_type(model_name: str, model_type: ModelType, model_task: ModelTask, **kwargs: Any) -> InferenceModel:
|
||||||
|
return get_model_class(model_name, model_type, model_task)(model_name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_deps(model_name: str, model_type: ModelType, model_task: ModelTask) -> list[tuple[ModelType, ModelTask]]:
|
||||||
|
return get_model_class(model_name, model_type, model_task).depends
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
@ -11,13 +11,14 @@ from huggingface_hub import snapshot_download
|
||||||
import ann.ann
|
import ann.ann
|
||||||
from app.models.constants import SUPPORTED_PROVIDERS
|
from app.models.constants import SUPPORTED_PROVIDERS
|
||||||
|
|
||||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
from ..config import clean_name, log, settings
|
||||||
from ..schemas import ModelRuntime, ModelType
|
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
|
||||||
from .ann import AnnSession
|
from .ann import AnnSession
|
||||||
|
|
||||||
|
|
||||||
class InferenceModel(ABC):
|
class InferenceModel(ABC):
|
||||||
_model_type: ModelType
|
depends: ClassVar[list[ModelIdentity]]
|
||||||
|
identity: ClassVar[ModelIdentity]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -26,16 +27,16 @@ class InferenceModel(ABC):
|
||||||
providers: list[str] | None = None,
|
providers: list[str] | None = None,
|
||||||
provider_options: list[dict[str, Any]] | None = None,
|
provider_options: list[dict[str, Any]] | None = None,
|
||||||
sess_options: ort.SessionOptions | None = None,
|
sess_options: ort.SessionOptions | None = None,
|
||||||
preferred_runtime: ModelRuntime | None = None,
|
preferred_format: ModelFormat | None = None,
|
||||||
**model_kwargs: Any,
|
**model_kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.loaded = False
|
self.loaded = False
|
||||||
self.model_name = model_name
|
self.model_name = clean_name(model_name)
|
||||||
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
|
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
|
||||||
self.providers = providers if providers is not None else self.providers_default
|
self.providers = providers if providers is not None else self.providers_default
|
||||||
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
|
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
|
||||||
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
|
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
|
||||||
self.preferred_runtime = preferred_runtime if preferred_runtime is not None else self.preferred_runtime_default
|
self.preferred_format = preferred_format if preferred_format is not None else self.preferred_format_default
|
||||||
|
|
||||||
def download(self) -> None:
|
def download(self) -> None:
|
||||||
if not self.cached:
|
if not self.cached:
|
||||||
|
@ -47,35 +48,36 @@ class InferenceModel(ABC):
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
if self.loaded:
|
if self.loaded:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.download()
|
self.download()
|
||||||
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
||||||
self._load()
|
self.session = self._load()
|
||||||
self.loaded = True
|
self.loaded = True
|
||||||
|
|
||||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
|
||||||
self.load()
|
self.load()
|
||||||
if model_kwargs:
|
if model_kwargs:
|
||||||
self.configure(**model_kwargs)
|
self.configure(**model_kwargs)
|
||||||
return self._predict(inputs)
|
return self._predict(*inputs, **model_kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _predict(self, inputs: Any) -> Any: ...
|
def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...
|
||||||
|
|
||||||
def configure(self, **model_kwargs: Any) -> None:
|
def configure(self, **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _download(self) -> None:
|
def _download(self) -> None:
|
||||||
ignore_patterns = [] if self.preferred_runtime == ModelRuntime.ARMNN else ["*.armnn"]
|
ignore_patterns = [] if self.preferred_format == ModelFormat.ARMNN else ["*.armnn"]
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
get_hf_model_name(self.model_name),
|
f"immich-app/{clean_name(self.model_name)}",
|
||||||
cache_dir=self.cache_dir,
|
cache_dir=self.cache_dir,
|
||||||
local_dir=self.cache_dir,
|
local_dir=self.cache_dir,
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
def _load(self) -> ModelSession:
|
||||||
def _load(self) -> None: ...
|
return self._make_session(self.model_path)
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
if not self.cache_dir.exists():
|
if not self.cache_dir.exists():
|
||||||
|
@ -99,7 +101,7 @@ class InferenceModel(ABC):
|
||||||
self.cache_dir.unlink()
|
self.cache_dir.unlink()
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
|
def _make_session(self, model_path: Path) -> ModelSession:
|
||||||
if not model_path.is_file():
|
if not model_path.is_file():
|
||||||
onnx_path = model_path.with_suffix(".onnx")
|
onnx_path = model_path.with_suffix(".onnx")
|
||||||
if not onnx_path.is_file():
|
if not onnx_path.is_file():
|
||||||
|
@ -124,9 +126,21 @@ class InferenceModel(ABC):
|
||||||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_dir(self) -> Path:
|
||||||
|
return self.cache_dir / self.model_type.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_path(self) -> Path:
|
||||||
|
return self.model_dir / f"model.{self.preferred_format}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_task(self) -> ModelTask:
|
||||||
|
return self.identity[1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> ModelType:
|
def model_type(self) -> ModelType:
|
||||||
return self._model_type
|
return self.identity[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_dir(self) -> Path:
|
def cache_dir(self) -> Path:
|
||||||
|
@ -138,11 +152,11 @@ class InferenceModel(ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_dir_default(self) -> Path:
|
def cache_dir_default(self) -> Path:
|
||||||
return get_cache_dir(self.model_name, self.model_type)
|
return settings.cache_folder / self.model_task.value / self.model_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cached(self) -> bool:
|
def cached(self) -> bool:
|
||||||
return self.cache_dir.is_dir() and any(self.cache_dir.iterdir())
|
return self.model_path.is_file()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def providers(self) -> list[str]:
|
def providers(self) -> list[str]:
|
||||||
|
@ -226,14 +240,14 @@ class InferenceModel(ABC):
|
||||||
return sess_options
|
return sess_options
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def preferred_runtime(self) -> ModelRuntime:
|
def preferred_format(self) -> ModelFormat:
|
||||||
return self._preferred_runtime
|
return self._preferred_format
|
||||||
|
|
||||||
@preferred_runtime.setter
|
@preferred_format.setter
|
||||||
def preferred_runtime(self, preferred_runtime: ModelRuntime) -> None:
|
def preferred_format(self, preferred_format: ModelFormat) -> None:
|
||||||
log.debug(f"Setting preferred runtime to {preferred_runtime}")
|
log.debug(f"Setting preferred format to {preferred_format}")
|
||||||
self._preferred_runtime = preferred_runtime
|
self._preferred_format = preferred_format
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def preferred_runtime_default(self) -> ModelRuntime:
|
def preferred_format_default(self) -> ModelFormat:
|
||||||
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
|
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX
|
||||||
|
|
|
@ -5,9 +5,9 @@ from aiocache.lock import OptimisticLock
|
||||||
from aiocache.plugins import TimingPlugin
|
from aiocache.plugins import TimingPlugin
|
||||||
|
|
||||||
from app.models import from_model_type
|
from app.models import from_model_type
|
||||||
|
from app.models.base import InferenceModel
|
||||||
|
|
||||||
from ..schemas import ModelType, has_profiling
|
from ..schemas import ModelTask, ModelType, has_profiling
|
||||||
from .base import InferenceModel
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
|
@ -31,28 +31,21 @@ class ModelCache:
|
||||||
if profiling:
|
if profiling:
|
||||||
plugins.append(TimingPlugin())
|
plugins.append(TimingPlugin())
|
||||||
|
|
||||||
self.revalidate_enable = revalidate
|
self.should_revalidate = revalidate
|
||||||
|
|
||||||
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
|
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
|
||||||
|
|
||||||
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
|
async def get(
|
||||||
"""
|
self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any
|
||||||
Args:
|
) -> InferenceModel:
|
||||||
model_name: Name of model in the model hub used for the task.
|
key = f"{model_name}{model_type}{model_task}"
|
||||||
model_type: Model type or task, which determines which model zoo is used.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
model: The requested model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
|
||||||
|
|
||||||
async with OptimisticLock(self.cache, key) as lock:
|
async with OptimisticLock(self.cache, key) as lock:
|
||||||
model: InferenceModel | None = await self.cache.get(key)
|
model: InferenceModel | None = await self.cache.get(key)
|
||||||
if model is None:
|
if model is None:
|
||||||
model = from_model_type(model_type, model_name, **model_kwargs)
|
model = from_model_type(model_name, model_type, model_task, **model_kwargs)
|
||||||
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
|
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
|
||||||
elif self.revalidate_enable:
|
elif self.should_revalidate:
|
||||||
await self.revalidate(key, model_kwargs.get("ttl", None))
|
await self.revalidate(key, model_kwargs.get("ttl", None))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -1,189 +0,0 @@
|
||||||
import json
|
|
||||||
from abc import abstractmethod
|
|
||||||
from functools import cached_property
|
|
||||||
from io import BytesIO
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from numpy.typing import NDArray
|
|
||||||
from PIL import Image
|
|
||||||
from tokenizers import Encoding, Tokenizer
|
|
||||||
|
|
||||||
from app.config import clean_name, log
|
|
||||||
from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy
|
|
||||||
from app.schemas import ModelType
|
|
||||||
|
|
||||||
from .base import InferenceModel
|
|
||||||
|
|
||||||
|
|
||||||
class BaseCLIPEncoder(InferenceModel):
|
|
||||||
_model_type = ModelType.CLIP
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
cache_dir: Path | str | None = None,
|
|
||||||
mode: Literal["text", "vision"] | None = None,
|
|
||||||
**model_kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self.mode = mode
|
|
||||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
|
||||||
|
|
||||||
def _load(self) -> None:
|
|
||||||
if self.mode == "text" or self.mode is None:
|
|
||||||
log.debug(f"Loading clip text model '{self.model_name}'")
|
|
||||||
self.text_model = self._make_session(self.textual_path)
|
|
||||||
log.debug(f"Loaded clip text model '{self.model_name}'")
|
|
||||||
|
|
||||||
if self.mode == "vision" or self.mode is None:
|
|
||||||
log.debug(f"Loading clip vision model '{self.model_name}'")
|
|
||||||
self.vision_model = self._make_session(self.visual_path)
|
|
||||||
log.debug(f"Loaded clip vision model '{self.model_name}'")
|
|
||||||
|
|
||||||
def _predict(self, image_or_text: Image.Image | str) -> NDArray[np.float32]:
|
|
||||||
if isinstance(image_or_text, bytes):
|
|
||||||
image_or_text = Image.open(BytesIO(image_or_text))
|
|
||||||
|
|
||||||
match image_or_text:
|
|
||||||
case Image.Image():
|
|
||||||
if self.mode == "text":
|
|
||||||
raise TypeError("Cannot encode image as text-only model")
|
|
||||||
outputs: NDArray[np.float32] = self.vision_model.run(None, self.transform(image_or_text))[0][0]
|
|
||||||
case str():
|
|
||||||
if self.mode == "vision":
|
|
||||||
raise TypeError("Cannot encode text as vision-only model")
|
|
||||||
outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
|
|
||||||
case _:
|
|
||||||
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def textual_dir(self) -> Path:
|
|
||||||
return self.cache_dir / "textual"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def visual_dir(self) -> Path:
|
|
||||||
return self.cache_dir / "visual"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_cfg_path(self) -> Path:
|
|
||||||
return self.cache_dir / "config.json"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def textual_path(self) -> Path:
|
|
||||||
return self.textual_dir / f"model.{self.preferred_runtime}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def visual_path(self) -> Path:
|
|
||||||
return self.visual_dir / f"model.{self.preferred_runtime}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tokenizer_file_path(self) -> Path:
|
|
||||||
return self.textual_dir / "tokenizer.json"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tokenizer_cfg_path(self) -> Path:
|
|
||||||
return self.textual_dir / "tokenizer_config.json"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def preprocess_cfg_path(self) -> Path:
|
|
||||||
return self.visual_dir / "preprocess_cfg.json"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cached(self) -> bool:
|
|
||||||
return self.textual_path.is_file() and self.visual_path.is_file()
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def model_cfg(self) -> dict[str, Any]:
|
|
||||||
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
|
||||||
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
|
||||||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
|
||||||
return model_cfg
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def tokenizer_file(self) -> dict[str, Any]:
|
|
||||||
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
|
||||||
tokenizer_file: dict[str, Any] = json.load(self.tokenizer_file_path.open())
|
|
||||||
log.debug(f"Loaded tokenizer file for CLIP model '{self.model_name}'")
|
|
||||||
return tokenizer_file
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def tokenizer_cfg(self) -> dict[str, Any]:
|
|
||||||
log.debug(f"Loading tokenizer config for CLIP model '{self.model_name}'")
|
|
||||||
tokenizer_cfg: dict[str, Any] = json.load(self.tokenizer_cfg_path.open())
|
|
||||||
log.debug(f"Loaded tokenizer config for CLIP model '{self.model_name}'")
|
|
||||||
return tokenizer_cfg
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def preprocess_cfg(self) -> dict[str, Any]:
|
|
||||||
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
|
|
||||||
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
|
||||||
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
|
|
||||||
return preprocess_cfg
|
|
||||||
|
|
||||||
|
|
||||||
class OpenCLIPEncoder(BaseCLIPEncoder):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
cache_dir: Path | str | None = None,
|
|
||||||
mode: Literal["text", "vision"] | None = None,
|
|
||||||
**model_kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(clean_name(model_name), cache_dir, mode, **model_kwargs)
|
|
||||||
|
|
||||||
def _load(self) -> None:
|
|
||||||
super()._load()
|
|
||||||
self._load_tokenizer()
|
|
||||||
|
|
||||||
size: list[int] | int = self.preprocess_cfg["size"]
|
|
||||||
self.size = size[0] if isinstance(size, list) else size
|
|
||||||
|
|
||||||
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
|
|
||||||
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
|
||||||
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
|
||||||
|
|
||||||
def _load_tokenizer(self) -> Tokenizer:
|
|
||||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
|
||||||
|
|
||||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
|
||||||
context_length: int = text_cfg.get("context_length", 77)
|
|
||||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
|
||||||
|
|
||||||
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
|
||||||
|
|
||||||
pad_id: int = self.tokenizer.token_to_id(pad_token)
|
|
||||||
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
|
||||||
self.tokenizer.enable_truncation(max_length=context_length)
|
|
||||||
|
|
||||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
|
||||||
|
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
|
||||||
tokens: Encoding = self.tokenizer.encode(text)
|
|
||||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
|
||||||
|
|
||||||
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
|
||||||
image = resize(image, self.size)
|
|
||||||
image = crop(image, self.size)
|
|
||||||
image_np = to_numpy(image)
|
|
||||||
image_np = normalize(image_np, self.mean, self.std)
|
|
||||||
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
|
|
||||||
|
|
||||||
|
|
||||||
class MCLIPEncoder(OpenCLIPEncoder):
|
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
|
||||||
tokens: Encoding = self.tokenizer.encode(text)
|
|
||||||
return {
|
|
||||||
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
|
||||||
"attention_mask": np.array([tokens.attention_mask], dtype=np.int32),
|
|
||||||
}
|
|
98
machine-learning/app/models/clip/textual.py
Normal file
98
machine-learning/app/models/clip/textual.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
import json
|
||||||
|
from abc import abstractmethod
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from tokenizers import Encoding, Tokenizer
|
||||||
|
|
||||||
|
from app.config import log
|
||||||
|
from app.models.base import InferenceModel
|
||||||
|
from app.schemas import ModelSession, ModelTask, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCLIPTextualEncoder(InferenceModel):
|
||||||
|
depends = []
|
||||||
|
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||||
|
|
||||||
|
def _predict(self, inputs: str, **kwargs: Any) -> NDArray[np.float32]:
|
||||||
|
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _load(self) -> ModelSession:
|
||||||
|
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||||
|
self.tokenizer = self._load_tokenizer()
|
||||||
|
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||||
|
|
||||||
|
return super()._load()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _load_tokenizer(self) -> Tokenizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_cfg_path(self) -> Path:
|
||||||
|
return self.cache_dir / "config.json"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer_file_path(self) -> Path:
|
||||||
|
return self.model_dir / "tokenizer.json"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer_cfg_path(self) -> Path:
|
||||||
|
return self.model_dir / "tokenizer_config.json"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model_cfg(self) -> dict[str, Any]:
|
||||||
|
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
||||||
|
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||||||
|
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||||
|
return model_cfg
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def tokenizer_file(self) -> dict[str, Any]:
|
||||||
|
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
||||||
|
tokenizer_file: dict[str, Any] = json.load(self.tokenizer_file_path.open())
|
||||||
|
log.debug(f"Loaded tokenizer file for CLIP model '{self.model_name}'")
|
||||||
|
return tokenizer_file
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def tokenizer_cfg(self) -> dict[str, Any]:
|
||||||
|
log.debug(f"Loading tokenizer config for CLIP model '{self.model_name}'")
|
||||||
|
tokenizer_cfg: dict[str, Any] = json.load(self.tokenizer_cfg_path.open())
|
||||||
|
log.debug(f"Loaded tokenizer config for CLIP model '{self.model_name}'")
|
||||||
|
return tokenizer_cfg
|
||||||
|
|
||||||
|
|
||||||
|
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||||
|
def _load_tokenizer(self) -> Tokenizer:
|
||||||
|
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||||
|
context_length: int = text_cfg.get("context_length", 77)
|
||||||
|
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||||
|
|
||||||
|
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||||
|
|
||||||
|
pad_id: int = tokenizer.token_to_id(pad_token)
|
||||||
|
tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
||||||
|
tokenizer.enable_truncation(max_length=context_length)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||||
|
tokens: Encoding = self.tokenizer.encode(text)
|
||||||
|
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||||
|
|
||||||
|
|
||||||
|
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||||
|
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||||
|
tokens: Encoding = self.tokenizer.encode(text)
|
||||||
|
return {
|
||||||
|
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
||||||
|
"attention_mask": np.array([tokens.attention_mask], dtype=np.int32),
|
||||||
|
}
|
69
machine-learning/app/models/clip/visual.py
Normal file
69
machine-learning/app/models/clip/visual.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import json
|
||||||
|
from abc import abstractmethod
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from app.config import log
|
||||||
|
from app.models.base import InferenceModel
|
||||||
|
from app.models.transforms import crop_pil, decode_pil, get_pil_resampling, normalize, resize_pil, to_numpy
|
||||||
|
from app.schemas import ModelSession, ModelTask, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCLIPVisualEncoder(InferenceModel):
|
||||||
|
depends = []
|
||||||
|
identity = (ModelType.VISUAL, ModelTask.SEARCH)
|
||||||
|
|
||||||
|
def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> NDArray[np.float32]:
|
||||||
|
image = decode_pil(inputs)
|
||||||
|
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
|
||||||
|
return res
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_cfg_path(self) -> Path:
|
||||||
|
return self.cache_dir / "config.json"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def preprocess_cfg_path(self) -> Path:
|
||||||
|
return self.model_dir / "preprocess_cfg.json"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model_cfg(self) -> dict[str, Any]:
|
||||||
|
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
||||||
|
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||||||
|
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||||
|
return model_cfg
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def preprocess_cfg(self) -> dict[str, Any]:
|
||||||
|
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
|
||||||
|
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
||||||
|
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
|
||||||
|
return preprocess_cfg
|
||||||
|
|
||||||
|
|
||||||
|
class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
|
||||||
|
def _load(self) -> ModelSession:
|
||||||
|
size: list[int] | int = self.preprocess_cfg["size"]
|
||||||
|
self.size = size[0] if isinstance(size, list) else size
|
||||||
|
|
||||||
|
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
|
||||||
|
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
||||||
|
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
||||||
|
|
||||||
|
return super()._load()
|
||||||
|
|
||||||
|
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
||||||
|
image = resize_pil(image, self.size)
|
||||||
|
image = crop_pil(image, self.size)
|
||||||
|
image_np = to_numpy(image)
|
||||||
|
image_np = normalize(image_np, self.mean, self.std)
|
||||||
|
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
|
|
@ -1,4 +1,5 @@
|
||||||
from app.config import clean_name
|
from app.config import clean_name
|
||||||
|
from app.schemas import ModelSource
|
||||||
|
|
||||||
_OPENCLIP_MODELS = {
|
_OPENCLIP_MODELS = {
|
||||||
"RN50__openai",
|
"RN50__openai",
|
||||||
|
@ -54,13 +55,16 @@ _INSIGHTFACE_MODELS = {
|
||||||
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
|
||||||
|
|
||||||
def is_openclip(model_name: str) -> bool:
|
def get_model_source(model_name: str) -> ModelSource | None:
|
||||||
return clean_name(model_name) in _OPENCLIP_MODELS
|
cleaned_name = clean_name(model_name)
|
||||||
|
|
||||||
|
if cleaned_name in _INSIGHTFACE_MODELS:
|
||||||
|
return ModelSource.INSIGHTFACE
|
||||||
|
|
||||||
def is_mclip(model_name: str) -> bool:
|
if cleaned_name in _MCLIP_MODELS:
|
||||||
return clean_name(model_name) in _MCLIP_MODELS
|
return ModelSource.MCLIP
|
||||||
|
|
||||||
|
if cleaned_name in _OPENCLIP_MODELS:
|
||||||
|
return ModelSource.OPENCLIP
|
||||||
|
|
||||||
def is_insightface(model_name: str) -> bool:
|
return None
|
||||||
return clean_name(model_name) in _INSIGHTFACE_MODELS
|
|
||||||
|
|
|
@ -1,90 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from insightface.model_zoo import ArcFaceONNX, RetinaFace
|
|
||||||
from insightface.utils.face_align import norm_crop
|
|
||||||
from numpy.typing import NDArray
|
|
||||||
|
|
||||||
from app.config import clean_name
|
|
||||||
from app.schemas import Face, ModelType, is_ndarray
|
|
||||||
|
|
||||||
from .base import InferenceModel
|
|
||||||
|
|
||||||
|
|
||||||
class FaceRecognizer(InferenceModel):
|
|
||||||
_model_type = ModelType.FACIAL_RECOGNITION
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
min_score: float = 0.7,
|
|
||||||
cache_dir: Path | str | None = None,
|
|
||||||
**model_kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
|
||||||
super().__init__(clean_name(model_name), cache_dir, **model_kwargs)
|
|
||||||
|
|
||||||
def _load(self) -> None:
|
|
||||||
self.det_model = RetinaFace(session=self._make_session(self.det_file))
|
|
||||||
self.rec_model = ArcFaceONNX(
|
|
||||||
self.rec_file.with_suffix(".onnx").as_posix(),
|
|
||||||
session=self._make_session(self.rec_file),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.det_model.prepare(
|
|
||||||
ctx_id=0,
|
|
||||||
det_thresh=self.min_score,
|
|
||||||
input_size=(640, 640),
|
|
||||||
)
|
|
||||||
self.rec_model.prepare(ctx_id=0)
|
|
||||||
|
|
||||||
def _predict(self, image: NDArray[np.uint8] | bytes) -> list[Face]:
|
|
||||||
if isinstance(image, bytes):
|
|
||||||
decoded_image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
|
|
||||||
else:
|
|
||||||
decoded_image = image
|
|
||||||
assert is_ndarray(decoded_image, np.uint8)
|
|
||||||
bboxes, kpss = self.det_model.detect(decoded_image)
|
|
||||||
if bboxes.size == 0:
|
|
||||||
return []
|
|
||||||
assert is_ndarray(kpss, np.float32)
|
|
||||||
|
|
||||||
scores = bboxes[:, 4].tolist()
|
|
||||||
bboxes = bboxes[:, :4].round().tolist()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
height, width, _ = decoded_image.shape
|
|
||||||
for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
|
|
||||||
cropped_img = norm_crop(decoded_image, kps)
|
|
||||||
embedding: NDArray[np.float32] = self.rec_model.get_feat(cropped_img)[0]
|
|
||||||
face: Face = {
|
|
||||||
"imageWidth": width,
|
|
||||||
"imageHeight": height,
|
|
||||||
"boundingBox": {
|
|
||||||
"x1": x1,
|
|
||||||
"y1": y1,
|
|
||||||
"x2": x2,
|
|
||||||
"y2": y2,
|
|
||||||
},
|
|
||||||
"score": score,
|
|
||||||
"embedding": embedding,
|
|
||||||
}
|
|
||||||
results.append(face)
|
|
||||||
return results
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cached(self) -> bool:
|
|
||||||
return self.det_file.is_file() and self.rec_file.is_file()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def det_file(self) -> Path:
|
|
||||||
return self.cache_dir / "detection" / f"model.{self.preferred_runtime}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rec_file(self) -> Path:
|
|
||||||
return self.cache_dir / "recognition" / f"model.{self.preferred_runtime}"
|
|
||||||
|
|
||||||
def configure(self, **model_kwargs: Any) -> None:
|
|
||||||
self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh)
|
|
48
machine-learning/app/models/facial_recognition/detection.py
Normal file
48
machine-learning/app/models/facial_recognition/detection.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from insightface.model_zoo import RetinaFace
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from app.models.base import InferenceModel
|
||||||
|
from app.models.transforms import decode_cv2
|
||||||
|
from app.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class FaceDetector(InferenceModel):
|
||||||
|
depends = []
|
||||||
|
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
min_score: float = 0.7,
|
||||||
|
cache_dir: Path | str | None = None,
|
||||||
|
**model_kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||||
|
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||||
|
|
||||||
|
def _load(self) -> ModelSession:
|
||||||
|
session = self._make_session(self.model_path)
|
||||||
|
self.model = RetinaFace(session=session)
|
||||||
|
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def _predict(self, inputs: NDArray[np.uint8] | bytes, **kwargs: Any) -> FaceDetectionOutput:
|
||||||
|
inputs = decode_cv2(inputs)
|
||||||
|
|
||||||
|
bboxes, landmarks = self._detect(inputs)
|
||||||
|
return {
|
||||||
|
"boxes": bboxes[:, :4].round(),
|
||||||
|
"scores": bboxes[:, 4],
|
||||||
|
"landmarks": landmarks,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
||||||
|
return self.model.detect(inputs) # type: ignore
|
||||||
|
|
||||||
|
def configure(self, **kwargs: Any) -> None:
|
||||||
|
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
|
|
@ -0,0 +1,77 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
from insightface.model_zoo import ArcFaceONNX
|
||||||
|
from insightface.utils.face_align import norm_crop
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from app.config import clean_name, log
|
||||||
|
from app.models.base import InferenceModel
|
||||||
|
from app.models.transforms import decode_cv2
|
||||||
|
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelSession, ModelTask, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class FaceRecognizer(InferenceModel):
|
||||||
|
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
|
||||||
|
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
min_score: float = 0.7,
|
||||||
|
cache_dir: Path | str | None = None,
|
||||||
|
**model_kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||||
|
super().__init__(clean_name(model_name), cache_dir, **model_kwargs)
|
||||||
|
|
||||||
|
def _load(self) -> ModelSession:
|
||||||
|
session = self._make_session(self.model_path)
|
||||||
|
if not self._has_batch_dim(session):
|
||||||
|
self._add_batch_dim(self.model_path)
|
||||||
|
session = self._make_session(self.model_path)
|
||||||
|
self.model = ArcFaceONNX(
|
||||||
|
self.model_path.with_suffix(".onnx").as_posix(),
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
def _predict(
|
||||||
|
self, inputs: NDArray[np.uint8] | bytes | Image.Image, faces: FaceDetectionOutput, **kwargs: Any
|
||||||
|
) -> FacialRecognitionOutput:
|
||||||
|
if faces["boxes"].shape[0] == 0:
|
||||||
|
return []
|
||||||
|
inputs = decode_cv2(inputs)
|
||||||
|
embeddings: NDArray[np.float32] = self.model.get_feat(self._crop(inputs, faces))
|
||||||
|
return self.postprocess(faces, embeddings)
|
||||||
|
|
||||||
|
def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32]) -> FacialRecognitionOutput:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"boundingBox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
|
||||||
|
"embedding": embedding,
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
for (x1, y1, x2, y2), embedding, score in zip(faces["boxes"], embeddings, faces["scores"])
|
||||||
|
]
|
||||||
|
|
||||||
|
def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]:
|
||||||
|
return [norm_crop(image, landmark) for landmark in faces["landmarks"]]
|
||||||
|
|
||||||
|
def _has_batch_dim(self, session: ort.InferenceSession) -> bool:
|
||||||
|
return not isinstance(session, ort.InferenceSession) or session.get_inputs()[0].shape[0] == "batch"
|
||||||
|
|
||||||
|
def _add_batch_dim(self, model_path: Path) -> None:
|
||||||
|
log.debug(f"Adding batch dimension to model {model_path}")
|
||||||
|
proto = onnx.load(model_path)
|
||||||
|
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
|
||||||
|
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
|
||||||
|
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
|
||||||
|
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
|
||||||
|
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
|
||||||
|
onnx.save(updated_proto, model_path)
|
0
machine-learning/app/models/session.py
Normal file
0
machine-learning/app/models/session.py
Normal file
|
@ -1,3 +1,7 @@
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import IO
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -5,7 +9,7 @@ from PIL import Image
|
||||||
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
|
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
|
||||||
|
|
||||||
|
|
||||||
def resize(img: Image.Image, size: int) -> Image.Image:
|
def resize_pil(img: Image.Image, size: int) -> Image.Image:
|
||||||
if img.width < img.height:
|
if img.width < img.height:
|
||||||
return img.resize((size, int((img.height / img.width) * size)), resample=Image.Resampling.BICUBIC)
|
return img.resize((size, int((img.height / img.width) * size)), resample=Image.Resampling.BICUBIC)
|
||||||
else:
|
else:
|
||||||
|
@ -13,7 +17,7 @@ def resize(img: Image.Image, size: int) -> Image.Image:
|
||||||
|
|
||||||
|
|
||||||
# https://stackoverflow.com/a/60883103
|
# https://stackoverflow.com/a/60883103
|
||||||
def crop(img: Image.Image, size: int) -> Image.Image:
|
def crop_pil(img: Image.Image, size: int) -> Image.Image:
|
||||||
left = int((img.size[0] / 2) - (size / 2))
|
left = int((img.size[0] / 2) - (size / 2))
|
||||||
upper = int((img.size[1] / 2) - (size / 2))
|
upper = int((img.size[1] / 2) - (size / 2))
|
||||||
right = left + size
|
right = left + size
|
||||||
|
@ -23,14 +27,36 @@ def crop(img: Image.Image, size: int) -> Image.Image:
|
||||||
|
|
||||||
|
|
||||||
def to_numpy(img: Image.Image) -> NDArray[np.float32]:
|
def to_numpy(img: Image.Image) -> NDArray[np.float32]:
|
||||||
return np.asarray(img.convert("RGB")).astype(np.float32) / 255.0
|
return np.asarray(img if img.mode == "RGB" else img.convert("RGB"), dtype=np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
def normalize(
|
def normalize(
|
||||||
img: NDArray[np.float32], mean: float | NDArray[np.float32], std: float | NDArray[np.float32]
|
img: NDArray[np.float32], mean: float | NDArray[np.float32], std: float | NDArray[np.float32]
|
||||||
) -> NDArray[np.float32]:
|
) -> NDArray[np.float32]:
|
||||||
return (img - mean) / std
|
return np.divide(img - mean, std, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
def get_pil_resampling(resample: str) -> Image.Resampling:
|
def get_pil_resampling(resample: str) -> Image.Resampling:
|
||||||
return _PIL_RESAMPLING_METHODS[resample.lower()]
|
return _PIL_RESAMPLING_METHODS[resample.lower()]
|
||||||
|
|
||||||
|
|
||||||
|
def pil_to_cv2(image: Image.Image) -> NDArray[np.uint8]:
|
||||||
|
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def decode_pil(image_bytes: bytes | IO[bytes] | Image.Image) -> Image.Image:
|
||||||
|
if isinstance(image_bytes, Image.Image):
|
||||||
|
return image_bytes
|
||||||
|
image = Image.open(BytesIO(image_bytes) if isinstance(image_bytes, bytes) else image_bytes)
|
||||||
|
image.load() # type: ignore
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]:
|
||||||
|
if isinstance(image_bytes, bytes):
|
||||||
|
image_bytes = decode_pil(image_bytes) # pillow is much faster than cv2
|
||||||
|
if isinstance(image_bytes, Image.Image):
|
||||||
|
return pil_to_cv2(image_bytes)
|
||||||
|
return image_bytes
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Protocol, TypedDict, TypeGuard
|
from typing import Any, Literal, Protocol, TypedDict, TypeGuard, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
@ -28,31 +28,87 @@ class BoundingBox(TypedDict):
|
||||||
y2: int
|
y2: int
|
||||||
|
|
||||||
|
|
||||||
class ModelType(StrEnum):
|
class ModelTask(StrEnum):
|
||||||
CLIP = "clip"
|
|
||||||
FACIAL_RECOGNITION = "facial-recognition"
|
FACIAL_RECOGNITION = "facial-recognition"
|
||||||
|
SEARCH = "clip"
|
||||||
|
|
||||||
|
|
||||||
class ModelRuntime(StrEnum):
|
class ModelType(StrEnum):
|
||||||
ONNX = "onnx"
|
DETECTION = "detection"
|
||||||
|
RECOGNITION = "recognition"
|
||||||
|
TEXTUAL = "textual"
|
||||||
|
VISUAL = "visual"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFormat(StrEnum):
|
||||||
ARMNN = "armnn"
|
ARMNN = "armnn"
|
||||||
|
ONNX = "onnx"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSource(StrEnum):
|
||||||
|
INSIGHTFACE = "insightface"
|
||||||
|
MCLIP = "mclip"
|
||||||
|
OPENCLIP = "openclip"
|
||||||
|
|
||||||
|
|
||||||
|
ModelIdentity = tuple[ModelType, ModelTask]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSession(Protocol):
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
output_names: list[str] | None,
|
||||||
|
input_feed: dict[str, npt.NDArray[np.float32]] | dict[str, npt.NDArray[np.int32]],
|
||||||
|
run_options: Any = None,
|
||||||
|
) -> list[npt.NDArray[np.float32]]: ...
|
||||||
|
|
||||||
|
|
||||||
class HasProfiling(Protocol):
|
class HasProfiling(Protocol):
|
||||||
profiling: dict[str, float]
|
profiling: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
class Face(TypedDict):
|
class FaceDetectionOutput(TypedDict):
|
||||||
|
boxes: npt.NDArray[np.float32]
|
||||||
|
scores: npt.NDArray[np.float32]
|
||||||
|
landmarks: npt.NDArray[np.float32]
|
||||||
|
|
||||||
|
|
||||||
|
class DetectedFace(TypedDict):
|
||||||
boundingBox: BoundingBox
|
boundingBox: BoundingBox
|
||||||
embedding: npt.NDArray[np.float32]
|
embedding: npt.NDArray[np.float32]
|
||||||
imageWidth: int
|
|
||||||
imageHeight: int
|
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
FacialRecognitionOutput = list[DetectedFace]
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineEntry(TypedDict):
|
||||||
|
modelName: str
|
||||||
|
options: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
PipelineRequest = dict[ModelTask, dict[ModelType, PipelineEntry]]
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceEntry(TypedDict):
|
||||||
|
name: str
|
||||||
|
task: ModelTask
|
||||||
|
type: ModelType
|
||||||
|
options: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]
|
||||||
|
|
||||||
|
|
||||||
|
InferenceResponse = dict[ModelTask | Literal["imageHeight"] | Literal["imageWidth"], Any]
|
||||||
|
|
||||||
|
|
||||||
def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
|
def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
|
||||||
return hasattr(obj, "profiling") and isinstance(obj.profiling, dict)
|
return hasattr(obj, "profiling") and isinstance(obj.profiling, dict)
|
||||||
|
|
||||||
|
|
||||||
def is_ndarray(obj: Any, dtype: "type[np._DTypeScalar_co]") -> "TypeGuard[npt.NDArray[np._DTypeScalar_co]]":
|
def is_ndarray(obj: Any, dtype: "type[np._DTypeScalar_co]") -> "TypeGuard[npt.NDArray[np._DTypeScalar_co]]":
|
||||||
return isinstance(obj, np.ndarray) and obj.dtype == dtype
|
return isinstance(obj, np.ndarray) and obj.dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
|
@ -17,13 +17,15 @@ from pytest import MonkeyPatch
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from app.main import load, preload_models
|
from app.main import load, preload_models
|
||||||
|
from app.models.clip.textual import MClipTextualEncoder, OpenClipTextualEncoder
|
||||||
|
from app.models.clip.visual import OpenClipVisualEncoder
|
||||||
|
from app.models.facial_recognition.detection import FaceDetector
|
||||||
|
from app.models.facial_recognition.recognition import FaceRecognizer
|
||||||
|
|
||||||
from .config import Settings, log, settings
|
from .config import Settings, log, settings
|
||||||
from .models.base import InferenceModel
|
from .models.base import InferenceModel
|
||||||
from .models.cache import ModelCache
|
from .models.cache import ModelCache
|
||||||
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
|
from .schemas import ModelFormat, ModelTask, ModelType
|
||||||
from .models.facial_recognition import FaceRecognizer
|
|
||||||
from .schemas import ModelRuntime, ModelType
|
|
||||||
|
|
||||||
|
|
||||||
class TestBase:
|
class TestBase:
|
||||||
|
@ -35,13 +37,13 @@ class TestBase:
|
||||||
|
|
||||||
@pytest.mark.providers(CPU_EP)
|
@pytest.mark.providers(CPU_EP)
|
||||||
def test_sets_cpu_provider(self, providers: list[str]) -> None:
|
def test_sets_cpu_provider(self, providers: list[str]) -> None:
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.providers == self.CPU_EP
|
assert encoder.providers == self.CPU_EP
|
||||||
|
|
||||||
@pytest.mark.providers(CUDA_EP)
|
@pytest.mark.providers(CUDA_EP)
|
||||||
def test_sets_cuda_provider_if_available(self, providers: list[str]) -> None:
|
def test_sets_cuda_provider_if_available(self, providers: list[str]) -> None:
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.providers == self.CUDA_EP
|
assert encoder.providers == self.CUDA_EP
|
||||||
|
|
||||||
|
@ -50,7 +52,7 @@ class TestBase:
|
||||||
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
|
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
|
||||||
mocked.get_available_openvino_device_ids.return_value = ["GPU.0", "CPU"]
|
mocked.get_available_openvino_device_ids.return_value = ["GPU.0", "CPU"]
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.providers == self.OV_EP
|
assert encoder.providers == self.OV_EP
|
||||||
|
|
||||||
|
@ -59,25 +61,25 @@ class TestBase:
|
||||||
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
|
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
|
||||||
mocked.get_available_openvino_device_ids.return_value = ["CPU"]
|
mocked.get_available_openvino_device_ids.return_value = ["CPU"]
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.providers == self.CPU_EP
|
assert encoder.providers == self.CPU_EP
|
||||||
|
|
||||||
@pytest.mark.providers(CUDA_EP_OUT_OF_ORDER)
|
@pytest.mark.providers(CUDA_EP_OUT_OF_ORDER)
|
||||||
def test_sets_providers_in_correct_order(self, providers: list[str]) -> None:
|
def test_sets_providers_in_correct_order(self, providers: list[str]) -> None:
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.providers == self.CUDA_EP
|
assert encoder.providers == self.CUDA_EP
|
||||||
|
|
||||||
@pytest.mark.providers(TRT_EP)
|
@pytest.mark.providers(TRT_EP)
|
||||||
def test_ignores_unsupported_providers(self, providers: list[str]) -> None:
|
def test_ignores_unsupported_providers(self, providers: list[str]) -> None:
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.providers == self.CUDA_EP
|
assert encoder.providers == self.CUDA_EP
|
||||||
|
|
||||||
def test_sets_provider_kwarg(self) -> None:
|
def test_sets_provider_kwarg(self) -> None:
|
||||||
providers = ["CUDAExecutionProvider"]
|
providers = ["CUDAExecutionProvider"]
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", providers=providers)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", providers=providers)
|
||||||
|
|
||||||
assert encoder.providers == providers
|
assert encoder.providers == providers
|
||||||
|
|
||||||
|
@ -85,7 +87,9 @@ class TestBase:
|
||||||
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
|
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
|
||||||
mocked.get_available_openvino_device_ids.return_value = ["GPU.0", "CPU"]
|
mocked.get_available_openvino_device_ids.return_value = ["GPU.0", "CPU"]
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"])
|
encoder = OpenClipTextualEncoder(
|
||||||
|
"ViT-B-32__openai", providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
|
||||||
assert encoder.provider_options == [
|
assert encoder.provider_options == [
|
||||||
{"device_type": "GPU_FP32", "cache_dir": (encoder.cache_dir / "openvino").as_posix()},
|
{"device_type": "GPU_FP32", "cache_dir": (encoder.cache_dir / "openvino").as_posix()},
|
||||||
|
@ -93,7 +97,7 @@ class TestBase:
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_sets_provider_options_kwarg(self) -> None:
|
def test_sets_provider_options_kwarg(self) -> None:
|
||||||
encoder = OpenCLIPEncoder(
|
encoder = OpenClipTextualEncoder(
|
||||||
"ViT-B-32__openai",
|
"ViT-B-32__openai",
|
||||||
providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"],
|
providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"],
|
||||||
provider_options=[],
|
provider_options=[],
|
||||||
|
@ -102,7 +106,7 @@ class TestBase:
|
||||||
assert encoder.provider_options == []
|
assert encoder.provider_options == []
|
||||||
|
|
||||||
def test_sets_default_sess_options(self) -> None:
|
def test_sets_default_sess_options(self) -> None:
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.sess_options.execution_mode == ort.ExecutionMode.ORT_SEQUENTIAL
|
assert encoder.sess_options.execution_mode == ort.ExecutionMode.ORT_SEQUENTIAL
|
||||||
assert encoder.sess_options.inter_op_num_threads == 1
|
assert encoder.sess_options.inter_op_num_threads == 1
|
||||||
|
@ -110,7 +114,9 @@ class TestBase:
|
||||||
assert encoder.sess_options.enable_cpu_mem_arena is False
|
assert encoder.sess_options.enable_cpu_mem_arena is False
|
||||||
|
|
||||||
def test_sets_default_sess_options_does_not_set_threads_if_non_cpu_and_default_threads(self) -> None:
|
def test_sets_default_sess_options_does_not_set_threads_if_non_cpu_and_default_threads(self) -> None:
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
encoder = OpenClipTextualEncoder(
|
||||||
|
"ViT-B-32__openai", providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
|
||||||
assert encoder.sess_options.inter_op_num_threads == 0
|
assert encoder.sess_options.inter_op_num_threads == 0
|
||||||
assert encoder.sess_options.intra_op_num_threads == 0
|
assert encoder.sess_options.intra_op_num_threads == 0
|
||||||
|
@ -120,14 +126,16 @@ class TestBase:
|
||||||
mock_settings.model_inter_op_threads = 2
|
mock_settings.model_inter_op_threads = 2
|
||||||
mock_settings.model_intra_op_threads = 4
|
mock_settings.model_intra_op_threads = 4
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
encoder = OpenClipTextualEncoder(
|
||||||
|
"ViT-B-32__openai", providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
|
||||||
assert encoder.sess_options.inter_op_num_threads == 2
|
assert encoder.sess_options.inter_op_num_threads == 2
|
||||||
assert encoder.sess_options.intra_op_num_threads == 4
|
assert encoder.sess_options.intra_op_num_threads == 4
|
||||||
|
|
||||||
def test_sets_sess_options_kwarg(self) -> None:
|
def test_sets_sess_options_kwarg(self) -> None:
|
||||||
sess_options = ort.SessionOptions()
|
sess_options = ort.SessionOptions()
|
||||||
encoder = OpenCLIPEncoder(
|
encoder = OpenClipTextualEncoder(
|
||||||
"ViT-B-32__openai",
|
"ViT-B-32__openai",
|
||||||
providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"],
|
providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"],
|
||||||
provider_options=[],
|
provider_options=[],
|
||||||
|
@ -137,43 +145,43 @@ class TestBase:
|
||||||
assert sess_options is encoder.sess_options
|
assert sess_options is encoder.sess_options
|
||||||
|
|
||||||
def test_sets_default_cache_dir(self) -> None:
|
def test_sets_default_cache_dir(self) -> None:
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.cache_dir == Path(settings.cache_folder) / "clip" / "ViT-B-32__openai"
|
assert encoder.cache_dir == Path(settings.cache_folder) / "clip" / "ViT-B-32__openai"
|
||||||
|
|
||||||
def test_sets_cache_dir_kwarg(self) -> None:
|
def test_sets_cache_dir_kwarg(self) -> None:
|
||||||
cache_dir = Path("/test_cache")
|
cache_dir = Path("/test_cache")
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=cache_dir)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=cache_dir)
|
||||||
|
|
||||||
assert encoder.cache_dir == cache_dir
|
assert encoder.cache_dir == cache_dir
|
||||||
|
|
||||||
def test_sets_default_preferred_runtime(self, mocker: MockerFixture) -> None:
|
def test_sets_default_preferred_format(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(settings, "ann", True)
|
mocker.patch.object(settings, "ann", True)
|
||||||
mocker.patch("ann.ann.is_available", False)
|
mocker.patch("ann.ann.is_available", False)
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.preferred_runtime == ModelRuntime.ONNX
|
assert encoder.preferred_format == ModelFormat.ONNX
|
||||||
|
|
||||||
def test_sets_default_preferred_runtime_to_armnn_if_available(self, mocker: MockerFixture) -> None:
|
def test_sets_default_preferred_format_to_armnn_if_available(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(settings, "ann", True)
|
mocker.patch.object(settings, "ann", True)
|
||||||
mocker.patch("ann.ann.is_available", True)
|
mocker.patch("ann.ann.is_available", True)
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.preferred_runtime == ModelRuntime.ARMNN
|
assert encoder.preferred_format == ModelFormat.ARMNN
|
||||||
|
|
||||||
def test_sets_preferred_runtime_kwarg(self, mocker: MockerFixture) -> None:
|
def test_sets_preferred_format_kwarg(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(settings, "ann", False)
|
mocker.patch.object(settings, "ann", False)
|
||||||
mocker.patch("ann.ann.is_available", False)
|
mocker.patch("ann.ann.is_available", False)
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", preferred_runtime=ModelRuntime.ARMNN)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", preferred_format=ModelFormat.ARMNN)
|
||||||
|
|
||||||
assert encoder.preferred_runtime == ModelRuntime.ARMNN
|
assert encoder.preferred_format == ModelFormat.ARMNN
|
||||||
|
|
||||||
def test_casts_cache_dir_string_to_path(self) -> None:
|
def test_casts_cache_dir_string_to_path(self) -> None:
|
||||||
cache_dir = "/test_cache"
|
cache_dir = "/test_cache"
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=cache_dir)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=cache_dir)
|
||||||
|
|
||||||
assert encoder.cache_dir == Path(cache_dir)
|
assert encoder.cache_dir == Path(cache_dir)
|
||||||
|
|
||||||
|
@ -186,7 +194,7 @@ class TestBase:
|
||||||
mocker.patch("app.models.base.Path", return_value=mock_cache_dir)
|
mocker.patch("app.models.base.Path", return_value=mock_cache_dir)
|
||||||
info = mocker.spy(log, "info")
|
info = mocker.spy(log, "info")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=mock_cache_dir)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=mock_cache_dir)
|
||||||
encoder.clear_cache()
|
encoder.clear_cache()
|
||||||
|
|
||||||
mock_rmtree.assert_called_once_with(encoder.cache_dir)
|
mock_rmtree.assert_called_once_with(encoder.cache_dir)
|
||||||
|
@ -201,7 +209,7 @@ class TestBase:
|
||||||
mocker.patch("app.models.base.Path", return_value=mock_cache_dir)
|
mocker.patch("app.models.base.Path", return_value=mock_cache_dir)
|
||||||
warning = mocker.spy(log, "warning")
|
warning = mocker.spy(log, "warning")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=mock_cache_dir)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=mock_cache_dir)
|
||||||
encoder.clear_cache()
|
encoder.clear_cache()
|
||||||
|
|
||||||
mock_rmtree.assert_not_called()
|
mock_rmtree.assert_not_called()
|
||||||
|
@ -215,7 +223,7 @@ class TestBase:
|
||||||
mock_cache_dir.is_dir.return_value = True
|
mock_cache_dir.is_dir.return_value = True
|
||||||
mocker.patch("app.models.base.Path", return_value=mock_cache_dir)
|
mocker.patch("app.models.base.Path", return_value=mock_cache_dir)
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=mock_cache_dir)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=mock_cache_dir)
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
encoder.clear_cache()
|
encoder.clear_cache()
|
||||||
|
|
||||||
|
@ -230,7 +238,7 @@ class TestBase:
|
||||||
mocker.patch("app.models.base.Path", return_value=mock_cache_dir)
|
mocker.patch("app.models.base.Path", return_value=mock_cache_dir)
|
||||||
warning = mocker.spy(log, "warning")
|
warning = mocker.spy(log, "warning")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=mock_cache_dir)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=mock_cache_dir)
|
||||||
encoder.clear_cache()
|
encoder.clear_cache()
|
||||||
|
|
||||||
mock_rmtree.assert_not_called()
|
mock_rmtree.assert_not_called()
|
||||||
|
@ -245,7 +253,7 @@ class TestBase:
|
||||||
mock_model_path.with_suffix.return_value = mock_model_path
|
mock_model_path.with_suffix.return_value = mock_model_path
|
||||||
mock_ann = mocker.patch("app.models.base.AnnSession")
|
mock_ann = mocker.patch("app.models.base.AnnSession")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
encoder._make_session(mock_model_path)
|
encoder._make_session(mock_model_path)
|
||||||
|
|
||||||
mock_ann.assert_called_once()
|
mock_ann.assert_called_once()
|
||||||
|
@ -263,7 +271,7 @@ class TestBase:
|
||||||
mock_ann = mocker.patch("app.models.base.AnnSession")
|
mock_ann = mocker.patch("app.models.base.AnnSession")
|
||||||
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
encoder._make_session(mock_armnn_path)
|
encoder._make_session(mock_armnn_path)
|
||||||
|
|
||||||
mock_ort.assert_called_once()
|
mock_ort.assert_called_once()
|
||||||
|
@ -277,7 +285,7 @@ class TestBase:
|
||||||
mock_ann = mocker.patch("app.models.base.AnnSession")
|
mock_ann = mocker.patch("app.models.base.AnnSession")
|
||||||
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
encoder._make_session(mock_model_path)
|
encoder._make_session(mock_model_path)
|
||||||
|
|
||||||
|
@ -287,7 +295,7 @@ class TestBase:
|
||||||
def test_download(self, mocker: MockerFixture) -> None:
|
def test_download(self, mocker: MockerFixture) -> None:
|
||||||
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="/path/to/cache")
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="/path/to/cache")
|
||||||
encoder.download()
|
encoder.download()
|
||||||
|
|
||||||
mock_snapshot_download.assert_called_once_with(
|
mock_snapshot_download.assert_called_once_with(
|
||||||
|
@ -298,10 +306,10 @@ class TestBase:
|
||||||
ignore_patterns=["*.armnn"],
|
ignore_patterns=["*.armnn"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_download_downloads_armnn_if_preferred_runtime(self, mocker: MockerFixture) -> None:
|
def test_download_downloads_armnn_if_preferred_format(self, mocker: MockerFixture) -> None:
|
||||||
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", preferred_runtime=ModelRuntime.ARMNN)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", preferred_format=ModelFormat.ARMNN)
|
||||||
encoder.download()
|
encoder.download()
|
||||||
|
|
||||||
mock_snapshot_download.assert_called_once_with(
|
mock_snapshot_download.assert_called_once_with(
|
||||||
|
@ -323,21 +331,17 @@ class TestCLIP:
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
clip_model_cfg: dict[str, Any],
|
clip_model_cfg: dict[str, Any],
|
||||||
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
||||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
mocker.patch.object(OpenCLIPEncoder, "download")
|
mocker.patch.object(OpenClipVisualEncoder, "download")
|
||||||
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
|
mocker.patch.object(OpenClipVisualEncoder, "model_cfg", clip_model_cfg)
|
||||||
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
mocker.patch.object(OpenClipVisualEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
||||||
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
|
||||||
|
|
||||||
mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||||
mocked.run.return_value = [[self.embedding]]
|
mocked.run.return_value = [[self.embedding]]
|
||||||
mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True)
|
|
||||||
|
|
||||||
clip_encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="vision")
|
clip_encoder = OpenClipVisualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||||
embedding = clip_encoder.predict(pil_image)
|
embedding = clip_encoder.predict(pil_image)
|
||||||
|
|
||||||
assert clip_encoder.mode == "vision"
|
|
||||||
assert isinstance(embedding, np.ndarray)
|
assert isinstance(embedding, np.ndarray)
|
||||||
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
|
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
|
||||||
assert embedding.dtype == np.float32
|
assert embedding.dtype == np.float32
|
||||||
|
@ -347,22 +351,19 @@ class TestCLIP:
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
clip_model_cfg: dict[str, Any],
|
clip_model_cfg: dict[str, Any],
|
||||||
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
|
||||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
mocker.patch.object(OpenCLIPEncoder, "download")
|
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||||
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
|
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||||
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||||
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
|
||||||
|
|
||||||
mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||||
mocked.run.return_value = [[self.embedding]]
|
mocked.run.return_value = [[self.embedding]]
|
||||||
mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True)
|
mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True)
|
||||||
|
|
||||||
clip_encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
|
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||||
embedding = clip_encoder.predict("test search query")
|
embedding = clip_encoder.predict("test search query")
|
||||||
|
|
||||||
assert clip_encoder.mode == "text"
|
|
||||||
assert isinstance(embedding, np.ndarray)
|
assert isinstance(embedding, np.ndarray)
|
||||||
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
|
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
|
||||||
assert embedding.dtype == np.float32
|
assert embedding.dtype == np.float32
|
||||||
|
@ -372,19 +373,18 @@ class TestCLIP:
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
clip_model_cfg: dict[str, Any],
|
clip_model_cfg: dict[str, Any],
|
||||||
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
|
||||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
mocker.patch.object(OpenCLIPEncoder, "download")
|
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||||
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
|
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||||
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||||
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||||
mock_tokenizer = mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True).return_value
|
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||||
|
|
||||||
clip_encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
|
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||||
clip_encoder._load_tokenizer()
|
clip_encoder._load()
|
||||||
tokens = clip_encoder.tokenize("test search query")
|
tokens = clip_encoder.tokenize("test search query")
|
||||||
|
|
||||||
assert "text" in tokens
|
assert "text" in tokens
|
||||||
|
@ -397,20 +397,19 @@ class TestCLIP:
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
clip_model_cfg: dict[str, Any],
|
clip_model_cfg: dict[str, Any],
|
||||||
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
|
||||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
mocker.patch.object(OpenCLIPEncoder, "download")
|
mocker.patch.object(MClipTextualEncoder, "download")
|
||||||
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
|
mocker.patch.object(MClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||||
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
mocker.patch.object(MClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||||
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||||
mock_tokenizer = mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True).return_value
|
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||||
mock_attention_mask = [randint(0, 1) for _ in range(77)]
|
mock_attention_mask = [randint(0, 1) for _ in range(77)]
|
||||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids, attention_mask=mock_attention_mask)
|
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids, attention_mask=mock_attention_mask)
|
||||||
|
|
||||||
clip_encoder = MCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
|
clip_encoder = MClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||||
clip_encoder._load_tokenizer()
|
clip_encoder._load()
|
||||||
tokens = clip_encoder.tokenize("test search query")
|
tokens = clip_encoder.tokenize("test search query")
|
||||||
|
|
||||||
assert "input_ids" in tokens
|
assert "input_ids" in tokens
|
||||||
|
@ -430,59 +429,90 @@ class TestFaceRecognition:
|
||||||
|
|
||||||
assert face_recognizer.min_score == 0.5
|
assert face_recognizer.min_score == 0.5
|
||||||
|
|
||||||
def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
def test_detection(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(FaceRecognizer, "load")
|
mocker.patch.object(FaceDetector, "load")
|
||||||
face_recognizer = FaceRecognizer("buffalo_s", min_score=0.0, cache_dir="test_cache")
|
face_detector = FaceDetector("buffalo_s", min_score=0.0, cache_dir="test_cache")
|
||||||
|
|
||||||
det_model = mock.Mock()
|
det_model = mock.Mock()
|
||||||
num_faces = 2
|
num_faces = 2
|
||||||
bbox = np.random.rand(num_faces, 4).astype(np.float32)
|
bbox = np.random.rand(num_faces, 4).astype(np.float32)
|
||||||
score = np.array([[0.67]] * num_faces).astype(np.float32)
|
scores = np.array([[0.67]] * num_faces).astype(np.float32)
|
||||||
kpss = np.random.rand(num_faces, 5, 2).astype(np.float32)
|
kpss = np.random.rand(num_faces, 5, 2).astype(np.float32)
|
||||||
det_model.detect.return_value = (np.concatenate([bbox, score], axis=-1), kpss)
|
det_model.detect.return_value = (np.concatenate([bbox, scores], axis=-1), kpss)
|
||||||
face_recognizer.det_model = det_model
|
face_detector.model = det_model
|
||||||
|
|
||||||
|
faces = face_detector.predict(cv_image)
|
||||||
|
|
||||||
|
assert isinstance(faces, dict)
|
||||||
|
assert isinstance(faces.get("boxes", None), np.ndarray)
|
||||||
|
assert isinstance(faces.get("landmarks", None), np.ndarray)
|
||||||
|
assert isinstance(faces.get("scores", None), np.ndarray)
|
||||||
|
assert np.equal(faces["boxes"], bbox.round()).all()
|
||||||
|
assert np.equal(faces["landmarks"], kpss).all()
|
||||||
|
assert np.equal(faces["scores"], scores).all()
|
||||||
|
det_model.detect.assert_called_once()
|
||||||
|
|
||||||
|
def test_recognition(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
||||||
|
mocker.patch.object(FaceRecognizer, "load")
|
||||||
|
face_recognizer = FaceRecognizer("buffalo_s", min_score=0.0, cache_dir="test_cache")
|
||||||
|
|
||||||
|
num_faces = 2
|
||||||
|
bbox = np.random.rand(num_faces, 4).astype(np.float32)
|
||||||
|
scores = np.array([0.67] * num_faces).astype(np.float32)
|
||||||
|
kpss = np.random.rand(num_faces, 5, 2).astype(np.float32)
|
||||||
|
faces = {"boxes": bbox, "landmarks": kpss, "scores": scores}
|
||||||
|
|
||||||
rec_model = mock.Mock()
|
rec_model = mock.Mock()
|
||||||
embedding = np.random.rand(num_faces, 512).astype(np.float32)
|
embedding = np.random.rand(num_faces, 512).astype(np.float32)
|
||||||
rec_model.get_feat.return_value = embedding
|
rec_model.get_feat.return_value = embedding
|
||||||
face_recognizer.rec_model = rec_model
|
face_recognizer.model = rec_model
|
||||||
|
|
||||||
faces = face_recognizer.predict(cv_image)
|
faces = face_recognizer.predict(cv_image, faces)
|
||||||
|
|
||||||
|
assert isinstance(faces, list)
|
||||||
assert len(faces) == num_faces
|
assert len(faces) == num_faces
|
||||||
for face in faces:
|
for face in faces:
|
||||||
assert face["imageHeight"] == 800
|
assert isinstance(face.get("boundingBox"), dict)
|
||||||
assert face["imageWidth"] == 600
|
assert set(face["boundingBox"]) == {"x1", "y1", "x2", "y2"}
|
||||||
assert isinstance(face["embedding"], np.ndarray)
|
assert all(isinstance(val, np.float32) for val in face["boundingBox"].values())
|
||||||
|
assert isinstance(face.get("embedding"), np.ndarray)
|
||||||
assert face["embedding"].shape[0] == 512
|
assert face["embedding"].shape[0] == 512
|
||||||
assert face["embedding"].dtype == np.float32
|
assert isinstance(face.get("score", None), np.float32)
|
||||||
|
|
||||||
det_model.detect.assert_called_once()
|
rec_model.get_feat.assert_called_once()
|
||||||
assert rec_model.get_feat.call_count == num_faces
|
call_args = rec_model.get_feat.call_args_list[0].args
|
||||||
|
assert len(call_args) == 1
|
||||||
|
assert isinstance(call_args[0], list)
|
||||||
|
assert isinstance(call_args[0][0], np.ndarray)
|
||||||
|
assert call_args[0][0].shape == (112, 112, 3)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestCache:
|
class TestCache:
|
||||||
async def test_caches(self, mock_get_model: mock.Mock) -> None:
|
async def test_caches(self, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache()
|
model_cache = ModelCache()
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
await model_cache.get("test_model_name", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
await model_cache.get("test_model_name", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
||||||
assert len(model_cache.cache._cache) == 1
|
assert len(model_cache.cache._cache) == 1
|
||||||
mock_get_model.assert_called_once()
|
mock_get_model.assert_called_once()
|
||||||
|
|
||||||
async def test_kwargs_used(self, mock_get_model: mock.Mock) -> None:
|
async def test_kwargs_used(self, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache()
|
model_cache = ModelCache()
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, cache_dir="test_cache")
|
await model_cache.get(
|
||||||
mock_get_model.assert_called_once_with(ModelType.FACIAL_RECOGNITION, "test_model_name", cache_dir="test_cache")
|
"test_model_name", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION, cache_dir="test_cache"
|
||||||
|
)
|
||||||
|
mock_get_model.assert_called_once_with(
|
||||||
|
"test_model_name", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION, cache_dir="test_cache"
|
||||||
|
)
|
||||||
|
|
||||||
async def test_different_clip(self, mock_get_model: mock.Mock) -> None:
|
async def test_different_clip(self, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache()
|
model_cache = ModelCache()
|
||||||
await model_cache.get("test_image_model_name", ModelType.CLIP)
|
await model_cache.get("test_model_name", ModelType.VISUAL, ModelTask.SEARCH)
|
||||||
await model_cache.get("test_text_model_name", ModelType.CLIP)
|
await model_cache.get("test_model_name", ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||||
mock_get_model.assert_has_calls(
|
mock_get_model.assert_has_calls(
|
||||||
[
|
[
|
||||||
mock.call(ModelType.CLIP, "test_image_model_name"),
|
mock.call("test_model_name", ModelType.VISUAL, ModelTask.SEARCH),
|
||||||
mock.call(ModelType.CLIP, "test_text_model_name"),
|
mock.call("test_model_name", ModelType.TEXTUAL, ModelTask.SEARCH),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert len(model_cache.cache._cache) == 2
|
assert len(model_cache.cache._cache) == 2
|
||||||
|
@ -490,19 +520,19 @@ class TestCache:
|
||||||
@mock.patch("app.models.cache.OptimisticLock", autospec=True)
|
@mock.patch("app.models.cache.OptimisticLock", autospec=True)
|
||||||
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
|
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache()
|
model_cache = ModelCache()
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
await model_cache.get("test_model_name", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION, ttl=100)
|
||||||
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
||||||
|
|
||||||
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
||||||
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache(revalidate=True)
|
model_cache = ModelCache(revalidate=True)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
await model_cache.get("test_model_name", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION, ttl=100)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
await model_cache.get("test_model_name", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION, ttl=100)
|
||||||
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
||||||
|
|
||||||
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
|
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache(profiling=True)
|
model_cache = ModelCache(profiling=True)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
await model_cache.get("test_model_name", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION, ttl=100)
|
||||||
profiling = await model_cache.get_profiling()
|
profiling = await model_cache.get_profiling()
|
||||||
assert isinstance(profiling, dict)
|
assert isinstance(profiling, dict)
|
||||||
assert profiling == model_cache.cache.profiling
|
assert profiling == model_cache.cache.profiling
|
||||||
|
@ -510,9 +540,9 @@ class TestCache:
|
||||||
async def test_loads_mclip(self) -> None:
|
async def test_loads_mclip(self) -> None:
|
||||||
model_cache = ModelCache()
|
model_cache = ModelCache()
|
||||||
|
|
||||||
model = await model_cache.get("XLM-Roberta-Large-Vit-B-32", ModelType.CLIP, mode="text")
|
model = await model_cache.get("XLM-Roberta-Large-Vit-B-32", ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||||
|
|
||||||
assert isinstance(model, MCLIPEncoder)
|
assert isinstance(model, MClipTextualEncoder)
|
||||||
assert model.model_name == "XLM-Roberta-Large-Vit-B-32"
|
assert model.model_name == "XLM-Roberta-Large-Vit-B-32"
|
||||||
|
|
||||||
async def test_raises_exception_if_invalid_model_type(self) -> None:
|
async def test_raises_exception_if_invalid_model_type(self) -> None:
|
||||||
|
@ -520,15 +550,55 @@ class TestCache:
|
||||||
model_cache = ModelCache()
|
model_cache = ModelCache()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await model_cache.get("XLM-Roberta-Large-Vit-B-32", invalid, mode="text")
|
await model_cache.get("XLM-Roberta-Large-Vit-B-32", ModelType.TEXTUAL, invalid)
|
||||||
|
|
||||||
async def test_raises_exception_if_unknown_model_name(self) -> None:
|
async def test_raises_exception_if_unknown_model_name(self) -> None:
|
||||||
model_cache = ModelCache()
|
model_cache = ModelCache()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
|
await model_cache.get("test_model_name", ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||||
|
|
||||||
async def test_preloads_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
|
async def test_preloads_clip_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
|
||||||
|
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.preload is not None
|
||||||
|
assert settings.preload.clip == "ViT-B-32__openai"
|
||||||
|
|
||||||
|
model_cache = ModelCache()
|
||||||
|
monkeypatch.setattr("app.main.model_cache", model_cache)
|
||||||
|
|
||||||
|
await preload_models(settings.preload)
|
||||||
|
mock_get_model.assert_has_calls(
|
||||||
|
[
|
||||||
|
mock.call("ViT-B-32__openai", ModelType.TEXTUAL, ModelTask.SEARCH),
|
||||||
|
mock.call("ViT-B-32__openai", ModelType.VISUAL, ModelTask.SEARCH),
|
||||||
|
],
|
||||||
|
any_order=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_preloads_facial_recognition_models(
|
||||||
|
self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock
|
||||||
|
) -> None:
|
||||||
|
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.preload is not None
|
||||||
|
assert settings.preload.facial_recognition == "buffalo_s"
|
||||||
|
|
||||||
|
model_cache = ModelCache()
|
||||||
|
monkeypatch.setattr("app.main.model_cache", model_cache)
|
||||||
|
|
||||||
|
await preload_models(settings.preload)
|
||||||
|
mock_get_model.assert_has_calls(
|
||||||
|
[
|
||||||
|
mock.call("buffalo_s", ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION),
|
||||||
|
mock.call("buffalo_s", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION),
|
||||||
|
],
|
||||||
|
any_order=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_preloads_all_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
|
||||||
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
|
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
|
||||||
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
|
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
|
||||||
|
|
||||||
|
@ -541,11 +611,15 @@ class TestCache:
|
||||||
monkeypatch.setattr("app.main.model_cache", model_cache)
|
monkeypatch.setattr("app.main.model_cache", model_cache)
|
||||||
|
|
||||||
await preload_models(settings.preload)
|
await preload_models(settings.preload)
|
||||||
assert len(model_cache.cache._cache) == 2
|
mock_get_model.assert_has_calls(
|
||||||
assert mock_get_model.call_count == 2
|
[
|
||||||
await model_cache.get("ViT-B-32__openai", ModelType.CLIP, ttl=100)
|
mock.call("ViT-B-32__openai", ModelType.TEXTUAL, ModelTask.SEARCH),
|
||||||
await model_cache.get("buffalo_s", ModelType.FACIAL_RECOGNITION, ttl=100)
|
mock.call("ViT-B-32__openai", ModelType.VISUAL, ModelTask.SEARCH),
|
||||||
assert mock_get_model.call_count == 2
|
mock.call("buffalo_s", ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION),
|
||||||
|
mock.call("buffalo_s", ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION),
|
||||||
|
],
|
||||||
|
any_order=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -572,7 +646,8 @@ class TestLoad:
|
||||||
async def test_load_clears_cache_and_retries_if_os_error(self) -> None:
|
async def test_load_clears_cache_and_retries_if_os_error(self) -> None:
|
||||||
mock_model = mock.Mock(spec=InferenceModel)
|
mock_model = mock.Mock(spec=InferenceModel)
|
||||||
mock_model.model_name = "test_model_name"
|
mock_model.model_name = "test_model_name"
|
||||||
mock_model.model_type = ModelType.CLIP
|
mock_model.model_type = ModelType.VISUAL
|
||||||
|
mock_model.model_task = ModelTask.SEARCH
|
||||||
mock_model.load.side_effect = [OSError, None]
|
mock_model.load.side_effect = [OSError, None]
|
||||||
mock_model.loaded = False
|
mock_model.loaded = False
|
||||||
|
|
||||||
|
@ -597,13 +672,15 @@ class TestEndpoints:
|
||||||
|
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/predict",
|
"http://localhost:3003/predict",
|
||||||
data={"modelName": "ViT-B-32__openai", "modelType": "clip", "options": json.dumps({"mode": "vision"})},
|
data={"entries": json.dumps({"clip": {"visual": {"modelName": "ViT-B-32__openai"}}})},
|
||||||
files={"image": byte_image.getvalue()},
|
files={"image": byte_image.getvalue()},
|
||||||
)
|
)
|
||||||
|
|
||||||
actual = response.json()
|
actual = response.json()
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert np.allclose(expected, actual)
|
assert isinstance(actual, dict)
|
||||||
|
assert isinstance(actual.get("clip", None), list)
|
||||||
|
assert np.allclose(expected, actual["clip"])
|
||||||
|
|
||||||
def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
||||||
expected = responses["clip"]["text"]
|
expected = responses["clip"]["text"]
|
||||||
|
@ -611,38 +688,49 @@ class TestEndpoints:
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/predict",
|
"http://localhost:3003/predict",
|
||||||
data={
|
data={
|
||||||
"modelName": "ViT-B-32__openai",
|
"entries": json.dumps(
|
||||||
"modelType": "clip",
|
{
|
||||||
|
"clip": {"textual": {"modelName": "ViT-B-32__openai"}},
|
||||||
|
},
|
||||||
|
),
|
||||||
"text": "test search query",
|
"text": "test search query",
|
||||||
"options": json.dumps({"mode": "text"}),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
actual = response.json()
|
actual = response.json()
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert np.allclose(expected, actual)
|
assert isinstance(actual, dict)
|
||||||
|
assert isinstance(actual.get("clip", None), list)
|
||||||
|
assert np.allclose(expected, actual["clip"])
|
||||||
|
|
||||||
def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
||||||
byte_image = BytesIO()
|
byte_image = BytesIO()
|
||||||
pil_image.save(byte_image, format="jpeg")
|
pil_image.save(byte_image, format="jpeg")
|
||||||
expected = responses["facial-recognition"]
|
|
||||||
|
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/predict",
|
"http://localhost:3003/predict",
|
||||||
data={
|
data={
|
||||||
"modelName": "buffalo_l",
|
"entries": json.dumps(
|
||||||
"modelType": "facial-recognition",
|
{
|
||||||
"options": json.dumps({"minScore": 0.034}),
|
"facial-recognition": {
|
||||||
|
"detection": {"modelName": "buffalo_l", "options": {"minScore": 0.034}},
|
||||||
|
"recognition": {"modelName": "buffalo_l"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
},
|
},
|
||||||
files={"image": byte_image.getvalue()},
|
files={"image": byte_image.getvalue()},
|
||||||
)
|
)
|
||||||
|
|
||||||
actual = response.json()
|
actual = response.json()
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert len(expected) == len(actual)
|
assert isinstance(actual, dict)
|
||||||
for expected_face, actual_face in zip(expected, actual):
|
assert actual.get("imageHeight", None) == responses["imageHeight"]
|
||||||
assert expected_face["imageHeight"] == actual_face["imageHeight"]
|
assert actual.get("imageWidth", None) == responses["imageWidth"]
|
||||||
assert expected_face["imageWidth"] == actual_face["imageWidth"]
|
assert "facial-recognition" in actual and isinstance(actual["facial-recognition"], list)
|
||||||
|
assert len(actual["facial-recognition"]) == len(responses["facial-recognition"])
|
||||||
|
|
||||||
|
for expected_face, actual_face in zip(responses["facial-recognition"], actual["facial-recognition"]):
|
||||||
assert expected_face["boundingBox"] == actual_face["boundingBox"]
|
assert expected_face["boundingBox"] == actual_face["boundingBox"]
|
||||||
assert np.allclose(expected_face["embedding"], actual_face["embedding"])
|
assert np.allclose(expected_face["embedding"], actual_face["embedding"])
|
||||||
assert np.allclose(expected_face["score"], actual_face["score"])
|
assert np.allclose(expected_face["score"], actual_face["score"])
|
||||||
|
|
|
@ -37,7 +37,6 @@ def on_test_start(environment: Environment, **kwargs: Any) -> None:
|
||||||
global byte_image
|
global byte_image
|
||||||
assert environment.parsed_options is not None
|
assert environment.parsed_options is not None
|
||||||
image = Image.new("RGB", (environment.parsed_options.image_size, environment.parsed_options.image_size))
|
image = Image.new("RGB", (environment.parsed_options.image_size, environment.parsed_options.image_size))
|
||||||
byte_image = BytesIO()
|
|
||||||
image.save(byte_image, format="jpeg")
|
image.save(byte_image, format="jpeg")
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,34 +44,25 @@ class InferenceLoadTest(HttpUser):
|
||||||
abstract: bool = True
|
abstract: bool = True
|
||||||
host = "http://127.0.0.1:3003"
|
host = "http://127.0.0.1:3003"
|
||||||
data: bytes
|
data: bytes
|
||||||
headers: dict[str, str] = {"Content-Type": "image/jpg"}
|
|
||||||
|
|
||||||
# re-use the image across all instances in a process
|
# re-use the image across all instances in a process
|
||||||
def on_start(self) -> None:
|
def on_start(self) -> None:
|
||||||
global byte_image
|
|
||||||
self.data = byte_image.getvalue()
|
self.data = byte_image.getvalue()
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextFormDataLoadTest(InferenceLoadTest):
|
class CLIPTextFormDataLoadTest(InferenceLoadTest):
|
||||||
@task
|
@task
|
||||||
def encode_text(self) -> None:
|
def encode_text(self) -> None:
|
||||||
data = [
|
request = {"clip": {"textual": {"modelName": self.environment.parsed_options.clip_model}}}
|
||||||
("modelName", self.environment.parsed_options.clip_model),
|
data = [("entries", json.dumps(request)), ("text", "test search query")]
|
||||||
("modelType", "clip"),
|
|
||||||
("options", json.dumps({"mode": "text"})),
|
|
||||||
("text", "test search query"),
|
|
||||||
]
|
|
||||||
self.client.post("/predict", data=data)
|
self.client.post("/predict", data=data)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionFormDataLoadTest(InferenceLoadTest):
|
class CLIPVisionFormDataLoadTest(InferenceLoadTest):
|
||||||
@task
|
@task
|
||||||
def encode_image(self) -> None:
|
def encode_image(self) -> None:
|
||||||
data = [
|
request = {"clip": {"visual": {"modelName": self.environment.parsed_options.clip_model, "options": {}}}}
|
||||||
("modelName", self.environment.parsed_options.clip_model),
|
data = [("entries", json.dumps(request))]
|
||||||
("modelType", "clip"),
|
|
||||||
("options", json.dumps({"mode": "vision"})),
|
|
||||||
]
|
|
||||||
files = {"image": self.data}
|
files = {"image": self.data}
|
||||||
self.client.post("/predict", data=data, files=files)
|
self.client.post("/predict", data=data, files=files)
|
||||||
|
|
||||||
|
@ -80,11 +70,18 @@ class CLIPVisionFormDataLoadTest(InferenceLoadTest):
|
||||||
class RecognitionFormDataLoadTest(InferenceLoadTest):
|
class RecognitionFormDataLoadTest(InferenceLoadTest):
|
||||||
@task
|
@task
|
||||||
def recognize(self) -> None:
|
def recognize(self) -> None:
|
||||||
data = [
|
request = {
|
||||||
("modelName", self.environment.parsed_options.face_model),
|
"facial-recognition": {
|
||||||
("modelType", "facial-recognition"),
|
"recognition": {
|
||||||
("options", json.dumps({"minScore": self.environment.parsed_options.face_min_score})),
|
"modelName": self.environment.parsed_options.face_model,
|
||||||
]
|
"options": {"minScore": self.environment.parsed_options.face_min_score},
|
||||||
|
},
|
||||||
|
"detection": {
|
||||||
|
"modelName": self.environment.parsed_options.face_model,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data = [("entries", json.dumps(request))]
|
||||||
files = {"image": self.data}
|
files = {"image": self.data}
|
||||||
|
|
||||||
self.client.post("/predict", data=data, files=files)
|
self.client.post("/predict", data=data, files=files)
|
||||||
|
|
|
@ -213,8 +213,6 @@
|
||||||
},
|
},
|
||||||
"facial-recognition": [
|
"facial-recognition": [
|
||||||
{
|
{
|
||||||
"imageWidth": 600,
|
|
||||||
"imageHeight": 800,
|
|
||||||
"boundingBox": {
|
"boundingBox": {
|
||||||
"x1": 690.0,
|
"x1": 690.0,
|
||||||
"y1": -89.0,
|
"y1": -89.0,
|
||||||
|
@ -325,5 +323,7 @@
|
||||||
-0.077056274, 0.002099529
|
-0.077056274, 0.002099529
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"imageWidth": 600,
|
||||||
|
"imageHeight": 800
|
||||||
}
|
}
|
||||||
|
|
4
mobile/openapi/README.md
generated
4
mobile/openapi/README.md
generated
|
@ -279,7 +279,6 @@ Class | Method | HTTP request | Description
|
||||||
- [BulkIdResponseDto](doc//BulkIdResponseDto.md)
|
- [BulkIdResponseDto](doc//BulkIdResponseDto.md)
|
||||||
- [BulkIdsDto](doc//BulkIdsDto.md)
|
- [BulkIdsDto](doc//BulkIdsDto.md)
|
||||||
- [CLIPConfig](doc//CLIPConfig.md)
|
- [CLIPConfig](doc//CLIPConfig.md)
|
||||||
- [CLIPMode](doc//CLIPMode.md)
|
|
||||||
- [CQMode](doc//CQMode.md)
|
- [CQMode](doc//CQMode.md)
|
||||||
- [ChangePasswordDto](doc//ChangePasswordDto.md)
|
- [ChangePasswordDto](doc//ChangePasswordDto.md)
|
||||||
- [CheckExistingAssetsDto](doc//CheckExistingAssetsDto.md)
|
- [CheckExistingAssetsDto](doc//CheckExistingAssetsDto.md)
|
||||||
|
@ -299,6 +298,7 @@ Class | Method | HTTP request | Description
|
||||||
- [EntityType](doc//EntityType.md)
|
- [EntityType](doc//EntityType.md)
|
||||||
- [ExifResponseDto](doc//ExifResponseDto.md)
|
- [ExifResponseDto](doc//ExifResponseDto.md)
|
||||||
- [FaceDto](doc//FaceDto.md)
|
- [FaceDto](doc//FaceDto.md)
|
||||||
|
- [FacialRecognitionConfig](doc//FacialRecognitionConfig.md)
|
||||||
- [FileChecksumDto](doc//FileChecksumDto.md)
|
- [FileChecksumDto](doc//FileChecksumDto.md)
|
||||||
- [FileChecksumResponseDto](doc//FileChecksumResponseDto.md)
|
- [FileChecksumResponseDto](doc//FileChecksumResponseDto.md)
|
||||||
- [FileReportDto](doc//FileReportDto.md)
|
- [FileReportDto](doc//FileReportDto.md)
|
||||||
|
@ -328,7 +328,6 @@ Class | Method | HTTP request | Description
|
||||||
- [MemoryUpdateDto](doc//MemoryUpdateDto.md)
|
- [MemoryUpdateDto](doc//MemoryUpdateDto.md)
|
||||||
- [MergePersonDto](doc//MergePersonDto.md)
|
- [MergePersonDto](doc//MergePersonDto.md)
|
||||||
- [MetadataSearchDto](doc//MetadataSearchDto.md)
|
- [MetadataSearchDto](doc//MetadataSearchDto.md)
|
||||||
- [ModelType](doc//ModelType.md)
|
|
||||||
- [OAuthAuthorizeResponseDto](doc//OAuthAuthorizeResponseDto.md)
|
- [OAuthAuthorizeResponseDto](doc//OAuthAuthorizeResponseDto.md)
|
||||||
- [OAuthCallbackDto](doc//OAuthCallbackDto.md)
|
- [OAuthCallbackDto](doc//OAuthCallbackDto.md)
|
||||||
- [OAuthConfigDto](doc//OAuthConfigDto.md)
|
- [OAuthConfigDto](doc//OAuthConfigDto.md)
|
||||||
|
@ -348,7 +347,6 @@ Class | Method | HTTP request | Description
|
||||||
- [QueueStatusDto](doc//QueueStatusDto.md)
|
- [QueueStatusDto](doc//QueueStatusDto.md)
|
||||||
- [ReactionLevel](doc//ReactionLevel.md)
|
- [ReactionLevel](doc//ReactionLevel.md)
|
||||||
- [ReactionType](doc//ReactionType.md)
|
- [ReactionType](doc//ReactionType.md)
|
||||||
- [RecognitionConfig](doc//RecognitionConfig.md)
|
|
||||||
- [ReverseGeocodingStateResponseDto](doc//ReverseGeocodingStateResponseDto.md)
|
- [ReverseGeocodingStateResponseDto](doc//ReverseGeocodingStateResponseDto.md)
|
||||||
- [ScanLibraryDto](doc//ScanLibraryDto.md)
|
- [ScanLibraryDto](doc//ScanLibraryDto.md)
|
||||||
- [SearchAlbumResponseDto](doc//SearchAlbumResponseDto.md)
|
- [SearchAlbumResponseDto](doc//SearchAlbumResponseDto.md)
|
||||||
|
|
4
mobile/openapi/lib/api.dart
generated
4
mobile/openapi/lib/api.dart
generated
|
@ -106,7 +106,6 @@ part 'model/avatar_update.dart';
|
||||||
part 'model/bulk_id_response_dto.dart';
|
part 'model/bulk_id_response_dto.dart';
|
||||||
part 'model/bulk_ids_dto.dart';
|
part 'model/bulk_ids_dto.dart';
|
||||||
part 'model/clip_config.dart';
|
part 'model/clip_config.dart';
|
||||||
part 'model/clip_mode.dart';
|
|
||||||
part 'model/cq_mode.dart';
|
part 'model/cq_mode.dart';
|
||||||
part 'model/change_password_dto.dart';
|
part 'model/change_password_dto.dart';
|
||||||
part 'model/check_existing_assets_dto.dart';
|
part 'model/check_existing_assets_dto.dart';
|
||||||
|
@ -126,6 +125,7 @@ part 'model/email_notifications_update.dart';
|
||||||
part 'model/entity_type.dart';
|
part 'model/entity_type.dart';
|
||||||
part 'model/exif_response_dto.dart';
|
part 'model/exif_response_dto.dart';
|
||||||
part 'model/face_dto.dart';
|
part 'model/face_dto.dart';
|
||||||
|
part 'model/facial_recognition_config.dart';
|
||||||
part 'model/file_checksum_dto.dart';
|
part 'model/file_checksum_dto.dart';
|
||||||
part 'model/file_checksum_response_dto.dart';
|
part 'model/file_checksum_response_dto.dart';
|
||||||
part 'model/file_report_dto.dart';
|
part 'model/file_report_dto.dart';
|
||||||
|
@ -155,7 +155,6 @@ part 'model/memory_update.dart';
|
||||||
part 'model/memory_update_dto.dart';
|
part 'model/memory_update_dto.dart';
|
||||||
part 'model/merge_person_dto.dart';
|
part 'model/merge_person_dto.dart';
|
||||||
part 'model/metadata_search_dto.dart';
|
part 'model/metadata_search_dto.dart';
|
||||||
part 'model/model_type.dart';
|
|
||||||
part 'model/o_auth_authorize_response_dto.dart';
|
part 'model/o_auth_authorize_response_dto.dart';
|
||||||
part 'model/o_auth_callback_dto.dart';
|
part 'model/o_auth_callback_dto.dart';
|
||||||
part 'model/o_auth_config_dto.dart';
|
part 'model/o_auth_config_dto.dart';
|
||||||
|
@ -175,7 +174,6 @@ part 'model/places_response_dto.dart';
|
||||||
part 'model/queue_status_dto.dart';
|
part 'model/queue_status_dto.dart';
|
||||||
part 'model/reaction_level.dart';
|
part 'model/reaction_level.dart';
|
||||||
part 'model/reaction_type.dart';
|
part 'model/reaction_type.dart';
|
||||||
part 'model/recognition_config.dart';
|
|
||||||
part 'model/reverse_geocoding_state_response_dto.dart';
|
part 'model/reverse_geocoding_state_response_dto.dart';
|
||||||
part 'model/scan_library_dto.dart';
|
part 'model/scan_library_dto.dart';
|
||||||
part 'model/search_album_response_dto.dart';
|
part 'model/search_album_response_dto.dart';
|
||||||
|
|
8
mobile/openapi/lib/api_client.dart
generated
8
mobile/openapi/lib/api_client.dart
generated
|
@ -276,8 +276,6 @@ class ApiClient {
|
||||||
return BulkIdsDto.fromJson(value);
|
return BulkIdsDto.fromJson(value);
|
||||||
case 'CLIPConfig':
|
case 'CLIPConfig':
|
||||||
return CLIPConfig.fromJson(value);
|
return CLIPConfig.fromJson(value);
|
||||||
case 'CLIPMode':
|
|
||||||
return CLIPModeTypeTransformer().decode(value);
|
|
||||||
case 'CQMode':
|
case 'CQMode':
|
||||||
return CQModeTypeTransformer().decode(value);
|
return CQModeTypeTransformer().decode(value);
|
||||||
case 'ChangePasswordDto':
|
case 'ChangePasswordDto':
|
||||||
|
@ -316,6 +314,8 @@ class ApiClient {
|
||||||
return ExifResponseDto.fromJson(value);
|
return ExifResponseDto.fromJson(value);
|
||||||
case 'FaceDto':
|
case 'FaceDto':
|
||||||
return FaceDto.fromJson(value);
|
return FaceDto.fromJson(value);
|
||||||
|
case 'FacialRecognitionConfig':
|
||||||
|
return FacialRecognitionConfig.fromJson(value);
|
||||||
case 'FileChecksumDto':
|
case 'FileChecksumDto':
|
||||||
return FileChecksumDto.fromJson(value);
|
return FileChecksumDto.fromJson(value);
|
||||||
case 'FileChecksumResponseDto':
|
case 'FileChecksumResponseDto':
|
||||||
|
@ -374,8 +374,6 @@ class ApiClient {
|
||||||
return MergePersonDto.fromJson(value);
|
return MergePersonDto.fromJson(value);
|
||||||
case 'MetadataSearchDto':
|
case 'MetadataSearchDto':
|
||||||
return MetadataSearchDto.fromJson(value);
|
return MetadataSearchDto.fromJson(value);
|
||||||
case 'ModelType':
|
|
||||||
return ModelTypeTypeTransformer().decode(value);
|
|
||||||
case 'OAuthAuthorizeResponseDto':
|
case 'OAuthAuthorizeResponseDto':
|
||||||
return OAuthAuthorizeResponseDto.fromJson(value);
|
return OAuthAuthorizeResponseDto.fromJson(value);
|
||||||
case 'OAuthCallbackDto':
|
case 'OAuthCallbackDto':
|
||||||
|
@ -414,8 +412,6 @@ class ApiClient {
|
||||||
return ReactionLevelTypeTransformer().decode(value);
|
return ReactionLevelTypeTransformer().decode(value);
|
||||||
case 'ReactionType':
|
case 'ReactionType':
|
||||||
return ReactionTypeTypeTransformer().decode(value);
|
return ReactionTypeTypeTransformer().decode(value);
|
||||||
case 'RecognitionConfig':
|
|
||||||
return RecognitionConfig.fromJson(value);
|
|
||||||
case 'ReverseGeocodingStateResponseDto':
|
case 'ReverseGeocodingStateResponseDto':
|
||||||
return ReverseGeocodingStateResponseDto.fromJson(value);
|
return ReverseGeocodingStateResponseDto.fromJson(value);
|
||||||
case 'ScanLibraryDto':
|
case 'ScanLibraryDto':
|
||||||
|
|
6
mobile/openapi/lib/api_helper.dart
generated
6
mobile/openapi/lib/api_helper.dart
generated
|
@ -76,9 +76,6 @@ String parameterToString(dynamic value) {
|
||||||
if (value is AudioCodec) {
|
if (value is AudioCodec) {
|
||||||
return AudioCodecTypeTransformer().encode(value).toString();
|
return AudioCodecTypeTransformer().encode(value).toString();
|
||||||
}
|
}
|
||||||
if (value is CLIPMode) {
|
|
||||||
return CLIPModeTypeTransformer().encode(value).toString();
|
|
||||||
}
|
|
||||||
if (value is CQMode) {
|
if (value is CQMode) {
|
||||||
return CQModeTypeTransformer().encode(value).toString();
|
return CQModeTypeTransformer().encode(value).toString();
|
||||||
}
|
}
|
||||||
|
@ -106,9 +103,6 @@ String parameterToString(dynamic value) {
|
||||||
if (value is MemoryType) {
|
if (value is MemoryType) {
|
||||||
return MemoryTypeTypeTransformer().encode(value).toString();
|
return MemoryTypeTypeTransformer().encode(value).toString();
|
||||||
}
|
}
|
||||||
if (value is ModelType) {
|
|
||||||
return ModelTypeTypeTransformer().encode(value).toString();
|
|
||||||
}
|
|
||||||
if (value is PathEntityType) {
|
if (value is PathEntityType) {
|
||||||
return PathEntityTypeTypeTransformer().encode(value).toString();
|
return PathEntityTypeTypeTransformer().encode(value).toString();
|
||||||
}
|
}
|
||||||
|
|
40
mobile/openapi/lib/model/clip_config.dart
generated
40
mobile/openapi/lib/model/clip_config.dart
generated
|
@ -14,63 +14,31 @@ class CLIPConfig {
|
||||||
/// Returns a new [CLIPConfig] instance.
|
/// Returns a new [CLIPConfig] instance.
|
||||||
CLIPConfig({
|
CLIPConfig({
|
||||||
required this.enabled,
|
required this.enabled,
|
||||||
this.mode,
|
|
||||||
required this.modelName,
|
required this.modelName,
|
||||||
this.modelType,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
bool enabled;
|
bool enabled;
|
||||||
|
|
||||||
///
|
|
||||||
/// Please note: This property should have been non-nullable! Since the specification file
|
|
||||||
/// does not include a default value (using the "default:" property), however, the generated
|
|
||||||
/// source code must fall back to having a nullable type.
|
|
||||||
/// Consider adding a "default:" property in the specification file to hide this note.
|
|
||||||
///
|
|
||||||
CLIPMode? mode;
|
|
||||||
|
|
||||||
String modelName;
|
String modelName;
|
||||||
|
|
||||||
///
|
|
||||||
/// Please note: This property should have been non-nullable! Since the specification file
|
|
||||||
/// does not include a default value (using the "default:" property), however, the generated
|
|
||||||
/// source code must fall back to having a nullable type.
|
|
||||||
/// Consider adding a "default:" property in the specification file to hide this note.
|
|
||||||
///
|
|
||||||
ModelType? modelType;
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
bool operator ==(Object other) => identical(this, other) || other is CLIPConfig &&
|
bool operator ==(Object other) => identical(this, other) || other is CLIPConfig &&
|
||||||
other.enabled == enabled &&
|
other.enabled == enabled &&
|
||||||
other.mode == mode &&
|
other.modelName == modelName;
|
||||||
other.modelName == modelName &&
|
|
||||||
other.modelType == modelType;
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
int get hashCode =>
|
int get hashCode =>
|
||||||
// ignore: unnecessary_parenthesis
|
// ignore: unnecessary_parenthesis
|
||||||
(enabled.hashCode) +
|
(enabled.hashCode) +
|
||||||
(mode == null ? 0 : mode!.hashCode) +
|
(modelName.hashCode);
|
||||||
(modelName.hashCode) +
|
|
||||||
(modelType == null ? 0 : modelType!.hashCode);
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
String toString() => 'CLIPConfig[enabled=$enabled, mode=$mode, modelName=$modelName, modelType=$modelType]';
|
String toString() => 'CLIPConfig[enabled=$enabled, modelName=$modelName]';
|
||||||
|
|
||||||
Map<String, dynamic> toJson() {
|
Map<String, dynamic> toJson() {
|
||||||
final json = <String, dynamic>{};
|
final json = <String, dynamic>{};
|
||||||
json[r'enabled'] = this.enabled;
|
json[r'enabled'] = this.enabled;
|
||||||
if (this.mode != null) {
|
|
||||||
json[r'mode'] = this.mode;
|
|
||||||
} else {
|
|
||||||
// json[r'mode'] = null;
|
|
||||||
}
|
|
||||||
json[r'modelName'] = this.modelName;
|
json[r'modelName'] = this.modelName;
|
||||||
if (this.modelType != null) {
|
|
||||||
json[r'modelType'] = this.modelType;
|
|
||||||
} else {
|
|
||||||
// json[r'modelType'] = null;
|
|
||||||
}
|
|
||||||
return json;
|
return json;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,9 +51,7 @@ class CLIPConfig {
|
||||||
|
|
||||||
return CLIPConfig(
|
return CLIPConfig(
|
||||||
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
||||||
mode: CLIPMode.fromJson(json[r'mode']),
|
|
||||||
modelName: mapValueOfType<String>(json, r'modelName')!,
|
modelName: mapValueOfType<String>(json, r'modelName')!,
|
||||||
modelType: ModelType.fromJson(json[r'modelType']),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
|
|
85
mobile/openapi/lib/model/clip_mode.dart
generated
85
mobile/openapi/lib/model/clip_mode.dart
generated
|
@ -1,85 +0,0 @@
|
||||||
//
|
|
||||||
// AUTO-GENERATED FILE, DO NOT MODIFY!
|
|
||||||
//
|
|
||||||
// @dart=2.18
|
|
||||||
|
|
||||||
// ignore_for_file: unused_element, unused_import
|
|
||||||
// ignore_for_file: always_put_required_named_parameters_first
|
|
||||||
// ignore_for_file: constant_identifier_names
|
|
||||||
// ignore_for_file: lines_longer_than_80_chars
|
|
||||||
|
|
||||||
part of openapi.api;
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPMode {
|
|
||||||
/// Instantiate a new enum with the provided [value].
|
|
||||||
const CLIPMode._(this.value);
|
|
||||||
|
|
||||||
/// The underlying value of this enum member.
|
|
||||||
final String value;
|
|
||||||
|
|
||||||
@override
|
|
||||||
String toString() => value;
|
|
||||||
|
|
||||||
String toJson() => value;
|
|
||||||
|
|
||||||
static const vision = CLIPMode._(r'vision');
|
|
||||||
static const text = CLIPMode._(r'text');
|
|
||||||
|
|
||||||
/// List of all possible values in this [enum][CLIPMode].
|
|
||||||
static const values = <CLIPMode>[
|
|
||||||
vision,
|
|
||||||
text,
|
|
||||||
];
|
|
||||||
|
|
||||||
static CLIPMode? fromJson(dynamic value) => CLIPModeTypeTransformer().decode(value);
|
|
||||||
|
|
||||||
static List<CLIPMode> listFromJson(dynamic json, {bool growable = false,}) {
|
|
||||||
final result = <CLIPMode>[];
|
|
||||||
if (json is List && json.isNotEmpty) {
|
|
||||||
for (final row in json) {
|
|
||||||
final value = CLIPMode.fromJson(row);
|
|
||||||
if (value != null) {
|
|
||||||
result.add(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result.toList(growable: growable);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Transformation class that can [encode] an instance of [CLIPMode] to String,
|
|
||||||
/// and [decode] dynamic data back to [CLIPMode].
|
|
||||||
class CLIPModeTypeTransformer {
|
|
||||||
factory CLIPModeTypeTransformer() => _instance ??= const CLIPModeTypeTransformer._();
|
|
||||||
|
|
||||||
const CLIPModeTypeTransformer._();
|
|
||||||
|
|
||||||
String encode(CLIPMode data) => data.value;
|
|
||||||
|
|
||||||
/// Decodes a [dynamic value][data] to a CLIPMode.
|
|
||||||
///
|
|
||||||
/// If [allowNull] is true and the [dynamic value][data] cannot be decoded successfully,
|
|
||||||
/// then null is returned. However, if [allowNull] is false and the [dynamic value][data]
|
|
||||||
/// cannot be decoded successfully, then an [UnimplementedError] is thrown.
|
|
||||||
///
|
|
||||||
/// The [allowNull] is very handy when an API changes and a new enum value is added or removed,
|
|
||||||
/// and users are still using an old app with the old code.
|
|
||||||
CLIPMode? decode(dynamic data, {bool allowNull = true}) {
|
|
||||||
if (data != null) {
|
|
||||||
switch (data) {
|
|
||||||
case r'vision': return CLIPMode.vision;
|
|
||||||
case r'text': return CLIPMode.text;
|
|
||||||
default:
|
|
||||||
if (!allowNull) {
|
|
||||||
throw ArgumentError('Unknown enum value to decode: $data');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Singleton [CLIPModeTypeTransformer] instance.
|
|
||||||
static CLIPModeTypeTransformer? _instance;
|
|
||||||
}
|
|
||||||
|
|
|
@ -10,15 +10,14 @@
|
||||||
|
|
||||||
part of openapi.api;
|
part of openapi.api;
|
||||||
|
|
||||||
class RecognitionConfig {
|
class FacialRecognitionConfig {
|
||||||
/// Returns a new [RecognitionConfig] instance.
|
/// Returns a new [FacialRecognitionConfig] instance.
|
||||||
RecognitionConfig({
|
FacialRecognitionConfig({
|
||||||
required this.enabled,
|
required this.enabled,
|
||||||
required this.maxDistance,
|
required this.maxDistance,
|
||||||
required this.minFaces,
|
required this.minFaces,
|
||||||
required this.minScore,
|
required this.minScore,
|
||||||
required this.modelName,
|
required this.modelName,
|
||||||
this.modelType,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
bool enabled;
|
bool enabled;
|
||||||
|
@ -36,22 +35,13 @@ class RecognitionConfig {
|
||||||
|
|
||||||
String modelName;
|
String modelName;
|
||||||
|
|
||||||
///
|
|
||||||
/// Please note: This property should have been non-nullable! Since the specification file
|
|
||||||
/// does not include a default value (using the "default:" property), however, the generated
|
|
||||||
/// source code must fall back to having a nullable type.
|
|
||||||
/// Consider adding a "default:" property in the specification file to hide this note.
|
|
||||||
///
|
|
||||||
ModelType? modelType;
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
bool operator ==(Object other) => identical(this, other) || other is RecognitionConfig &&
|
bool operator ==(Object other) => identical(this, other) || other is FacialRecognitionConfig &&
|
||||||
other.enabled == enabled &&
|
other.enabled == enabled &&
|
||||||
other.maxDistance == maxDistance &&
|
other.maxDistance == maxDistance &&
|
||||||
other.minFaces == minFaces &&
|
other.minFaces == minFaces &&
|
||||||
other.minScore == minScore &&
|
other.minScore == minScore &&
|
||||||
other.modelName == modelName &&
|
other.modelName == modelName;
|
||||||
other.modelType == modelType;
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
int get hashCode =>
|
int get hashCode =>
|
||||||
|
@ -60,11 +50,10 @@ class RecognitionConfig {
|
||||||
(maxDistance.hashCode) +
|
(maxDistance.hashCode) +
|
||||||
(minFaces.hashCode) +
|
(minFaces.hashCode) +
|
||||||
(minScore.hashCode) +
|
(minScore.hashCode) +
|
||||||
(modelName.hashCode) +
|
(modelName.hashCode);
|
||||||
(modelType == null ? 0 : modelType!.hashCode);
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
String toString() => 'RecognitionConfig[enabled=$enabled, maxDistance=$maxDistance, minFaces=$minFaces, minScore=$minScore, modelName=$modelName, modelType=$modelType]';
|
String toString() => 'FacialRecognitionConfig[enabled=$enabled, maxDistance=$maxDistance, minFaces=$minFaces, minScore=$minScore, modelName=$modelName]';
|
||||||
|
|
||||||
Map<String, dynamic> toJson() {
|
Map<String, dynamic> toJson() {
|
||||||
final json = <String, dynamic>{};
|
final json = <String, dynamic>{};
|
||||||
|
@ -73,38 +62,32 @@ class RecognitionConfig {
|
||||||
json[r'minFaces'] = this.minFaces;
|
json[r'minFaces'] = this.minFaces;
|
||||||
json[r'minScore'] = this.minScore;
|
json[r'minScore'] = this.minScore;
|
||||||
json[r'modelName'] = this.modelName;
|
json[r'modelName'] = this.modelName;
|
||||||
if (this.modelType != null) {
|
|
||||||
json[r'modelType'] = this.modelType;
|
|
||||||
} else {
|
|
||||||
// json[r'modelType'] = null;
|
|
||||||
}
|
|
||||||
return json;
|
return json;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a new [RecognitionConfig] instance and imports its values from
|
/// Returns a new [FacialRecognitionConfig] instance and imports its values from
|
||||||
/// [value] if it's a [Map], null otherwise.
|
/// [value] if it's a [Map], null otherwise.
|
||||||
// ignore: prefer_constructors_over_static_methods
|
// ignore: prefer_constructors_over_static_methods
|
||||||
static RecognitionConfig? fromJson(dynamic value) {
|
static FacialRecognitionConfig? fromJson(dynamic value) {
|
||||||
if (value is Map) {
|
if (value is Map) {
|
||||||
final json = value.cast<String, dynamic>();
|
final json = value.cast<String, dynamic>();
|
||||||
|
|
||||||
return RecognitionConfig(
|
return FacialRecognitionConfig(
|
||||||
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
||||||
maxDistance: mapValueOfType<double>(json, r'maxDistance')!,
|
maxDistance: mapValueOfType<double>(json, r'maxDistance')!,
|
||||||
minFaces: mapValueOfType<int>(json, r'minFaces')!,
|
minFaces: mapValueOfType<int>(json, r'minFaces')!,
|
||||||
minScore: mapValueOfType<double>(json, r'minScore')!,
|
minScore: mapValueOfType<double>(json, r'minScore')!,
|
||||||
modelName: mapValueOfType<String>(json, r'modelName')!,
|
modelName: mapValueOfType<String>(json, r'modelName')!,
|
||||||
modelType: ModelType.fromJson(json[r'modelType']),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
static List<RecognitionConfig> listFromJson(dynamic json, {bool growable = false,}) {
|
static List<FacialRecognitionConfig> listFromJson(dynamic json, {bool growable = false,}) {
|
||||||
final result = <RecognitionConfig>[];
|
final result = <FacialRecognitionConfig>[];
|
||||||
if (json is List && json.isNotEmpty) {
|
if (json is List && json.isNotEmpty) {
|
||||||
for (final row in json) {
|
for (final row in json) {
|
||||||
final value = RecognitionConfig.fromJson(row);
|
final value = FacialRecognitionConfig.fromJson(row);
|
||||||
if (value != null) {
|
if (value != null) {
|
||||||
result.add(value);
|
result.add(value);
|
||||||
}
|
}
|
||||||
|
@ -113,12 +96,12 @@ class RecognitionConfig {
|
||||||
return result.toList(growable: growable);
|
return result.toList(growable: growable);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Map<String, RecognitionConfig> mapFromJson(dynamic json) {
|
static Map<String, FacialRecognitionConfig> mapFromJson(dynamic json) {
|
||||||
final map = <String, RecognitionConfig>{};
|
final map = <String, FacialRecognitionConfig>{};
|
||||||
if (json is Map && json.isNotEmpty) {
|
if (json is Map && json.isNotEmpty) {
|
||||||
json = json.cast<String, dynamic>(); // ignore: parameter_assignments
|
json = json.cast<String, dynamic>(); // ignore: parameter_assignments
|
||||||
for (final entry in json.entries) {
|
for (final entry in json.entries) {
|
||||||
final value = RecognitionConfig.fromJson(entry.value);
|
final value = FacialRecognitionConfig.fromJson(entry.value);
|
||||||
if (value != null) {
|
if (value != null) {
|
||||||
map[entry.key] = value;
|
map[entry.key] = value;
|
||||||
}
|
}
|
||||||
|
@ -127,14 +110,14 @@ class RecognitionConfig {
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
// maps a json object with a list of RecognitionConfig-objects as value to a dart map
|
// maps a json object with a list of FacialRecognitionConfig-objects as value to a dart map
|
||||||
static Map<String, List<RecognitionConfig>> mapListFromJson(dynamic json, {bool growable = false,}) {
|
static Map<String, List<FacialRecognitionConfig>> mapListFromJson(dynamic json, {bool growable = false,}) {
|
||||||
final map = <String, List<RecognitionConfig>>{};
|
final map = <String, List<FacialRecognitionConfig>>{};
|
||||||
if (json is Map && json.isNotEmpty) {
|
if (json is Map && json.isNotEmpty) {
|
||||||
// ignore: parameter_assignments
|
// ignore: parameter_assignments
|
||||||
json = json.cast<String, dynamic>();
|
json = json.cast<String, dynamic>();
|
||||||
for (final entry in json.entries) {
|
for (final entry in json.entries) {
|
||||||
map[entry.key] = RecognitionConfig.listFromJson(entry.value, growable: growable,);
|
map[entry.key] = FacialRecognitionConfig.listFromJson(entry.value, growable: growable,);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return map;
|
return map;
|
85
mobile/openapi/lib/model/model_type.dart
generated
85
mobile/openapi/lib/model/model_type.dart
generated
|
@ -1,85 +0,0 @@
|
||||||
//
|
|
||||||
// AUTO-GENERATED FILE, DO NOT MODIFY!
|
|
||||||
//
|
|
||||||
// @dart=2.18
|
|
||||||
|
|
||||||
// ignore_for_file: unused_element, unused_import
|
|
||||||
// ignore_for_file: always_put_required_named_parameters_first
|
|
||||||
// ignore_for_file: constant_identifier_names
|
|
||||||
// ignore_for_file: lines_longer_than_80_chars
|
|
||||||
|
|
||||||
part of openapi.api;
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType {
|
|
||||||
/// Instantiate a new enum with the provided [value].
|
|
||||||
const ModelType._(this.value);
|
|
||||||
|
|
||||||
/// The underlying value of this enum member.
|
|
||||||
final String value;
|
|
||||||
|
|
||||||
@override
|
|
||||||
String toString() => value;
|
|
||||||
|
|
||||||
String toJson() => value;
|
|
||||||
|
|
||||||
static const facialRecognition = ModelType._(r'facial-recognition');
|
|
||||||
static const clip = ModelType._(r'clip');
|
|
||||||
|
|
||||||
/// List of all possible values in this [enum][ModelType].
|
|
||||||
static const values = <ModelType>[
|
|
||||||
facialRecognition,
|
|
||||||
clip,
|
|
||||||
];
|
|
||||||
|
|
||||||
static ModelType? fromJson(dynamic value) => ModelTypeTypeTransformer().decode(value);
|
|
||||||
|
|
||||||
static List<ModelType> listFromJson(dynamic json, {bool growable = false,}) {
|
|
||||||
final result = <ModelType>[];
|
|
||||||
if (json is List && json.isNotEmpty) {
|
|
||||||
for (final row in json) {
|
|
||||||
final value = ModelType.fromJson(row);
|
|
||||||
if (value != null) {
|
|
||||||
result.add(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result.toList(growable: growable);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Transformation class that can [encode] an instance of [ModelType] to String,
|
|
||||||
/// and [decode] dynamic data back to [ModelType].
|
|
||||||
class ModelTypeTypeTransformer {
|
|
||||||
factory ModelTypeTypeTransformer() => _instance ??= const ModelTypeTypeTransformer._();
|
|
||||||
|
|
||||||
const ModelTypeTypeTransformer._();
|
|
||||||
|
|
||||||
String encode(ModelType data) => data.value;
|
|
||||||
|
|
||||||
/// Decodes a [dynamic value][data] to a ModelType.
|
|
||||||
///
|
|
||||||
/// If [allowNull] is true and the [dynamic value][data] cannot be decoded successfully,
|
|
||||||
/// then null is returned. However, if [allowNull] is false and the [dynamic value][data]
|
|
||||||
/// cannot be decoded successfully, then an [UnimplementedError] is thrown.
|
|
||||||
///
|
|
||||||
/// The [allowNull] is very handy when an API changes and a new enum value is added or removed,
|
|
||||||
/// and users are still using an old app with the old code.
|
|
||||||
ModelType? decode(dynamic data, {bool allowNull = true}) {
|
|
||||||
if (data != null) {
|
|
||||||
switch (data) {
|
|
||||||
case r'facial-recognition': return ModelType.facialRecognition;
|
|
||||||
case r'clip': return ModelType.clip;
|
|
||||||
default:
|
|
||||||
if (!allowNull) {
|
|
||||||
throw ArgumentError('Unknown enum value to decode: $data');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Singleton [ModelTypeTypeTransformer] instance.
|
|
||||||
static ModelTypeTypeTransformer? _instance;
|
|
||||||
}
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ class SystemConfigMachineLearningDto {
|
||||||
|
|
||||||
bool enabled;
|
bool enabled;
|
||||||
|
|
||||||
RecognitionConfig facialRecognition;
|
FacialRecognitionConfig facialRecognition;
|
||||||
|
|
||||||
String url;
|
String url;
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ class SystemConfigMachineLearningDto {
|
||||||
clip: CLIPConfig.fromJson(json[r'clip'])!,
|
clip: CLIPConfig.fromJson(json[r'clip'])!,
|
||||||
duplicateDetection: DuplicateDetectionConfig.fromJson(json[r'duplicateDetection'])!,
|
duplicateDetection: DuplicateDetectionConfig.fromJson(json[r'duplicateDetection'])!,
|
||||||
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
enabled: mapValueOfType<bool>(json, r'enabled')!,
|
||||||
facialRecognition: RecognitionConfig.fromJson(json[r'facialRecognition'])!,
|
facialRecognition: FacialRecognitionConfig.fromJson(json[r'facialRecognition'])!,
|
||||||
url: mapValueOfType<String>(json, r'url')!,
|
url: mapValueOfType<String>(json, r'url')!,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -7878,14 +7878,8 @@
|
||||||
"enabled": {
|
"enabled": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
"mode": {
|
|
||||||
"$ref": "#/components/schemas/CLIPMode"
|
|
||||||
},
|
|
||||||
"modelName": {
|
"modelName": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
|
||||||
"modelType": {
|
|
||||||
"$ref": "#/components/schemas/ModelType"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
|
@ -7894,13 +7888,6 @@
|
||||||
],
|
],
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
"CLIPMode": {
|
|
||||||
"enum": [
|
|
||||||
"vision",
|
|
||||||
"text"
|
|
||||||
],
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"CQMode": {
|
"CQMode": {
|
||||||
"enum": [
|
"enum": [
|
||||||
"auto",
|
"auto",
|
||||||
|
@ -8323,6 +8310,40 @@
|
||||||
],
|
],
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
|
"FacialRecognitionConfig": {
|
||||||
|
"properties": {
|
||||||
|
"enabled": {
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"maxDistance": {
|
||||||
|
"format": "float",
|
||||||
|
"maximum": 2,
|
||||||
|
"minimum": 0,
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"minFaces": {
|
||||||
|
"minimum": 1,
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"minScore": {
|
||||||
|
"format": "float",
|
||||||
|
"maximum": 1,
|
||||||
|
"minimum": 0,
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"modelName": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"enabled",
|
||||||
|
"maxDistance",
|
||||||
|
"minFaces",
|
||||||
|
"minScore",
|
||||||
|
"modelName"
|
||||||
|
],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"FileChecksumDto": {
|
"FileChecksumDto": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"filenames": {
|
"filenames": {
|
||||||
|
@ -9039,13 +9060,6 @@
|
||||||
},
|
},
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
"ModelType": {
|
|
||||||
"enum": [
|
|
||||||
"facial-recognition",
|
|
||||||
"clip"
|
|
||||||
],
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"OAuthAuthorizeResponseDto": {
|
"OAuthAuthorizeResponseDto": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"url": {
|
"url": {
|
||||||
|
@ -9379,43 +9393,6 @@
|
||||||
],
|
],
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"RecognitionConfig": {
|
|
||||||
"properties": {
|
|
||||||
"enabled": {
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
"maxDistance": {
|
|
||||||
"format": "float",
|
|
||||||
"maximum": 2,
|
|
||||||
"minimum": 0,
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
"minFaces": {
|
|
||||||
"minimum": 1,
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
"minScore": {
|
|
||||||
"format": "float",
|
|
||||||
"maximum": 1,
|
|
||||||
"minimum": 0,
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
"modelName": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"modelType": {
|
|
||||||
"$ref": "#/components/schemas/ModelType"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": [
|
|
||||||
"enabled",
|
|
||||||
"maxDistance",
|
|
||||||
"minFaces",
|
|
||||||
"minScore",
|
|
||||||
"modelName"
|
|
||||||
],
|
|
||||||
"type": "object"
|
|
||||||
},
|
|
||||||
"ReverseGeocodingStateResponseDto": {
|
"ReverseGeocodingStateResponseDto": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"lastImportFileName": {
|
"lastImportFileName": {
|
||||||
|
@ -10521,7 +10498,7 @@
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
"facialRecognition": {
|
"facialRecognition": {
|
||||||
"$ref": "#/components/schemas/RecognitionConfig"
|
"$ref": "#/components/schemas/FacialRecognitionConfig"
|
||||||
},
|
},
|
||||||
"url": {
|
"url": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
|
|
@ -962,27 +962,24 @@ export type SystemConfigLoggingDto = {
|
||||||
};
|
};
|
||||||
export type ClipConfig = {
|
export type ClipConfig = {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
mode?: CLIPMode;
|
|
||||||
modelName: string;
|
modelName: string;
|
||||||
modelType?: ModelType;
|
|
||||||
};
|
};
|
||||||
export type DuplicateDetectionConfig = {
|
export type DuplicateDetectionConfig = {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
maxDistance: number;
|
maxDistance: number;
|
||||||
};
|
};
|
||||||
export type RecognitionConfig = {
|
export type FacialRecognitionConfig = {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
maxDistance: number;
|
maxDistance: number;
|
||||||
minFaces: number;
|
minFaces: number;
|
||||||
minScore: number;
|
minScore: number;
|
||||||
modelName: string;
|
modelName: string;
|
||||||
modelType?: ModelType;
|
|
||||||
};
|
};
|
||||||
export type SystemConfigMachineLearningDto = {
|
export type SystemConfigMachineLearningDto = {
|
||||||
clip: ClipConfig;
|
clip: ClipConfig;
|
||||||
duplicateDetection: DuplicateDetectionConfig;
|
duplicateDetection: DuplicateDetectionConfig;
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
facialRecognition: RecognitionConfig;
|
facialRecognition: FacialRecognitionConfig;
|
||||||
url: string;
|
url: string;
|
||||||
};
|
};
|
||||||
export type SystemConfigMapDto = {
|
export type SystemConfigMapDto = {
|
||||||
|
@ -3074,14 +3071,6 @@ export enum LogLevel {
|
||||||
Error = "error",
|
Error = "error",
|
||||||
Fatal = "fatal"
|
Fatal = "fatal"
|
||||||
}
|
}
|
||||||
export enum CLIPMode {
|
|
||||||
Vision = "vision",
|
|
||||||
Text = "text"
|
|
||||||
}
|
|
||||||
export enum ModelType {
|
|
||||||
FacialRecognition = "facial-recognition",
|
|
||||||
Clip = "clip"
|
|
||||||
}
|
|
||||||
export enum TimeBucketSize {
|
export enum TimeBucketSize {
|
||||||
Day = "DAY",
|
Day = "DAY",
|
||||||
Month = "MONTH"
|
Month = "MONTH"
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import { ApiProperty } from '@nestjs/swagger';
|
import { ApiProperty } from '@nestjs/swagger';
|
||||||
import { Type } from 'class-transformer';
|
import { Type } from 'class-transformer';
|
||||||
import { IsEnum, IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
|
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
|
||||||
import { CLIPMode, ModelType } from 'src/interfaces/machine-learning.interface';
|
import { ValidateBoolean } from 'src/validation';
|
||||||
import { Optional, ValidateBoolean } from 'src/validation';
|
|
||||||
|
|
||||||
export class TaskConfig {
|
export class TaskConfig {
|
||||||
@ValidateBoolean()
|
@ValidateBoolean()
|
||||||
|
@ -13,19 +12,9 @@ export class ModelConfig extends TaskConfig {
|
||||||
@IsString()
|
@IsString()
|
||||||
@IsNotEmpty()
|
@IsNotEmpty()
|
||||||
modelName!: string;
|
modelName!: string;
|
||||||
|
|
||||||
@IsEnum(ModelType)
|
|
||||||
@Optional()
|
|
||||||
@ApiProperty({ enumName: 'ModelType', enum: ModelType })
|
|
||||||
modelType?: ModelType;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export class CLIPConfig extends ModelConfig {
|
export class CLIPConfig extends ModelConfig {}
|
||||||
@IsEnum(CLIPMode)
|
|
||||||
@Optional()
|
|
||||||
@ApiProperty({ enumName: 'CLIPMode', enum: CLIPMode })
|
|
||||||
mode?: CLIPMode;
|
|
||||||
}
|
|
||||||
|
|
||||||
export class DuplicateDetectionConfig extends TaskConfig {
|
export class DuplicateDetectionConfig extends TaskConfig {
|
||||||
@IsNumber()
|
@IsNumber()
|
||||||
|
@ -36,7 +25,7 @@ export class DuplicateDetectionConfig extends TaskConfig {
|
||||||
maxDistance!: number;
|
maxDistance!: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class RecognitionConfig extends ModelConfig {
|
export class FacialRecognitionConfig extends ModelConfig {
|
||||||
@IsNumber()
|
@IsNumber()
|
||||||
@Min(0)
|
@Min(0)
|
||||||
@Max(1)
|
@Max(1)
|
||||||
|
|
|
@ -30,7 +30,7 @@ import {
|
||||||
TranscodePolicy,
|
TranscodePolicy,
|
||||||
VideoCodec,
|
VideoCodec,
|
||||||
} from 'src/config';
|
} from 'src/config';
|
||||||
import { CLIPConfig, DuplicateDetectionConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
|
import { CLIPConfig, DuplicateDetectionConfig, FacialRecognitionConfig } from 'src/dtos/model-config.dto';
|
||||||
import { ConcurrentQueueName, QueueName } from 'src/interfaces/job.interface';
|
import { ConcurrentQueueName, QueueName } from 'src/interfaces/job.interface';
|
||||||
import { ValidateBoolean, validateCronExpression } from 'src/validation';
|
import { ValidateBoolean, validateCronExpression } from 'src/validation';
|
||||||
|
|
||||||
|
@ -270,10 +270,10 @@ class SystemConfigMachineLearningDto {
|
||||||
@IsObject()
|
@IsObject()
|
||||||
duplicateDetection!: DuplicateDetectionConfig;
|
duplicateDetection!: DuplicateDetectionConfig;
|
||||||
|
|
||||||
@Type(() => RecognitionConfig)
|
@Type(() => FacialRecognitionConfig)
|
||||||
@ValidateNested()
|
@ValidateNested()
|
||||||
@IsObject()
|
@IsObject()
|
||||||
facialRecognition!: RecognitionConfig;
|
facialRecognition!: FacialRecognitionConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum MapTheme {
|
enum MapTheme {
|
||||||
|
|
|
@ -1,15 +1,5 @@
|
||||||
import { CLIPConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
|
|
||||||
|
|
||||||
export const IMachineLearningRepository = 'IMachineLearningRepository';
|
export const IMachineLearningRepository = 'IMachineLearningRepository';
|
||||||
|
|
||||||
export interface VisionModelInput {
|
|
||||||
imagePath: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface TextModelInput {
|
|
||||||
text: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface BoundingBox {
|
export interface BoundingBox {
|
||||||
x1: number;
|
x1: number;
|
||||||
y1: number;
|
y1: number;
|
||||||
|
@ -17,26 +7,51 @@ export interface BoundingBox {
|
||||||
y2: number;
|
y2: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface DetectFaceResult {
|
export enum ModelTask {
|
||||||
imageWidth: number;
|
FACIAL_RECOGNITION = 'facial-recognition',
|
||||||
imageHeight: number;
|
SEARCH = 'clip',
|
||||||
boundingBox: BoundingBox;
|
|
||||||
score: number;
|
|
||||||
embedding: number[];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum ModelType {
|
export enum ModelType {
|
||||||
FACIAL_RECOGNITION = 'facial-recognition',
|
DETECTION = 'detection',
|
||||||
CLIP = 'clip',
|
PIPELINE = 'pipeline',
|
||||||
|
RECOGNITION = 'recognition',
|
||||||
|
TEXTUAL = 'textual',
|
||||||
|
VISUAL = 'visual',
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum CLIPMode {
|
export type ModelPayload = { imagePath: string } | { text: string };
|
||||||
VISION = 'vision',
|
|
||||||
TEXT = 'text',
|
type ModelOptions = { modelName: string };
|
||||||
|
|
||||||
|
export type FaceDetectionOptions = ModelOptions & { minScore: number };
|
||||||
|
|
||||||
|
type VisualResponse = { imageHeight: number; imageWidth: number };
|
||||||
|
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
||||||
|
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
|
||||||
|
|
||||||
|
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
|
||||||
|
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
|
||||||
|
|
||||||
|
export type FacialRecognitionRequest = {
|
||||||
|
[ModelTask.FACIAL_RECOGNITION]: {
|
||||||
|
[ModelType.DETECTION]: FaceDetectionOptions;
|
||||||
|
[ModelType.RECOGNITION]: ModelOptions;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export interface Face {
|
||||||
|
boundingBox: BoundingBox;
|
||||||
|
embedding: number[];
|
||||||
|
score: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
||||||
|
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
||||||
|
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||||
|
|
||||||
export interface IMachineLearningRepository {
|
export interface IMachineLearningRepository {
|
||||||
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]>;
|
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||||
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]>;
|
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
||||||
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]>;
|
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,8 +37,6 @@ export interface SearchExploreItem<T> {
|
||||||
items: SearchExploreItemSet<T>;
|
items: SearchExploreItemSet<T>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type Embedding = number[];
|
|
||||||
|
|
||||||
export interface SearchAssetIDOptions {
|
export interface SearchAssetIDOptions {
|
||||||
checksum?: Buffer;
|
checksum?: Buffer;
|
||||||
deviceAssetId?: string;
|
deviceAssetId?: string;
|
||||||
|
@ -106,7 +104,7 @@ export interface SearchExifOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SearchEmbeddingOptions {
|
export interface SearchEmbeddingOptions {
|
||||||
embedding: Embedding;
|
embedding: number[];
|
||||||
userIds: string[];
|
userIds: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +152,7 @@ export interface FaceEmbeddingSearch extends SearchEmbeddingOptions {
|
||||||
|
|
||||||
export interface AssetDuplicateSearch {
|
export interface AssetDuplicateSearch {
|
||||||
assetId: string;
|
assetId: string;
|
||||||
embedding: Embedding;
|
embedding: number[];
|
||||||
maxDistance?: number;
|
maxDistance?: number;
|
||||||
type: AssetType;
|
type: AssetType;
|
||||||
userIds: string[];
|
userIds: string[];
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
import { Injectable } from '@nestjs/common';
|
import { Injectable } from '@nestjs/common';
|
||||||
import { readFile } from 'node:fs/promises';
|
import { readFile } from 'node:fs/promises';
|
||||||
import { CLIPConfig, ModelConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
|
import { CLIPConfig } from 'src/dtos/model-config.dto';
|
||||||
import {
|
import {
|
||||||
CLIPMode,
|
ClipTextualResponse,
|
||||||
DetectFaceResult,
|
ClipVisualResponse,
|
||||||
|
FaceDetectionOptions,
|
||||||
|
FacialRecognitionResponse,
|
||||||
IMachineLearningRepository,
|
IMachineLearningRepository,
|
||||||
|
MachineLearningRequest,
|
||||||
|
ModelPayload,
|
||||||
|
ModelTask,
|
||||||
ModelType,
|
ModelType,
|
||||||
TextModelInput,
|
|
||||||
VisionModelInput,
|
|
||||||
} from 'src/interfaces/machine-learning.interface';
|
} from 'src/interfaces/machine-learning.interface';
|
||||||
import { Instrumentation } from 'src/utils/instrumentation';
|
import { Instrumentation } from 'src/utils/instrumentation';
|
||||||
|
|
||||||
|
@ -16,8 +19,8 @@ const errorPrefix = 'Machine learning request';
|
||||||
@Instrumentation()
|
@Instrumentation()
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||||
private async predict<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
|
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
|
||||||
const formData = await this.getFormData(input, config);
|
const formData = await this.getFormData(payload, config);
|
||||||
|
|
||||||
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
|
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
|
||||||
(error: Error | any) => {
|
(error: Error | any) => {
|
||||||
|
@ -26,50 +29,46 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
||||||
);
|
);
|
||||||
|
|
||||||
if (res.status >= 400) {
|
if (res.status >= 400) {
|
||||||
const modelType = config.modelType ? ` for ${config.modelType.replace('-', ' ')}` : '';
|
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
|
||||||
throw new Error(`${errorPrefix}${modelType} failed with status ${res.status}: ${res.statusText}`);
|
|
||||||
}
|
}
|
||||||
return res.json();
|
return res.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
|
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
|
||||||
return this.predict<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
|
const request = {
|
||||||
|
[ModelTask.FACIAL_RECOGNITION]: {
|
||||||
|
[ModelType.DETECTION]: { modelName, minScore },
|
||||||
|
[ModelType.RECOGNITION]: { modelName },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const response = await this.predict<FacialRecognitionResponse>(url, { imagePath }, request);
|
||||||
|
return {
|
||||||
|
imageHeight: response.imageHeight,
|
||||||
|
imageWidth: response.imageWidth,
|
||||||
|
faces: response[ModelTask.FACIAL_RECOGNITION],
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
|
async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
|
||||||
return this.predict<number[]>(url, input, {
|
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
|
||||||
...config,
|
const response = await this.predict<ClipVisualResponse>(url, { imagePath }, request);
|
||||||
modelType: ModelType.CLIP,
|
return response[ModelTask.SEARCH];
|
||||||
mode: CLIPMode.VISION,
|
|
||||||
} as CLIPConfig);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
|
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
|
||||||
return this.predict<number[]>(url, input, {
|
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||||
...config,
|
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
|
||||||
modelType: ModelType.CLIP,
|
return response[ModelTask.SEARCH];
|
||||||
mode: CLIPMode.TEXT,
|
|
||||||
} as CLIPConfig);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
|
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
const { enabled, modelName, modelType, ...options } = config;
|
formData.append('entries', JSON.stringify(config));
|
||||||
if (!enabled) {
|
|
||||||
throw new Error(`${modelType} is not enabled`);
|
|
||||||
}
|
|
||||||
|
|
||||||
formData.append('modelName', modelName);
|
if ('imagePath' in payload) {
|
||||||
if (modelType) {
|
formData.append('image', new Blob([await readFile(payload.imagePath)]));
|
||||||
formData.append('modelType', modelType);
|
} else if ('text' in payload) {
|
||||||
}
|
formData.append('text', payload.text);
|
||||||
if (options) {
|
|
||||||
formData.append('options', JSON.stringify(options));
|
|
||||||
}
|
|
||||||
if ('imagePath' in input) {
|
|
||||||
formData.append('image', new Blob([await readFile(input.imagePath)]));
|
|
||||||
} else if ('text' in input) {
|
|
||||||
formData.append('text', input.text);
|
|
||||||
} else {
|
} else {
|
||||||
throw new Error('Invalid input');
|
throw new Error('Invalid input');
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import { IAssetRepository, WithoutProperty } from 'src/interfaces/asset.interfac
|
||||||
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
|
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
|
||||||
import { IJobRepository, JobName, JobStatus } from 'src/interfaces/job.interface';
|
import { IJobRepository, JobName, JobStatus } from 'src/interfaces/job.interface';
|
||||||
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||||
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
import { DetectedFaces, IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
||||||
import { IMediaRepository } from 'src/interfaces/media.interface';
|
import { IMediaRepository } from 'src/interfaces/media.interface';
|
||||||
import { IMoveRepository } from 'src/interfaces/move.interface';
|
import { IMoveRepository } from 'src/interfaces/move.interface';
|
||||||
import { IPersonRepository } from 'src/interfaces/person.interface';
|
import { IPersonRepository } from 'src/interfaces/person.interface';
|
||||||
|
@ -46,19 +46,21 @@ const responseDto: PersonResponseDto = {
|
||||||
|
|
||||||
const statistics = { assets: 3 };
|
const statistics = { assets: 3 };
|
||||||
|
|
||||||
const detectFaceMock = {
|
const detectFaceMock: DetectedFaces = {
|
||||||
assetId: 'asset-1',
|
faces: [
|
||||||
personId: 'person-1',
|
{
|
||||||
boundingBox: {
|
boundingBox: {
|
||||||
x1: 100,
|
x1: 100,
|
||||||
y1: 100,
|
y1: 100,
|
||||||
x2: 200,
|
x2: 200,
|
||||||
y2: 200,
|
y2: 200,
|
||||||
},
|
},
|
||||||
|
embedding: [1, 2, 3, 4],
|
||||||
|
score: 0.2,
|
||||||
|
},
|
||||||
|
],
|
||||||
imageHeight: 500,
|
imageHeight: 500,
|
||||||
imageWidth: 400,
|
imageWidth: 400,
|
||||||
embedding: [1, 2, 3, 4],
|
|
||||||
score: 0.2,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
describe(PersonService.name, () => {
|
describe(PersonService.name, () => {
|
||||||
|
@ -642,21 +644,13 @@ describe(PersonService.name, () => {
|
||||||
it('should handle no results', async () => {
|
it('should handle no results', async () => {
|
||||||
const start = Date.now();
|
const start = Date.now();
|
||||||
|
|
||||||
machineLearningMock.detectFaces.mockResolvedValue([]);
|
machineLearningMock.detectFaces.mockResolvedValue({ imageHeight: 500, imageWidth: 400, faces: [] });
|
||||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||||
await sut.handleDetectFaces({ id: assetStub.image.id });
|
await sut.handleDetectFaces({ id: assetStub.image.id });
|
||||||
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
|
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
|
||||||
'http://immich-machine-learning:3003',
|
'http://immich-machine-learning:3003',
|
||||||
{
|
assetStub.image.previewPath,
|
||||||
imagePath: assetStub.image.previewPath,
|
expect.objectContaining({ minScore: 0.7, modelName: 'buffalo_l' }),
|
||||||
},
|
|
||||||
{
|
|
||||||
enabled: true,
|
|
||||||
maxDistance: 0.5,
|
|
||||||
minScore: 0.7,
|
|
||||||
minFaces: 3,
|
|
||||||
modelName: 'buffalo_l',
|
|
||||||
},
|
|
||||||
);
|
);
|
||||||
expect(personMock.createFaces).not.toHaveBeenCalled();
|
expect(personMock.createFaces).not.toHaveBeenCalled();
|
||||||
expect(jobMock.queue).not.toHaveBeenCalled();
|
expect(jobMock.queue).not.toHaveBeenCalled();
|
||||||
|
@ -671,7 +665,7 @@ describe(PersonService.name, () => {
|
||||||
|
|
||||||
it('should create a face with no person and queue recognition job', async () => {
|
it('should create a face with no person and queue recognition job', async () => {
|
||||||
personMock.createFaces.mockResolvedValue([faceStub.face1.id]);
|
personMock.createFaces.mockResolvedValue([faceStub.face1.id]);
|
||||||
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
|
machineLearningMock.detectFaces.mockResolvedValue(detectFaceMock);
|
||||||
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
|
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
|
||||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||||
const face = {
|
const face = {
|
||||||
|
|
|
@ -333,26 +333,28 @@ export class PersonService {
|
||||||
return JobStatus.SKIPPED;
|
return JobStatus.SKIPPED;
|
||||||
}
|
}
|
||||||
|
|
||||||
const faces = await this.machineLearningRepository.detectFaces(
|
if (!asset.isVisible) {
|
||||||
|
return JobStatus.SKIPPED;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { imageHeight, imageWidth, faces } = await this.machineLearningRepository.detectFaces(
|
||||||
machineLearning.url,
|
machineLearning.url,
|
||||||
{ imagePath: asset.previewPath },
|
asset.previewPath,
|
||||||
machineLearning.facialRecognition,
|
machineLearning.facialRecognition,
|
||||||
);
|
);
|
||||||
|
|
||||||
this.logger.debug(`${faces.length} faces detected in ${asset.previewPath}`);
|
this.logger.debug(`${faces.length} faces detected in ${asset.previewPath}`);
|
||||||
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
|
|
||||||
|
|
||||||
if (faces.length > 0) {
|
if (faces.length > 0) {
|
||||||
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
|
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
|
||||||
|
|
||||||
const mappedFaces = faces.map((face) => ({
|
const mappedFaces = faces.map((face) => ({
|
||||||
assetId: asset.id,
|
assetId: asset.id,
|
||||||
embedding: face.embedding,
|
embedding: face.embedding,
|
||||||
imageHeight: face.imageHeight,
|
imageHeight,
|
||||||
imageWidth: face.imageWidth,
|
imageWidth,
|
||||||
boundingBoxX1: face.boundingBox.x1,
|
boundingBoxX1: face.boundingBox.x1,
|
||||||
boundingBoxX2: face.boundingBox.x2,
|
|
||||||
boundingBoxY1: face.boundingBox.y1,
|
boundingBoxY1: face.boundingBox.y1,
|
||||||
|
boundingBoxX2: face.boundingBox.x2,
|
||||||
boundingBoxY2: face.boundingBox.y2,
|
boundingBoxY2: face.boundingBox.y2,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
|
@ -102,12 +102,7 @@ export class SearchService {
|
||||||
|
|
||||||
const userIds = await this.getUserIdsToSearch(auth);
|
const userIds = await this.getUserIdsToSearch(auth);
|
||||||
|
|
||||||
const embedding = await this.machineLearning.encodeText(
|
const embedding = await this.machineLearning.encodeText(machineLearning.url, dto.query, machineLearning.clip);
|
||||||
machineLearning.url,
|
|
||||||
{ text: dto.query },
|
|
||||||
machineLearning.clip,
|
|
||||||
);
|
|
||||||
|
|
||||||
const page = dto.page ?? 1;
|
const page = dto.page ?? 1;
|
||||||
const size = dto.size || 100;
|
const size = dto.size || 100;
|
||||||
const { hasNextPage, items } = await this.searchRepository.searchSmart(
|
const { hasNextPage, items } = await this.searchRepository.searchSmart(
|
||||||
|
|
|
@ -108,8 +108,8 @@ describe(SmartInfoService.name, () => {
|
||||||
|
|
||||||
expect(machineMock.encodeImage).toHaveBeenCalledWith(
|
expect(machineMock.encodeImage).toHaveBeenCalledWith(
|
||||||
'http://immich-machine-learning:3003',
|
'http://immich-machine-learning:3003',
|
||||||
{ imagePath: assetStub.image.previewPath },
|
assetStub.image.previewPath,
|
||||||
{ enabled: true, modelName: 'ViT-B-32__openai' },
|
expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
|
||||||
);
|
);
|
||||||
expect(searchMock.upsert).toHaveBeenCalledWith(assetStub.image.id, [0.01, 0.02, 0.03]);
|
expect(searchMock.upsert).toHaveBeenCalledWith(assetStub.image.id, [0.01, 0.02, 0.03]);
|
||||||
});
|
});
|
||||||
|
|
|
@ -93,9 +93,9 @@ export class SmartInfoService {
|
||||||
return JobStatus.FAILED;
|
return JobStatus.FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
const clipEmbedding = await this.machineLearning.encodeImage(
|
const embedding = await this.machineLearning.encodeImage(
|
||||||
machineLearning.url,
|
machineLearning.url,
|
||||||
{ imagePath: asset.previewPath },
|
asset.previewPath,
|
||||||
machineLearning.clip,
|
machineLearning.clip,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ export class SmartInfoService {
|
||||||
await this.databaseRepository.wait(DatabaseLock.CLIPDimSize);
|
await this.databaseRepository.wait(DatabaseLock.CLIPDimSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.repository.upsert(asset.id, clipEmbedding);
|
await this.repository.upsert(asset.id, embedding);
|
||||||
|
|
||||||
return JobStatus.SUCCESS;
|
return JobStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue