mirror of
https://github.com/immich-app/immich.git
synced 2025-02-04 01:09:14 -05:00
add tests
This commit is contained in:
parent
259386cf13
commit
717961ce7b
1 changed files with 67 additions and 0 deletions
|
@ -7,6 +7,7 @@ from types import SimpleNamespace
|
|||
from typing import Any, Callable
|
||||
from unittest import mock
|
||||
|
||||
from app.models.session import ort_add_batch_dim, ort_has_batch_dim, ort_squeeze_outputs
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
@ -734,3 +735,69 @@ class TestEndpoints:
|
|||
assert expected_face["boundingBox"] == actual_face["boundingBox"]
|
||||
assert np.allclose(expected_face["embedding"], actual_face["embedding"])
|
||||
assert np.allclose(expected_face["score"], actual_face["score"])
|
||||
|
||||
|
||||
class TestSessionUtils:
|
||||
def test_ort_has_batch_dim(self, mocker: MockerFixture) -> None:
|
||||
mock_session = mocker.Mock(spec=ort.InferenceSession)
|
||||
mock_session.get_inputs.return_value = [SimpleNamespace(shape=["batch", 3, 224, 224], name="input.1")]
|
||||
|
||||
assert ort_has_batch_dim(mock_session) is True
|
||||
|
||||
def test_ort_has_no_batch_dim(self, mocker: MockerFixture) -> None:
|
||||
mock_session = mocker.Mock(spec=ort.InferenceSession)
|
||||
mock_session.get_inputs.return_value = [SimpleNamespace(shape=[1, 3, 224, 224], name="input.1")]
|
||||
|
||||
assert ort_has_batch_dim(mock_session) is False
|
||||
|
||||
def test_ort_squeeze_outputs(self, mocker: MockerFixture) -> None:
|
||||
mock_session = mocker.Mock(spec=ort.InferenceSession)
|
||||
mock_session.run.return_value = [np.random.rand(1, 3, 224, 224).astype(np.float32)]
|
||||
|
||||
ort_squeeze_outputs(mock_session)
|
||||
out = mock_session.run(["output"], {"input": np.random.rand(1, 3, 224, 224).astype(np.float32)})
|
||||
|
||||
assert len(out) == 1
|
||||
assert out[0].shape == (3, 224, 224)
|
||||
|
||||
def test_ort_add_batch_dim(self, mocker: MockerFixture) -> None:
|
||||
mock_proto = mocker.Mock()
|
||||
mock_proto.graph.input = [
|
||||
SimpleNamespace(name="input", type=mock.Mock(tensor_type=SimpleNamespace(shape=SimpleNamespace())))
|
||||
]
|
||||
mock_proto.graph.output = [
|
||||
SimpleNamespace(name="output", type=SimpleNamespace(tensor_type=SimpleNamespace(shape=SimpleNamespace())))
|
||||
]
|
||||
mock_proto.graph.input[0].type.tensor_type.shape.dim = [
|
||||
SimpleNamespace(dim_value=3),
|
||||
SimpleNamespace(dim_value=224),
|
||||
SimpleNamespace(dim_value=224),
|
||||
]
|
||||
mock_proto.graph.output[0].type.tensor_type.shape.dim = [
|
||||
SimpleNamespace(dim_value=3),
|
||||
SimpleNamespace(dim_value=224),
|
||||
SimpleNamespace(dim_value=224),
|
||||
]
|
||||
|
||||
mock_load = mocker.patch("app.models.session.onnx.load")
|
||||
mock_load.return_value = mock_proto
|
||||
|
||||
mock_update_dims = mocker.patch("app.models.session.update_inputs_outputs_dims")
|
||||
mock_updated = mocker.Mock()
|
||||
mock_update_dims.return_value = mock_updated
|
||||
|
||||
mock_infer_shapes = mocker.patch("app.models.session.infer_shapes")
|
||||
mock_inferred = mocker.Mock()
|
||||
mock_infer_shapes.return_value = mock_inferred
|
||||
|
||||
mock_save = mocker.patch("app.models.session.onnx.save")
|
||||
|
||||
input_path, output_path = Path("input.onnx"), Path("output.onnx")
|
||||
ort_add_batch_dim(input_path, output_path)
|
||||
|
||||
mock_load.assert_called_once_with(input_path)
|
||||
mock_save.assert_called_once_with(mock_inferred, output_path)
|
||||
mock_update_dims.assert_called_once_with(mock_proto, mock.ANY, mock.ANY)
|
||||
assert mock_update_dims.call_args_list == [
|
||||
mock.call(mock_proto, {"input": ["batch", 224, 224]}, {"output": ["batch", 224, 224]})
|
||||
]
|
||||
|
|
Loading…
Add table
Reference in a new issue