diff --git a/machine-learning/app/models/clip/textual.py b/machine-learning/app/models/clip/textual.py index 7a25c2f4ad..32c28ea2bb 100644 --- a/machine-learning/app/models/clip/textual.py +++ b/machine-learning/app/models/clip/textual.py @@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer from app.config import log from app.models.base import InferenceModel +from app.models.transforms import clean_text from app.schemas import ModelSession, ModelTask, ModelType @@ -25,6 +26,8 @@ class BaseCLIPTextualEncoder(InferenceModel): session = super()._load() log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'") self.tokenizer = self._load_tokenizer() + tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs") + self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize" log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'") return session @@ -56,6 +59,11 @@ class BaseCLIPTextualEncoder(InferenceModel): log.debug(f"Loaded model config for CLIP model '{self.model_name}'") return model_cfg + @property + def text_cfg(self) -> dict[str, Any]: + text_cfg: dict[str, Any] = self.model_cfg["text_cfg"] + return text_cfg + @cached_property def tokenizer_file(self) -> dict[str, Any]: log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'") @@ -73,8 +81,7 @@ class BaseCLIPTextualEncoder(InferenceModel): class OpenClipTextualEncoder(BaseCLIPTextualEncoder): def _load_tokenizer(self) -> Tokenizer: - text_cfg: dict[str, Any] = self.model_cfg["text_cfg"] - context_length: int = text_cfg.get("context_length", 77) + context_length: int = self.text_cfg.get("context_length", 77) pad_token: str = self.tokenizer_cfg["pad_token"] tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix()) @@ -86,12 +93,14 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder): return tokenizer def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + text = clean_text(text, canonicalize=self.canonicalize) tokens: Encoding = self.tokenizer.encode(text) return {"text": np.array([tokens.ids], dtype=np.int32)} class MClipTextualEncoder(OpenClipTextualEncoder): def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + text = clean_text(text, canonicalize=self.canonicalize) tokens: Encoding = self.tokenizer.encode(text) return { "input_ids": np.array([tokens.ids], dtype=np.int32), diff --git a/machine-learning/app/models/transforms.py b/machine-learning/app/models/transforms.py index cae9b6b1ab..bb03103d4b 100644 --- a/machine-learning/app/models/transforms.py +++ b/machine-learning/app/models/transforms.py @@ -1,3 +1,4 @@ +import string from io import BytesIO from typing import IO @@ -7,6 +8,7 @@ from numpy.typing import NDArray from PIL import Image _PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling} +_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation) def resize_pil(img: Image.Image, size: int) -> Image.Image: @@ -60,3 +62,10 @@ def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[ if isinstance(image_bytes, Image.Image): return pil_to_cv2(image_bytes) return image_bytes + + +def clean_text(text: str, canonicalize: bool = False) -> str: + text = " ".join(text.split()) + if canonicalize: + text = text.translate(_PUNCTUATION_TRANS).lower() + return text diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index fb3542e7e4..17fdb5b1fa 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -379,13 +379,40 @@ class TestCLIP: clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache") clip_encoder._load() - tokens = clip_encoder.tokenize("test search query") + tokens = clip_encoder.tokenize("test search query") assert "text" in tokens assert isinstance(tokens["text"], np.ndarray) assert tokens["text"].shape == (1, 77) assert tokens["text"].dtype == np.int32 assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0) + mock_tokenizer.encode.assert_called_once_with("test search query") + + def test_openclip_tokenizer_canonicalizes_text( + self, + mocker: MockerFixture, + clip_model_cfg: dict[str, Any], + clip_tokenizer_cfg: Callable[[Path], dict[str, Any]], + ) -> None: + clip_model_cfg["text_cfg"]["tokenizer_kwargs"] = {"clean": "canonicalize"} + mocker.patch.object(OpenClipTextualEncoder, "download") + mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg) + mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg) + mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value + mock_tokenizer = mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True).return_value + mock_ids = [randint(0, 50000) for _ in range(77)] + mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids) + + clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache") + clip_encoder._load() + tokens = clip_encoder.tokenize("Test Search Query!") + + assert "text" in tokens + assert isinstance(tokens["text"], np.ndarray) + assert tokens["text"].shape == (1, 77) + assert tokens["text"].dtype == np.int32 + assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0) + mock_tokenizer.encode.assert_called_once_with("test search query") def test_mclip_tokenizer( self,