mirror of
https://github.com/immich-app/immich.git
synced 2025-01-07 00:50:23 -05:00
fix(ml): armnn not being used (#10929)
* fix armnn not being used, move fallback handling to main, add tests * formatting
This commit is contained in:
parent
59aa347912
commit
f43721ec92
7 changed files with 111 additions and 44 deletions
|
@ -168,6 +168,12 @@ def warning() -> Iterator[mock.Mock]:
|
||||||
yield mocked
|
yield mocked
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def exception() -> Iterator[mock.Mock]:
|
||||||
|
with mock.patch.object(log, "exception") as mocked:
|
||||||
|
yield mocked
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def snapshot_download() -> Iterator[mock.Mock]:
|
def snapshot_download() -> Iterator[mock.Mock]:
|
||||||
with mock.patch("app.models.base.snapshot_download") as mocked:
|
with mock.patch("app.models.base.snapshot_download") as mocked:
|
||||||
|
|
|
@ -29,6 +29,7 @@ from .schemas import (
|
||||||
InferenceEntry,
|
InferenceEntry,
|
||||||
InferenceResponse,
|
InferenceResponse,
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
|
ModelFormat,
|
||||||
ModelIdentity,
|
ModelIdentity,
|
||||||
ModelTask,
|
ModelTask,
|
||||||
ModelType,
|
ModelType,
|
||||||
|
@ -195,6 +196,16 @@ async def load(model: InferenceModel) -> InferenceModel:
|
||||||
if model.load_attempts > 1:
|
if model.load_attempts > 1:
|
||||||
raise HTTPException(500, f"Failed to load model '{model.model_name}'")
|
raise HTTPException(500, f"Failed to load model '{model.model_name}'")
|
||||||
with lock:
|
with lock:
|
||||||
|
try:
|
||||||
|
model.load()
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
if model.model_format == ModelFormat.ONNX:
|
||||||
|
raise e
|
||||||
|
log.exception(e)
|
||||||
|
log.warning(
|
||||||
|
f"{model.model_format.upper()} is available, but model '{model.model_name}' does not support it."
|
||||||
|
)
|
||||||
|
model.model_format = ModelFormat.ONNX
|
||||||
model.load()
|
model.load()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ class InferenceModel(ABC):
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
cache_dir: Path | str | None = None,
|
cache_dir: Path | str | None = None,
|
||||||
preferred_format: ModelFormat | None = None,
|
model_format: ModelFormat | None = None,
|
||||||
session: ModelSession | None = None,
|
session: ModelSession | None = None,
|
||||||
**model_kwargs: Any,
|
**model_kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -31,7 +31,7 @@ class InferenceModel(ABC):
|
||||||
self.load_attempts = 0
|
self.load_attempts = 0
|
||||||
self.model_name = clean_name(model_name)
|
self.model_name = clean_name(model_name)
|
||||||
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
|
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
|
||||||
self.model_format = preferred_format if preferred_format is not None else self._model_format_default
|
self.model_format = model_format if model_format is not None else self._model_format_default
|
||||||
if session is not None:
|
if session is not None:
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ class InferenceModel(ABC):
|
||||||
self.load_attempts += 1
|
self.load_attempts += 1
|
||||||
|
|
||||||
self.download()
|
self.download()
|
||||||
attempt = f"Attempt #{self.load_attempts + 1} to load" if self.load_attempts else "Loading"
|
attempt = f"Attempt #{self.load_attempts} to load" if self.load_attempts > 1 else "Loading"
|
||||||
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
||||||
self.session = self._load()
|
self.session = self._load()
|
||||||
self.loaded = True
|
self.loaded = True
|
||||||
|
@ -101,6 +101,9 @@ class InferenceModel(ABC):
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def _make_session(self, model_path: Path) -> ModelSession:
|
def _make_session(self, model_path: Path) -> ModelSession:
|
||||||
|
if not model_path.is_file():
|
||||||
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
match model_path.suffix:
|
match model_path.suffix:
|
||||||
case ".armnn":
|
case ".armnn":
|
||||||
session: ModelSession = AnnSession(model_path)
|
session: ModelSession = AnnSession(model_path)
|
||||||
|
@ -144,17 +147,13 @@ class InferenceModel(ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_format(self) -> ModelFormat:
|
def model_format(self) -> ModelFormat:
|
||||||
return self._preferred_format
|
return self._model_format
|
||||||
|
|
||||||
@model_format.setter
|
@model_format.setter
|
||||||
def model_format(self, preferred_format: ModelFormat) -> None:
|
def model_format(self, model_format: ModelFormat) -> None:
|
||||||
log.debug(f"Setting preferred format to {preferred_format}")
|
log.debug(f"Setting model format to {model_format}")
|
||||||
self._preferred_format = preferred_format
|
self._model_format = model_format
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _model_format_default(self) -> ModelFormat:
|
def _model_format_default(self) -> ModelFormat:
|
||||||
prefer_ann = ann.ann.is_available and settings.ann
|
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX
|
||||||
ann_exists = (self.model_dir / "model.armnn").is_file()
|
|
||||||
if prefer_ann and not ann_exists:
|
|
||||||
log.warning(f"ARM NN is available, but '{self.model_name}' does not support ARM NN. Falling back to ONNX.")
|
|
||||||
return ModelFormat.ARMNN if prefer_ann and ann_exists else ModelFormat.ONNX
|
|
||||||
|
|
|
@ -22,11 +22,12 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _load(self) -> ModelSession:
|
def _load(self) -> ModelSession:
|
||||||
|
session = super()._load()
|
||||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||||
|
|
||||||
return super()._load()
|
return session
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _load_tokenizer(self) -> Tokenizer:
|
def _load_tokenizer(self) -> Tokenizer:
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -14,15 +13,9 @@ class FaceDetector(InferenceModel):
|
||||||
depends = []
|
depends = []
|
||||||
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
|
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None:
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
min_score: float = 0.7,
|
|
||||||
cache_dir: Path | str | None = None,
|
|
||||||
**model_kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
super().__init__(model_name, **model_kwargs)
|
||||||
|
|
||||||
def _load(self) -> ModelSession:
|
def _load(self) -> ModelSession:
|
||||||
session = self._make_session(self.model_path)
|
session = self._make_session(self.model_path)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from numpy.typing import NDArray
|
||||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from app.config import clean_name, log
|
from app.config import log
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
from app.models.transforms import decode_cv2
|
from app.models.transforms import decode_cv2
|
||||||
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
|
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
|
||||||
|
@ -20,20 +20,14 @@ class FaceRecognizer(InferenceModel):
|
||||||
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
|
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
|
||||||
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None:
|
||||||
self,
|
super().__init__(model_name, **model_kwargs)
|
||||||
model_name: str,
|
|
||||||
min_score: float = 0.7,
|
|
||||||
cache_dir: Path | str | None = None,
|
|
||||||
**model_kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(clean_name(model_name), cache_dir, **model_kwargs)
|
|
||||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||||
self.batch = self.model_format == ModelFormat.ONNX
|
self.batch = self.model_format == ModelFormat.ONNX
|
||||||
|
|
||||||
def _load(self) -> ModelSession:
|
def _load(self) -> ModelSession:
|
||||||
session = self._make_session(self.model_path)
|
session = self._make_session(self.model_path)
|
||||||
if self.model_format == ModelFormat.ONNX and not has_batch_axis(session):
|
if self.batch and not has_batch_axis(session):
|
||||||
self._add_batch_axis(self.model_path)
|
self._add_batch_axis(self.model_path)
|
||||||
session = self._make_session(self.model_path)
|
session = self._make_session(self.model_path)
|
||||||
self.model = ArcFaceONNX(
|
self.model = ArcFaceONNX(
|
||||||
|
|
|
@ -43,7 +43,7 @@ class TestBase:
|
||||||
|
|
||||||
assert encoder.cache_dir == cache_dir
|
assert encoder.cache_dir == cache_dir
|
||||||
|
|
||||||
def test_sets_default_preferred_format(self, mocker: MockerFixture) -> None:
|
def test_sets_default_model_format(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(settings, "ann", True)
|
mocker.patch.object(settings, "ann", True)
|
||||||
mocker.patch("ann.ann.is_available", False)
|
mocker.patch("ann.ann.is_available", False)
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ class TestBase:
|
||||||
|
|
||||||
assert encoder.model_format == ModelFormat.ONNX
|
assert encoder.model_format == ModelFormat.ONNX
|
||||||
|
|
||||||
def test_sets_default_preferred_format_to_armnn_if_available(self, path: mock.Mock, mocker: MockerFixture) -> None:
|
def test_sets_default_model_format_to_armnn_if_available(self, path: mock.Mock, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(settings, "ann", True)
|
mocker.patch.object(settings, "ann", True)
|
||||||
mocker.patch("ann.ann.is_available", True)
|
mocker.patch("ann.ann.is_available", True)
|
||||||
path.suffix = ".armnn"
|
path.suffix = ".armnn"
|
||||||
|
@ -60,11 +60,11 @@ class TestBase:
|
||||||
|
|
||||||
assert encoder.model_format == ModelFormat.ARMNN
|
assert encoder.model_format == ModelFormat.ARMNN
|
||||||
|
|
||||||
def test_sets_preferred_format_kwarg(self, mocker: MockerFixture) -> None:
|
def test_sets_model_format_kwarg(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(settings, "ann", False)
|
mocker.patch.object(settings, "ann", False)
|
||||||
mocker.patch("ann.ann.is_available", False)
|
mocker.patch("ann.ann.is_available", False)
|
||||||
|
|
||||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", preferred_format=ModelFormat.ARMNN)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", model_format=ModelFormat.ARMNN)
|
||||||
|
|
||||||
assert encoder.model_format == ModelFormat.ARMNN
|
assert encoder.model_format == ModelFormat.ARMNN
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ class TestBase:
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_download_downloads_armnn_if_preferred_format(self, snapshot_download: mock.Mock) -> None:
|
def test_download_downloads_armnn_if_preferred_format(self, snapshot_download: mock.Mock) -> None:
|
||||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai", preferred_format=ModelFormat.ARMNN)
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", model_format=ModelFormat.ARMNN)
|
||||||
encoder.download()
|
encoder.download()
|
||||||
|
|
||||||
snapshot_download.assert_called_once_with(
|
snapshot_download.assert_called_once_with(
|
||||||
|
@ -140,6 +140,19 @@ class TestBase:
|
||||||
ignore_patterns=[],
|
ignore_patterns=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_throws_exception_if_model_path_does_not_exist(
|
||||||
|
self, snapshot_download: mock.Mock, ort_session: mock.Mock, path: mock.Mock
|
||||||
|
) -> None:
|
||||||
|
path.return_value.__truediv__.return_value.__truediv__.return_value.is_file.return_value = False
|
||||||
|
|
||||||
|
encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir=path)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
encoder.load()
|
||||||
|
|
||||||
|
snapshot_download.assert_called_once()
|
||||||
|
ort_session.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("ort_session")
|
@pytest.mark.usefixtures("ort_session")
|
||||||
class TestOrtSession:
|
class TestOrtSession:
|
||||||
|
@ -467,16 +480,18 @@ class TestFaceRecognition:
|
||||||
assert isinstance(call_args[0][0], np.ndarray)
|
assert isinstance(call_args[0][0], np.ndarray)
|
||||||
assert call_args[0][0].shape == (112, 112, 3)
|
assert call_args[0][0].shape == (112, 112, 3)
|
||||||
|
|
||||||
def test_recognition_adds_batch_axis_for_ort(self, ort_session: mock.Mock, mocker: MockerFixture) -> None:
|
def test_recognition_adds_batch_axis_for_ort(
|
||||||
|
self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||||
|
) -> None:
|
||||||
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
||||||
update_dims = mocker.patch(
|
update_dims = mocker.patch(
|
||||||
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
||||||
)
|
)
|
||||||
mocker.patch("app.models.base.InferenceModel.download")
|
mocker.patch("app.models.base.InferenceModel.download")
|
||||||
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
||||||
|
|
||||||
ort_session.return_value.get_inputs.return_value = [SimpleNamespace(name="input.1", shape=(1, 3, 224, 224))]
|
ort_session.return_value.get_inputs.return_value = [SimpleNamespace(name="input.1", shape=(1, 3, 224, 224))]
|
||||||
ort_session.return_value.get_outputs.return_value = [SimpleNamespace(name="output.1", shape=(1, 800))]
|
ort_session.return_value.get_outputs.return_value = [SimpleNamespace(name="output.1", shape=(1, 800))]
|
||||||
|
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||||
|
|
||||||
proto = mock.Mock()
|
proto = mock.Mock()
|
||||||
|
|
||||||
|
@ -492,27 +507,30 @@ class TestFaceRecognition:
|
||||||
|
|
||||||
onnx.load.return_value = proto
|
onnx.load.return_value = proto
|
||||||
|
|
||||||
face_recognizer = FaceRecognizer("buffalo_s")
|
face_recognizer = FaceRecognizer("buffalo_s", cache_dir=path)
|
||||||
face_recognizer.load()
|
face_recognizer.load()
|
||||||
|
|
||||||
assert face_recognizer.batch is True
|
assert face_recognizer.batch is True
|
||||||
update_dims.assert_called_once_with(proto, {"input.1": ["batch", 3, 224, 224]}, {"output.1": ["batch", 800]})
|
update_dims.assert_called_once_with(proto, {"input.1": ["batch", 3, 224, 224]}, {"output.1": ["batch", 800]})
|
||||||
onnx.save.assert_called_once_with(update_dims.return_value, face_recognizer.model_path)
|
onnx.save.assert_called_once_with(update_dims.return_value, face_recognizer.model_path)
|
||||||
|
|
||||||
def test_recognition_does_not_add_batch_axis_if_exists(self, ort_session: mock.Mock, mocker: MockerFixture) -> None:
|
def test_recognition_does_not_add_batch_axis_if_exists(
|
||||||
|
self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||||
|
) -> None:
|
||||||
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
||||||
update_dims = mocker.patch(
|
update_dims = mocker.patch(
|
||||||
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
||||||
)
|
)
|
||||||
mocker.patch("app.models.base.InferenceModel.download")
|
mocker.patch("app.models.base.InferenceModel.download")
|
||||||
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
||||||
|
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||||
|
|
||||||
inputs = [SimpleNamespace(name="input.1", shape=("batch", 3, 224, 224))]
|
inputs = [SimpleNamespace(name="input.1", shape=("batch", 3, 224, 224))]
|
||||||
outputs = [SimpleNamespace(name="output.1", shape=("batch", 800))]
|
outputs = [SimpleNamespace(name="output.1", shape=("batch", 800))]
|
||||||
ort_session.return_value.get_inputs.return_value = inputs
|
ort_session.return_value.get_inputs.return_value = inputs
|
||||||
ort_session.return_value.get_outputs.return_value = outputs
|
ort_session.return_value.get_outputs.return_value = outputs
|
||||||
|
|
||||||
face_recognizer = FaceRecognizer("buffalo_s")
|
face_recognizer = FaceRecognizer("buffalo_s", cache_dir=path)
|
||||||
face_recognizer.load()
|
face_recognizer.load()
|
||||||
|
|
||||||
assert face_recognizer.batch is True
|
assert face_recognizer.batch is True
|
||||||
|
@ -520,6 +538,30 @@ class TestFaceRecognition:
|
||||||
onnx.load.assert_not_called()
|
onnx.load.assert_not_called()
|
||||||
onnx.save.assert_not_called()
|
onnx.save.assert_not_called()
|
||||||
|
|
||||||
|
def test_recognition_does_not_add_batch_axis_for_armnn(
|
||||||
|
self, ann_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||||
|
) -> None:
|
||||||
|
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
||||||
|
update_dims = mocker.patch(
|
||||||
|
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
||||||
|
)
|
||||||
|
mocker.patch("app.models.base.InferenceModel.download")
|
||||||
|
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
||||||
|
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".armnn"
|
||||||
|
|
||||||
|
inputs = [SimpleNamespace(name="input.1", shape=("batch", 3, 224, 224))]
|
||||||
|
outputs = [SimpleNamespace(name="output.1", shape=("batch", 800))]
|
||||||
|
ann_session.return_value.get_inputs.return_value = inputs
|
||||||
|
ann_session.return_value.get_outputs.return_value = outputs
|
||||||
|
|
||||||
|
face_recognizer = FaceRecognizer("buffalo_s", model_format=ModelFormat.ARMNN, cache_dir=path)
|
||||||
|
face_recognizer.load()
|
||||||
|
|
||||||
|
assert face_recognizer.batch is False
|
||||||
|
update_dims.assert_not_called()
|
||||||
|
onnx.load.assert_not_called()
|
||||||
|
onnx.save.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestCache:
|
class TestCache:
|
||||||
|
@ -693,7 +735,7 @@ class TestLoad:
|
||||||
mock_model.clear_cache.assert_called_once()
|
mock_model.clear_cache.assert_called_once()
|
||||||
assert mock_model.load.call_count == 2
|
assert mock_model.load.call_count == 2
|
||||||
|
|
||||||
async def test_load_clears_cache_and_raises_if_os_error_and_already_retried(self) -> None:
|
async def test_load_raises_if_os_error_and_already_retried(self) -> None:
|
||||||
mock_model = mock.Mock(spec=InferenceModel)
|
mock_model = mock.Mock(spec=InferenceModel)
|
||||||
mock_model.model_name = "test_model_name"
|
mock_model.model_name = "test_model_name"
|
||||||
mock_model.model_type = ModelType.VISUAL
|
mock_model.model_type = ModelType.VISUAL
|
||||||
|
@ -707,6 +749,27 @@ class TestLoad:
|
||||||
mock_model.clear_cache.assert_not_called()
|
mock_model.clear_cache.assert_not_called()
|
||||||
mock_model.load.assert_not_called()
|
mock_model.load.assert_not_called()
|
||||||
|
|
||||||
|
async def test_falls_back_to_onnx_if_other_format_does_not_exist(
|
||||||
|
self, exception: mock.Mock, warning: mock.Mock
|
||||||
|
) -> None:
|
||||||
|
mock_model = mock.Mock(spec=InferenceModel)
|
||||||
|
mock_model.model_name = "test_model_name"
|
||||||
|
mock_model.model_type = ModelType.VISUAL
|
||||||
|
mock_model.model_task = ModelTask.SEARCH
|
||||||
|
mock_model.model_format = ModelFormat.ARMNN
|
||||||
|
mock_model.loaded = False
|
||||||
|
mock_model.load_attempts = 0
|
||||||
|
error = FileNotFoundError()
|
||||||
|
mock_model.load.side_effect = [error, None]
|
||||||
|
|
||||||
|
await load(mock_model)
|
||||||
|
|
||||||
|
mock_model.clear_cache.assert_not_called()
|
||||||
|
assert mock_model.load.call_count == 2
|
||||||
|
exception.assert_called_once_with(error)
|
||||||
|
warning.assert_called_once_with("ARMNN is available, but model 'test_model_name' does not support it.")
|
||||||
|
mock_model.model_format = ModelFormat.ONNX
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not settings.test_full,
|
not settings.test_full,
|
||||||
|
|
Loading…
Reference in a new issue