0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-28 00:59:18 -05:00

add ml unit tests

This commit is contained in:
mertalev 2024-10-17 18:58:49 -04:00
parent 0bcfbc9ca7
commit c2ecf82550
No known key found for this signature in database
GPG key ID: 46904880C3E8B346
2 changed files with 40 additions and 2 deletions

View file

@ -99,8 +99,6 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
text = clean_text(text, canonicalize=self.canonicalize)
if self.is_nllb:
flores_code = code if language and (code := WEBLATE_TO_FLORES200.get(language)) else "eng_Latn"
print(f"{language=}")
print(f"{flores_code=}")
text = f"{flores_code}{text}"
tokens: Encoding = self.tokenizer.encode(text)
return {"text": np.array([tokens.ids], dtype=np.int32)}

View file

@ -426,6 +426,46 @@ class TestCLIP:
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
mock_tokenizer.encode.assert_called_once_with("test search query")
def test_openclip_tokenizer_adds_flores_token_for_nllb(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
) -> None:
mocker.patch.object(OpenClipTextualEncoder, "download")
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
mock_ids = [randint(0, 50000) for _ in range(77)]
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
clip_encoder._load()
clip_encoder.tokenize("test search query", language="de")
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
) -> None:
mocker.patch.object(OpenClipTextualEncoder, "download")
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
mock_ids = [randint(0, 50000) for _ in range(77)]
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
clip_encoder._load()
clip_encoder.tokenize("test search query", language="de")
mock_tokenizer.encode.assert_called_once_with("test search query")
def test_mclip_tokenizer(
self,
mocker: MockerFixture,