2023-06-27 18:21:33 -05:00
|
|
|
from typing import Any
|
2023-06-24 22:18:09 -05:00
|
|
|
|
2023-06-06 20:48:51 -05:00
|
|
|
from aiocache.backends.memory import SimpleMemoryCache
|
|
|
|
from aiocache.lock import OptimisticLock
|
2024-03-03 19:48:56 -05:00
|
|
|
from aiocache.plugins import TimingPlugin
|
2023-06-24 22:18:09 -05:00
|
|
|
|
2023-10-31 05:02:04 -05:00
|
|
|
from app.models import from_model_type
|
|
|
|
|
2023-11-13 11:18:46 -05:00
|
|
|
from ..schemas import ModelType, has_profiling
|
2023-06-24 22:18:09 -05:00
|
|
|
from .base import InferenceModel
|
2023-06-06 20:48:51 -05:00
|
|
|
|
|
|
|
|
|
|
|
class ModelCache:
|
|
|
|
"""Fetches a model from an in-memory cache, instantiating it if it's missing."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
revalidate: bool = False,
|
|
|
|
timeout: int | None = None,
|
|
|
|
profiling: bool = False,
|
2023-09-09 04:02:44 -05:00
|
|
|
) -> None:
|
2023-06-06 20:48:51 -05:00
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
|
|
|
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
|
|
|
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
|
|
|
"""
|
|
|
|
|
|
|
|
plugins = []
|
|
|
|
|
|
|
|
if profiling:
|
|
|
|
plugins.append(TimingPlugin())
|
|
|
|
|
2024-03-03 19:48:56 -05:00
|
|
|
self.revalidate_enable = revalidate
|
|
|
|
|
|
|
|
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
|
2023-06-06 20:48:51 -05:00
|
|
|
|
2023-06-27 18:21:33 -05:00
|
|
|
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
|
2023-06-06 20:48:51 -05:00
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
model_name: Name of model in the model hub used for the task.
|
|
|
|
model_type: Model type or task, which determines which model zoo is used.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
model: The requested model.
|
|
|
|
"""
|
|
|
|
|
2023-08-24 23:28:51 -05:00
|
|
|
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
2024-03-03 19:48:56 -05:00
|
|
|
|
2023-07-11 12:01:21 -05:00
|
|
|
async with OptimisticLock(self.cache, key) as lock:
|
2023-11-13 11:18:46 -05:00
|
|
|
model: InferenceModel | None = await self.cache.get(key)
|
2023-07-11 12:01:21 -05:00
|
|
|
if model is None:
|
2023-10-31 05:02:04 -05:00
|
|
|
model = from_model_type(model_type, model_name, **model_kwargs)
|
2024-03-03 19:48:56 -05:00
|
|
|
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
|
|
|
|
elif self.revalidate_enable:
|
|
|
|
await self.revalidate(key, model_kwargs.get("ttl", None))
|
2023-06-06 20:48:51 -05:00
|
|
|
return model
|
|
|
|
|
|
|
|
async def get_profiling(self) -> dict[str, float] | None:
|
2023-11-13 11:18:46 -05:00
|
|
|
if not has_profiling(self.cache):
|
2023-06-06 20:48:51 -05:00
|
|
|
return None
|
|
|
|
|
2023-11-13 11:18:46 -05:00
|
|
|
return self.cache.profiling
|
2023-06-06 20:48:51 -05:00
|
|
|
|
2024-03-03 19:48:56 -05:00
|
|
|
async def revalidate(self, key: str, ttl: int | None) -> None:
|
|
|
|
if ttl is not None and key in self.cache._handlers:
|
|
|
|
await self.cache.expire(key, ttl)
|