From ad062ba78e5a294830b180ae8261861e8ca70f01 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 17 Oct 2024 18:07:49 -0400 Subject: [PATCH] use user language for search --- machine-learning/app/models/clip/textual.py | 18 ++++-- machine-learning/app/models/constants.py | 59 +++++++++++++++++++ server/src/dtos/search.dto.ts | 5 ++ .../interfaces/machine-learning.interface.ts | 7 ++- .../machine-learning.repository.ts | 5 +- server/src/services/search.service.ts | 10 ++-- .../[[assetId=id]]/+page.svelte | 3 +- 7 files changed, 91 insertions(+), 16 deletions(-) diff --git a/machine-learning/app/models/clip/textual.py b/machine-learning/app/models/clip/textual.py index 32c28ea2bb..b164dcc17c 100644 --- a/machine-learning/app/models/clip/textual.py +++ b/machine-learning/app/models/clip/textual.py @@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer from app.config import log from app.models.base import InferenceModel +from app.models.constants import WEBLATE_TO_FLORES200 from app.models.transforms import clean_text from app.schemas import ModelSession, ModelTask, ModelType @@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel): depends = [] identity = (ModelType.TEXTUAL, ModelTask.SEARCH) - def _predict(self, inputs: str, **kwargs: Any) -> NDArray[np.float32]: - res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0] + def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> NDArray[np.float32]: + tokens = self.tokenize(inputs, language=language) + res: NDArray[np.float32] = self.session.run(None, tokens)[0][0] return res def _load(self) -> ModelSession: @@ -28,6 +30,7 @@ class BaseCLIPTextualEncoder(InferenceModel): self.tokenizer = self._load_tokenizer() tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs") self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize" + self.is_nllb = self.model_name.startswith("nllb") log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'") return session @@ -37,7 +40,7 @@ class BaseCLIPTextualEncoder(InferenceModel): pass @abstractmethod - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: pass @property @@ -92,14 +95,19 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder): return tokenizer - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: text = clean_text(text, canonicalize=self.canonicalize) + if self.is_nllb: + flores_code = code if language and (code := WEBLATE_TO_FLORES200.get(language)) else "eng_Latn" + print(f"{language=}") + print(f"{flores_code=}") + text = f"{flores_code}{text}" tokens: Encoding = self.tokenizer.encode(text) return {"text": np.array([tokens.ids], dtype=np.int32)} class MClipTextualEncoder(OpenClipTextualEncoder): - def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: + def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]: text = clean_text(text, canonicalize=self.canonicalize) tokens: Encoding = self.tokenizer.encode(text) return { diff --git a/machine-learning/app/models/constants.py b/machine-learning/app/models/constants.py index 338a481594..a84c3802ec 100644 --- a/machine-learning/app/models/constants.py +++ b/machine-learning/app/models/constants.py @@ -66,6 +66,65 @@ _INSIGHTFACE_MODELS = { SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"] +WEBLATE_TO_FLORES200 = { + "af": "afr_Latn", + "ar": "arb_Arab", + "az": "azj_Latn", + "be": "bel_Cyrl", + "bg": "bul_Cyrl", + "ca": "cat_Latn", + "cs": "ces_Latn", + "da": "dan_Latn", + "de": "deu_Latn", + "el": "ell_Grek", + "en": "eng_Latn", + "es": "spa_Latn", + "et": "est_Latn", + "fa": "pes_Arab", + "fi": "fin_Latn", + "fr": "fra_Latn", + "he": "heb_Hebr", + "hi": "hin_Deva", + "hr": "hrv_Latn", + "hu": "hun_Latn", + "hy": "hye_Armn", + "id": "ind_Latn", + "it": "ita_Latn", + "ja": "jpn_Hira", + "kmr": "kmr_Latn", + "ko": "kor_Hang", + "lb": "ltz_Latn", + "lt": "lit_Latn", + "lv": "lav_Latn", + "mfa": "zsm_Latn", + "mk": "mkd_Cyrl", + "mn": "khk_Cyrl", + "mr": "mar_Deva", + "ms": "zsm_Latn", + "nb_NO": "nob_Latn", + "nl": "nld_Latn", + "pl": "pol_Latn", + "pt_BR": "por_Latn", + "pt": "por_Latn", + "ro": "ron_Latn", + "ru": "rus_Cyrl", + "sk": "slk_Latn", + "sl": "slv_Latn", + "sr_Cyrl": "srp_Cyrl", + "sv": "swe_Latn", + "ta": "tam_Taml", + "te": "tel_Telu", + "th": "tha_Thai", + "tr": "tur_Latn", + "uk": "ukr_Cyrl", + "vi": "vie_Latn", + "zh-CN": "zho_Hans", + "zh-TW": "zho_Hant", + "zh_Hant": "zho_Hant", + "zh_SIMPLIFIED": "zho_Hans", +} + + def get_model_source(model_name: str) -> ModelSource | None: cleaned_name = clean_name(model_name) diff --git a/server/src/dtos/search.dto.ts b/server/src/dtos/search.dto.ts index 5c5dce1a11..434bb3562f 100644 --- a/server/src/dtos/search.dto.ts +++ b/server/src/dtos/search.dto.ts @@ -177,6 +177,11 @@ export class SmartSearchDto extends BaseSearchDto { @IsNotEmpty() query!: string; + @IsString() + @IsNotEmpty() + @Optional() + language?: string; + @IsInt() @Min(1) @Type(() => Number) diff --git a/server/src/interfaces/machine-learning.interface.ts b/server/src/interfaces/machine-learning.interface.ts index 5342030c8f..b140755a27 100644 --- a/server/src/interfaces/machine-learning.interface.ts +++ b/server/src/interfaces/machine-learning.interface.ts @@ -30,7 +30,9 @@ type VisualResponse = { imageHeight: number; imageWidth: number }; export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } }; export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse; -export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } }; +export type ClipTextualRequest = { + [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions & { options: { language?: string } } }; +}; export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] }; export type FacialRecognitionRequest = { @@ -49,9 +51,10 @@ export interface Face { export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse; export type DetectedFaces = { faces: Face[] } & VisualResponse; export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest; +export type TextEncodingOptions = ModelOptions & { language?: string }; export interface IMachineLearningRepository { encodeImage(url: string, imagePath: string, config: ModelOptions): Promise; - encodeText(url: string, text: string, config: ModelOptions): Promise; + encodeText(url: string, text: string, config: TextEncodingOptions): Promise; detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise; } diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts index b9404022ef..f15364c24f 100644 --- a/server/src/repositories/machine-learning.repository.ts +++ b/server/src/repositories/machine-learning.repository.ts @@ -11,6 +11,7 @@ import { ModelPayload, ModelTask, ModelType, + TextEncodingOptions, } from 'src/interfaces/machine-learning.interface'; import { Instrumentation } from 'src/utils/instrumentation'; @@ -55,8 +56,8 @@ export class MachineLearningRepository implements IMachineLearningRepository { return response[ModelTask.SEARCH]; } - async encodeText(url: string, text: string, { modelName }: CLIPConfig) { - const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } }; + async encodeText(url: string, text: string, { language, modelName }: TextEncodingOptions) { + const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, options: { language } } } }; const response = await this.predict(url, { text }, request); return response[ModelTask.SEARCH]; } diff --git a/server/src/services/search.service.ts b/server/src/services/search.service.ts index 03ffbe97db..db6c70b143 100644 --- a/server/src/services/search.service.ts +++ b/server/src/services/search.service.ts @@ -86,12 +86,10 @@ export class SearchService extends BaseService { } const userIds = await this.getUserIdsToSearch(auth); - - const embedding = await this.machineLearningRepository.encodeText( - machineLearning.url, - dto.query, - machineLearning.clip, - ); + const embedding = await this.machineLearningRepository.encodeText(machineLearning.url, dto.query, { + modelName: machineLearning.clip.modelName, + language: dto.language, + }); const page = dto.page ?? 1; const size = dto.size || 100; const { hasNextPage, items } = await this.searchRepository.searchSmart( diff --git a/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte b/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte index eb0c493204..0efee47472 100644 --- a/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte +++ b/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte @@ -31,7 +31,7 @@ } from '@immich/sdk'; import { mdiArrowLeft, mdiDotsVertical, mdiImageOffOutline, mdiPlus, mdiSelectAll } from '@mdi/js'; import type { Viewport } from '$lib/stores/assets.store'; - import { locale } from '$lib/stores/preferences.store'; + import { lang, locale } from '$lib/stores/preferences.store'; import LoadingSpinner from '$lib/components/shared-components/loading-spinner.svelte'; import { handlePromiseError } from '$lib/utils'; import { parseUtcDate } from '$lib/utils/date-time'; @@ -144,6 +144,7 @@ page: nextPage, withExif: true, isVisible: true, + language: $lang, ...terms, };