0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-07 00:50:23 -05:00
immich/machine-learning/app/models/cache.py
Mert 848ba685eb
fix(ml): race condition when loading models (#3207)
* sync model loading, disabled model ttl by default

* disable revalidation if model unloading disabled

* moved lock
2023-07-11 12:01:21 -05:00

97 lines
3.1 KiB
Python

from typing import Any
from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock
from aiocache.plugins import BasePlugin, TimingPlugin
from ..schemas import ModelType
from .base import InferenceModel
class ModelCache:
"""Fetches a model from an in-memory cache, instantiating it if it's missing."""
def __init__(
self,
ttl: float | None = None,
revalidate: bool = False,
timeout: int | None = None,
profiling: bool = False,
):
"""
Args:
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
"""
self.ttl = ttl
plugins = []
if revalidate:
plugins.append(RevalidationPlugin())
if profiling:
plugins.append(TimingPlugin())
self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
"""
Args:
model_name: Name of model in the model hub used for the task.
model_type: Model type or task, which determines which model zoo is used.
Returns:
model: The requested model.
"""
key = self.cache.build_key(model_name, model_type.value)
async with OptimisticLock(self.cache, key) as lock:
model = await self.cache.get(key)
if model is None:
model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl)
return model
async def get_profiling(self) -> dict[str, float] | None:
if not hasattr(self.cache, "profiling"):
return None
return self.cache.profiling # type: ignore
class RevalidationPlugin(BasePlugin):
"""Revalidates cache item's TTL after cache hit."""
async def post_get(
self,
client: SimpleMemoryCache,
key: str,
ret: Any | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
if ret is None:
return
if namespace is not None:
key = client.build_key(key, namespace)
if key in client._handlers:
await client.expire(key, client.ttl)
async def post_multi_get(
self,
client: SimpleMemoryCache,
keys: list[str],
ret: list[Any] | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
if ret is None:
return
for key, val in zip(keys, ret):
if namespace is not None:
key = client.build_key(key, namespace)
if val is not None and key in client._handlers:
await client.expire(key, client.ttl)