from __future__ import annotations from pathlib import Path from typing import Any, NamedTuple from numpy import ascontiguousarray from ann.ann import Ann from app.schemas import ndarray_f32, ndarray_i32 from ..config import log, settings class AnnSession: """ Wrapper for ANN to be drop-in replacement for ONNX session. """ def __init__(self, model_path: Path): tuning_file = Path(settings.cache_folder) / "gpu-tuning.ann" with tuning_file.open(mode="a"): # make sure tuning file exists (without clearing contents) # once filled, the tuning file reduces the cost/time of the first # inference after model load by 10s of seconds pass self.ann = Ann(tuning_level=3, tuning_file=tuning_file.as_posix()) log.info("Loading ANN model %s ...", model_path) cache_file = model_path.with_suffix(".anncache") save = False if not cache_file.is_file(): save = True with cache_file.open(mode="a"): # create empty model cache file pass self.model = self.ann.load( model_path.as_posix(), save_cached_network=save, cached_network_path=cache_file.as_posix(), ) log.info("Loaded ANN model with ID %d", self.model) def __del__(self) -> None: self.ann.unload(self.model) log.info("Unloaded ANN model %d", self.model) self.ann.destroy() def get_inputs(self) -> list[AnnNode]: shapes = self.ann.input_shapes[self.model] return [AnnNode(None, s) for s in shapes] def get_outputs(self) -> list[AnnNode]: shapes = self.ann.output_shapes[self.model] return [AnnNode(None, s) for s in shapes] def run( self, output_names: list[str] | None, input_feed: dict[str, ndarray_f32] | dict[str, ndarray_i32], run_options: Any = None, ) -> list[ndarray_f32]: inputs: list[ndarray_f32] = [ascontiguousarray(v) for v in input_feed.values()] return self.ann.execute(self.model, inputs) class AnnNode(NamedTuple): name: str | None shape: tuple[int, ...]