From c2ecf825501f002f7b34f457d9eaf3a4fcd28871 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:58:49 -0400 Subject: [PATCH] add ml unit tests --- machine-learning/app/models/clip/textual.py | 2 -- machine-learning/app/test_main.py | 40 +++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/machine-learning/app/models/clip/textual.py b/machine-learning/app/models/clip/textual.py index b164dcc17c..28e5c8102c 100644 --- a/machine-learning/app/models/clip/textual.py +++ b/machine-learning/app/models/clip/textual.py @@ -99,8 +99,6 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder): text = clean_text(text, canonicalize=self.canonicalize) if self.is_nllb: flores_code = code if language and (code := WEBLATE_TO_FLORES200.get(language)) else "eng_Latn" - print(f"{language=}") - print(f"{flores_code=}") text = f"{flores_code}{text}" tokens: Encoding = self.tokenizer.encode(text) return {"text": np.array([tokens.ids], dtype=np.int32)} diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index 50ec188aa4..362c76a50d 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -426,6 +426,46 @@ class TestCLIP: assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0) mock_tokenizer.encode.assert_called_once_with("test search query") + def test_openclip_tokenizer_adds_flores_token_for_nllb( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + ) -> None: + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache") + clip_encoder._load() + clip_encoder.tokenize("test search query", language="de") + + mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query") + + def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + ) -> None: + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache") + clip_encoder._load() + clip_encoder.tokenize("test search query", language="de") + + mock_tokenizer.encode.assert_called_once_with("test search query") + def test_mclip_tokenizer( self, mocker: MockerFixture,