0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-02-04 01:09:14 -05:00

add tests

This commit is contained in:
mertalev 2024-06-07 00:28:07 -04:00
parent 259386cf13
commit 717961ce7b
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3

View file

@ -7,6 +7,7 @@ from types import SimpleNamespace
from typing import Any, Callable from typing import Any, Callable
from unittest import mock from unittest import mock
from app.models.session import ort_add_batch_dim, ort_has_batch_dim, ort_squeeze_outputs
import cv2 import cv2
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
@ -734,3 +735,69 @@ class TestEndpoints:
assert expected_face["boundingBox"] == actual_face["boundingBox"] assert expected_face["boundingBox"] == actual_face["boundingBox"]
assert np.allclose(expected_face["embedding"], actual_face["embedding"]) assert np.allclose(expected_face["embedding"], actual_face["embedding"])
assert np.allclose(expected_face["score"], actual_face["score"]) assert np.allclose(expected_face["score"], actual_face["score"])
class TestSessionUtils:
def test_ort_has_batch_dim(self, mocker: MockerFixture) -> None:
mock_session = mocker.Mock(spec=ort.InferenceSession)
mock_session.get_inputs.return_value = [SimpleNamespace(shape=["batch", 3, 224, 224], name="input.1")]
assert ort_has_batch_dim(mock_session) is True
def test_ort_has_no_batch_dim(self, mocker: MockerFixture) -> None:
mock_session = mocker.Mock(spec=ort.InferenceSession)
mock_session.get_inputs.return_value = [SimpleNamespace(shape=[1, 3, 224, 224], name="input.1")]
assert ort_has_batch_dim(mock_session) is False
def test_ort_squeeze_outputs(self, mocker: MockerFixture) -> None:
mock_session = mocker.Mock(spec=ort.InferenceSession)
mock_session.run.return_value = [np.random.rand(1, 3, 224, 224).astype(np.float32)]
ort_squeeze_outputs(mock_session)
out = mock_session.run(["output"], {"input": np.random.rand(1, 3, 224, 224).astype(np.float32)})
assert len(out) == 1
assert out[0].shape == (3, 224, 224)
def test_ort_add_batch_dim(self, mocker: MockerFixture) -> None:
mock_proto = mocker.Mock()
mock_proto.graph.input = [
SimpleNamespace(name="input", type=mock.Mock(tensor_type=SimpleNamespace(shape=SimpleNamespace())))
]
mock_proto.graph.output = [
SimpleNamespace(name="output", type=SimpleNamespace(tensor_type=SimpleNamespace(shape=SimpleNamespace())))
]
mock_proto.graph.input[0].type.tensor_type.shape.dim = [
SimpleNamespace(dim_value=3),
SimpleNamespace(dim_value=224),
SimpleNamespace(dim_value=224),
]
mock_proto.graph.output[0].type.tensor_type.shape.dim = [
SimpleNamespace(dim_value=3),
SimpleNamespace(dim_value=224),
SimpleNamespace(dim_value=224),
]
mock_load = mocker.patch("app.models.session.onnx.load")
mock_load.return_value = mock_proto
mock_update_dims = mocker.patch("app.models.session.update_inputs_outputs_dims")
mock_updated = mocker.Mock()
mock_update_dims.return_value = mock_updated
mock_infer_shapes = mocker.patch("app.models.session.infer_shapes")
mock_inferred = mocker.Mock()
mock_infer_shapes.return_value = mock_inferred
mock_save = mocker.patch("app.models.session.onnx.save")
input_path, output_path = Path("input.onnx"), Path("output.onnx")
ort_add_batch_dim(input_path, output_path)
mock_load.assert_called_once_with(input_path)
mock_save.assert_called_once_with(mock_inferred, output_path)
mock_update_dims.assert_called_once_with(mock_proto, mock.ANY, mock.ANY)
assert mock_update_dims.call_args_list == [
mock.call(mock_proto, {"input": ["batch", 224, 224]}, {"output": ["batch", 224, 224]})
]