diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index d9d1455bd1..3a5f62c3b2 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from typing import Any, Callable from unittest import mock +from app.models.session import ort_add_batch_dim, ort_has_batch_dim, ort_squeeze_outputs import cv2 import numpy as np import onnxruntime as ort @@ -734,3 +735,69 @@ class TestEndpoints: assert expected_face["boundingBox"] == actual_face["boundingBox"] assert np.allclose(expected_face["embedding"], actual_face["embedding"]) assert np.allclose(expected_face["score"], actual_face["score"]) + + +class TestSessionUtils: + def test_ort_has_batch_dim(self, mocker: MockerFixture) -> None: + mock_session = mocker.Mock(spec=ort.InferenceSession) + mock_session.get_inputs.return_value = [SimpleNamespace(shape=["batch", 3, 224, 224], name="input.1")] + + assert ort_has_batch_dim(mock_session) is True + + def test_ort_has_no_batch_dim(self, mocker: MockerFixture) -> None: + mock_session = mocker.Mock(spec=ort.InferenceSession) + mock_session.get_inputs.return_value = [SimpleNamespace(shape=[1, 3, 224, 224], name="input.1")] + + assert ort_has_batch_dim(mock_session) is False + + def test_ort_squeeze_outputs(self, mocker: MockerFixture) -> None: + mock_session = mocker.Mock(spec=ort.InferenceSession) + mock_session.run.return_value = [np.random.rand(1, 3, 224, 224).astype(np.float32)] + + ort_squeeze_outputs(mock_session) + out = mock_session.run(["output"], {"input": np.random.rand(1, 3, 224, 224).astype(np.float32)}) + + assert len(out) == 1 + assert out[0].shape == (3, 224, 224) + + def test_ort_add_batch_dim(self, mocker: MockerFixture) -> None: + mock_proto = mocker.Mock() + mock_proto.graph.input = [ + SimpleNamespace(name="input", type=mock.Mock(tensor_type=SimpleNamespace(shape=SimpleNamespace()))) + ] + mock_proto.graph.output = [ + SimpleNamespace(name="output", type=SimpleNamespace(tensor_type=SimpleNamespace(shape=SimpleNamespace()))) + ] + mock_proto.graph.input[0].type.tensor_type.shape.dim = [ + SimpleNamespace(dim_value=3), + SimpleNamespace(dim_value=224), + SimpleNamespace(dim_value=224), + ] + mock_proto.graph.output[0].type.tensor_type.shape.dim = [ + SimpleNamespace(dim_value=3), + SimpleNamespace(dim_value=224), + SimpleNamespace(dim_value=224), + ] + + mock_load = mocker.patch("app.models.session.onnx.load") + mock_load.return_value = mock_proto + + mock_update_dims = mocker.patch("app.models.session.update_inputs_outputs_dims") + mock_updated = mocker.Mock() + mock_update_dims.return_value = mock_updated + + mock_infer_shapes = mocker.patch("app.models.session.infer_shapes") + mock_inferred = mocker.Mock() + mock_infer_shapes.return_value = mock_inferred + + mock_save = mocker.patch("app.models.session.onnx.save") + + input_path, output_path = Path("input.onnx"), Path("output.onnx") + ort_add_batch_dim(input_path, output_path) + + mock_load.assert_called_once_with(input_path) + mock_save.assert_called_once_with(mock_inferred, output_path) + mock_update_dims.assert_called_once_with(mock_proto, mock.ANY, mock.ANY) + assert mock_update_dims.call_args_list == [ + mock.call(mock_proto, {"input": ["batch", 224, 224]}, {"output": ["batch", 224, 224]}) + ]