mirror of
https://github.com/immich-app/immich.git
synced 2025-02-04 01:09:14 -05:00
use user language for search
This commit is contained in:
parent
12628b80bc
commit
ad062ba78e
7 changed files with 91 additions and 16 deletions
|
@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
|
||||||
|
|
||||||
from app.config import log
|
from app.config import log
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
|
from app.models.constants import WEBLATE_TO_FLORES200
|
||||||
from app.models.transforms import clean_text
|
from app.models.transforms import clean_text
|
||||||
from app.schemas import ModelSession, ModelTask, ModelType
|
from app.schemas import ModelSession, ModelTask, ModelType
|
||||||
|
|
||||||
|
@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||||
depends = []
|
depends = []
|
||||||
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
|
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||||
|
|
||||||
def _predict(self, inputs: str, **kwargs: Any) -> NDArray[np.float32]:
|
def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> NDArray[np.float32]:
|
||||||
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
|
tokens = self.tokenize(inputs, language=language)
|
||||||
|
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _load(self) -> ModelSession:
|
def _load(self) -> ModelSession:
|
||||||
|
@ -28,6 +30,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||||
|
self.is_nllb = self.model_name.startswith("nllb")
|
||||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
@ -37,7 +40,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -92,14 +95,19 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||||
text = clean_text(text, canonicalize=self.canonicalize)
|
text = clean_text(text, canonicalize=self.canonicalize)
|
||||||
|
if self.is_nllb:
|
||||||
|
flores_code = code if language and (code := WEBLATE_TO_FLORES200.get(language)) else "eng_Latn"
|
||||||
|
print(f"{language=}")
|
||||||
|
print(f"{flores_code=}")
|
||||||
|
text = f"{flores_code}{text}"
|
||||||
tokens: Encoding = self.tokenizer.encode(text)
|
tokens: Encoding = self.tokenizer.encode(text)
|
||||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||||
|
|
||||||
|
|
||||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||||
text = clean_text(text, canonicalize=self.canonicalize)
|
text = clean_text(text, canonicalize=self.canonicalize)
|
||||||
tokens: Encoding = self.tokenizer.encode(text)
|
tokens: Encoding = self.tokenizer.encode(text)
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -66,6 +66,65 @@ _INSIGHTFACE_MODELS = {
|
||||||
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
|
||||||
|
|
||||||
|
WEBLATE_TO_FLORES200 = {
|
||||||
|
"af": "afr_Latn",
|
||||||
|
"ar": "arb_Arab",
|
||||||
|
"az": "azj_Latn",
|
||||||
|
"be": "bel_Cyrl",
|
||||||
|
"bg": "bul_Cyrl",
|
||||||
|
"ca": "cat_Latn",
|
||||||
|
"cs": "ces_Latn",
|
||||||
|
"da": "dan_Latn",
|
||||||
|
"de": "deu_Latn",
|
||||||
|
"el": "ell_Grek",
|
||||||
|
"en": "eng_Latn",
|
||||||
|
"es": "spa_Latn",
|
||||||
|
"et": "est_Latn",
|
||||||
|
"fa": "pes_Arab",
|
||||||
|
"fi": "fin_Latn",
|
||||||
|
"fr": "fra_Latn",
|
||||||
|
"he": "heb_Hebr",
|
||||||
|
"hi": "hin_Deva",
|
||||||
|
"hr": "hrv_Latn",
|
||||||
|
"hu": "hun_Latn",
|
||||||
|
"hy": "hye_Armn",
|
||||||
|
"id": "ind_Latn",
|
||||||
|
"it": "ita_Latn",
|
||||||
|
"ja": "jpn_Hira",
|
||||||
|
"kmr": "kmr_Latn",
|
||||||
|
"ko": "kor_Hang",
|
||||||
|
"lb": "ltz_Latn",
|
||||||
|
"lt": "lit_Latn",
|
||||||
|
"lv": "lav_Latn",
|
||||||
|
"mfa": "zsm_Latn",
|
||||||
|
"mk": "mkd_Cyrl",
|
||||||
|
"mn": "khk_Cyrl",
|
||||||
|
"mr": "mar_Deva",
|
||||||
|
"ms": "zsm_Latn",
|
||||||
|
"nb_NO": "nob_Latn",
|
||||||
|
"nl": "nld_Latn",
|
||||||
|
"pl": "pol_Latn",
|
||||||
|
"pt_BR": "por_Latn",
|
||||||
|
"pt": "por_Latn",
|
||||||
|
"ro": "ron_Latn",
|
||||||
|
"ru": "rus_Cyrl",
|
||||||
|
"sk": "slk_Latn",
|
||||||
|
"sl": "slv_Latn",
|
||||||
|
"sr_Cyrl": "srp_Cyrl",
|
||||||
|
"sv": "swe_Latn",
|
||||||
|
"ta": "tam_Taml",
|
||||||
|
"te": "tel_Telu",
|
||||||
|
"th": "tha_Thai",
|
||||||
|
"tr": "tur_Latn",
|
||||||
|
"uk": "ukr_Cyrl",
|
||||||
|
"vi": "vie_Latn",
|
||||||
|
"zh-CN": "zho_Hans",
|
||||||
|
"zh-TW": "zho_Hant",
|
||||||
|
"zh_Hant": "zho_Hant",
|
||||||
|
"zh_SIMPLIFIED": "zho_Hans",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_model_source(model_name: str) -> ModelSource | None:
|
def get_model_source(model_name: str) -> ModelSource | None:
|
||||||
cleaned_name = clean_name(model_name)
|
cleaned_name = clean_name(model_name)
|
||||||
|
|
||||||
|
|
|
@ -177,6 +177,11 @@ export class SmartSearchDto extends BaseSearchDto {
|
||||||
@IsNotEmpty()
|
@IsNotEmpty()
|
||||||
query!: string;
|
query!: string;
|
||||||
|
|
||||||
|
@IsString()
|
||||||
|
@IsNotEmpty()
|
||||||
|
@Optional()
|
||||||
|
language?: string;
|
||||||
|
|
||||||
@IsInt()
|
@IsInt()
|
||||||
@Min(1)
|
@Min(1)
|
||||||
@Type(() => Number)
|
@Type(() => Number)
|
||||||
|
|
|
@ -30,7 +30,9 @@ type VisualResponse = { imageHeight: number; imageWidth: number };
|
||||||
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
||||||
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
|
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
|
||||||
|
|
||||||
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
|
export type ClipTextualRequest = {
|
||||||
|
[ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions & { options: { language?: string } } };
|
||||||
|
};
|
||||||
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
|
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
|
||||||
|
|
||||||
export type FacialRecognitionRequest = {
|
export type FacialRecognitionRequest = {
|
||||||
|
@ -49,9 +51,10 @@ export interface Face {
|
||||||
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
||||||
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
||||||
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||||
|
export type TextEncodingOptions = ModelOptions & { language?: string };
|
||||||
|
|
||||||
export interface IMachineLearningRepository {
|
export interface IMachineLearningRepository {
|
||||||
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||||
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
encodeText(url: string, text: string, config: TextEncodingOptions): Promise<number[]>;
|
||||||
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import {
|
||||||
ModelPayload,
|
ModelPayload,
|
||||||
ModelTask,
|
ModelTask,
|
||||||
ModelType,
|
ModelType,
|
||||||
|
TextEncodingOptions,
|
||||||
} from 'src/interfaces/machine-learning.interface';
|
} from 'src/interfaces/machine-learning.interface';
|
||||||
import { Instrumentation } from 'src/utils/instrumentation';
|
import { Instrumentation } from 'src/utils/instrumentation';
|
||||||
|
|
||||||
|
@ -55,8 +56,8 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
||||||
return response[ModelTask.SEARCH];
|
return response[ModelTask.SEARCH];
|
||||||
}
|
}
|
||||||
|
|
||||||
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
|
async encodeText(url: string, text: string, { language, modelName }: TextEncodingOptions) {
|
||||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, options: { language } } } };
|
||||||
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
|
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
|
||||||
return response[ModelTask.SEARCH];
|
return response[ModelTask.SEARCH];
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,12 +86,10 @@ export class SearchService extends BaseService {
|
||||||
}
|
}
|
||||||
|
|
||||||
const userIds = await this.getUserIdsToSearch(auth);
|
const userIds = await this.getUserIdsToSearch(auth);
|
||||||
|
const embedding = await this.machineLearningRepository.encodeText(machineLearning.url, dto.query, {
|
||||||
const embedding = await this.machineLearningRepository.encodeText(
|
modelName: machineLearning.clip.modelName,
|
||||||
machineLearning.url,
|
language: dto.language,
|
||||||
dto.query,
|
});
|
||||||
machineLearning.clip,
|
|
||||||
);
|
|
||||||
const page = dto.page ?? 1;
|
const page = dto.page ?? 1;
|
||||||
const size = dto.size || 100;
|
const size = dto.size || 100;
|
||||||
const { hasNextPage, items } = await this.searchRepository.searchSmart(
|
const { hasNextPage, items } = await this.searchRepository.searchSmart(
|
||||||
|
|
|
@ -31,7 +31,7 @@
|
||||||
} from '@immich/sdk';
|
} from '@immich/sdk';
|
||||||
import { mdiArrowLeft, mdiDotsVertical, mdiImageOffOutline, mdiPlus, mdiSelectAll } from '@mdi/js';
|
import { mdiArrowLeft, mdiDotsVertical, mdiImageOffOutline, mdiPlus, mdiSelectAll } from '@mdi/js';
|
||||||
import type { Viewport } from '$lib/stores/assets.store';
|
import type { Viewport } from '$lib/stores/assets.store';
|
||||||
import { locale } from '$lib/stores/preferences.store';
|
import { lang, locale } from '$lib/stores/preferences.store';
|
||||||
import LoadingSpinner from '$lib/components/shared-components/loading-spinner.svelte';
|
import LoadingSpinner from '$lib/components/shared-components/loading-spinner.svelte';
|
||||||
import { handlePromiseError } from '$lib/utils';
|
import { handlePromiseError } from '$lib/utils';
|
||||||
import { parseUtcDate } from '$lib/utils/date-time';
|
import { parseUtcDate } from '$lib/utils/date-time';
|
||||||
|
@ -144,6 +144,7 @@
|
||||||
page: nextPage,
|
page: nextPage,
|
||||||
withExif: true,
|
withExif: true,
|
||||||
isVisible: true,
|
isVisible: true,
|
||||||
|
language: $lang,
|
||||||
...terms,
|
...terms,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue