0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-07 00:50:23 -05:00
immich/machine-learning/app/main.py
Mert 2b1b43a7e4
feat(ml): composable ml (#9973)
* modularize model classes

* various fixes

* expose port

* change response

* round coordinates

* simplify preload

* update server

* simplify interface

simplify

* update tests

* composable endpoint

* cleanup

fixes

remove unnecessary interface

support text input, cleanup

* ew camelcase

* update server

server fixes

fix typing

* ml fixes

update locustfile

fixes

* cleaner response

* better repo response

* update tests

formatting and typing

rename

* undo compose change

* linting

fix type

actually fix typing

* stricter typing

fix detection-only response

no need for defaultdict

* update spec file

update api

linting

* update e2e

* unnecessary dimension

* remove commented code

* remove duplicate code

* remove unused imports

* add batch dim
2024-06-07 03:09:47 +00:00

226 lines
7.4 KiB
Python

import asyncio
import gc
import os
import signal
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from functools import partial
from typing import Any, AsyncGenerator, Callable, Iterator
from zipfile import BadZipFile
import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
from pydantic import ValidationError
from starlette.formparsers import MultiPartParser
from app.models import get_model_deps
from app.models.base import InferenceModel
from app.models.transforms import decode_pil
from .config import PreloadModelData, log, settings
from .models.cache import ModelCache
from .schemas import (
InferenceEntries,
InferenceEntry,
InferenceResponse,
MessageResponse,
ModelIdentity,
ModelTask,
ModelType,
PipelineRequest,
T,
TextResponse,
)
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
last_called: float | None = None
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
global thread_pool
log.info(
(
"Created in-memory cache with unloading "
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
)
)
try:
if settings.request_threads > 0:
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
if settings.preload is not None:
await preload_models(settings.preload)
yield
finally:
log.handlers.clear()
for model in model_cache.cache._cache.values():
del model
if thread_pool is not None:
thread_pool.shutdown()
gc.collect()
async def preload_models(preload: PreloadModelData) -> None:
log.info(f"Preloading models: {preload}")
if preload.clip is not None:
model = await model_cache.get(preload.clip, ModelType.TEXTUAL, ModelTask.SEARCH)
await load(model)
model = await model_cache.get(preload.clip, ModelType.VISUAL, ModelTask.SEARCH)
await load(model)
if preload.facial_recognition is not None:
model = await model_cache.get(preload.facial_recognition, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
await load(model)
model = await model_cache.get(preload.facial_recognition, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
await load(model)
def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
last_called = time.time()
try:
yield
finally:
active_requests -= 1
def get_entries(entries: str = Form()) -> InferenceEntries:
try:
request: PipelineRequest = orjson.loads(entries)
without_deps: list[InferenceEntry] = []
with_deps: list[InferenceEntry] = []
for task, types in request.items():
for type, entry in types.items():
parsed: InferenceEntry = {
"name": entry["modelName"],
"task": task,
"type": type,
"options": entry.get("options", {}),
}
dep = get_model_deps(parsed["name"], type, task)
(with_deps if dep else without_deps).append(parsed)
return without_deps, with_deps
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
log.error(f"Invalid request format: {e}")
raise HTTPException(422, "Invalid request format.")
app = FastAPI(lifespan=lifespan)
@app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]:
return {"message": "Immich ML"}
@app.get("/ping", response_model=TextResponse)
def ping() -> str:
return "pong"
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
entries: InferenceEntries = Depends(get_entries),
image: bytes | None = File(default=None),
text: str | None = Form(default=None),
) -> Any:
if image is not None:
inputs: Image | str = await run(lambda: decode_pil(image))
elif text is not None:
inputs = text
else:
raise HTTPException(400, "Either image or text must be provided")
response = await run_inference(inputs, entries)
return ORJSONResponse(response)
async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
outputs: dict[ModelIdentity, Any] = {}
response: InferenceResponse = {}
async def _run_inference(entry: InferenceEntry) -> None:
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
inputs = [payload]
for dep in model.depends:
try:
inputs.append(outputs[dep])
except KeyError:
message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
raise HTTPException(400, message)
model = await load(model)
output = await run(model.predict, *inputs, **entry["options"])
outputs[model.identity] = output
response[entry["task"]] = output
without_deps, with_deps = entries
await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
if with_deps:
await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
if isinstance(payload, Image):
response["imageHeight"], response["imageWidth"] = payload.height, payload.width
return response
async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
if thread_pool is None:
return func(*args, **kwargs)
partial_func = partial(func, *args, **kwargs)
return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
async def load(model: InferenceModel) -> InferenceModel:
if model.loaded:
return model
def _load(model: InferenceModel) -> InferenceModel:
with lock:
model.load()
return model
try:
await run(_load, model)
return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warning(
(
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
"Clearing cache and retrying."
)
)
model.clear_cache()
await run(_load, model)
return model
async def idle_shutdown_task() -> None:
while True:
log.debug("Checking for inactivity...")
if (
last_called is not None
and not active_requests
and not lock.locked()
and time.time() - last_called > settings.model_ttl
):
log.info("Shutting down due to inactivity.")
os.kill(os.getpid(), signal.SIGINT)
break
await asyncio.sleep(settings.model_ttl_poll_s)