mirror of
https://github.com/immich-app/immich.git
synced 2025-01-28 00:59:18 -05:00
add ml unit tests
This commit is contained in:
parent
0bcfbc9ca7
commit
c2ecf82550
2 changed files with 40 additions and 2 deletions
|
@ -99,8 +99,6 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
|||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
if self.is_nllb:
|
||||
flores_code = code if language and (code := WEBLATE_TO_FLORES200.get(language)) else "eng_Latn"
|
||||
print(f"{language=}")
|
||||
print(f"{flores_code=}")
|
||||
text = f"{flores_code}{text}"
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
|
|
@ -426,6 +426,46 @@ class TestCLIP:
|
|||
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_openclip_tokenizer_adds_flores_token_for_nllb(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
|
||||
|
||||
def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_mclip_tokenizer(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
|
|
Loading…
Add table
Reference in a new issue