mirror of
https://github.com/immich-app/immich.git
synced 2025-04-08 03:01:32 -05:00
feat(ml): better multilingual search with nllb models (#13567)
This commit is contained in:
parent
838a8dd9a6
commit
6789c2ac19
16 changed files with 301 additions and 18 deletions
Before Width: | Height: | Size: 4.9 MiB After Width: | Height: | Size: 4.9 MiB |
|
@ -45,7 +45,7 @@ Some search examples:
|
|||
</TabItem>
|
||||
<TabItem value="Mobile" label="Mobile">
|
||||
|
||||
<img src={require('./img/moblie-smart-serach.webp').default} width="30%" title='Smart search on mobile' />
|
||||
<img src={require('./img/mobile-smart-search.webp').default} width="30%" title='Smart search on mobile' />
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
@ -56,7 +56,20 @@ Navigating to `Administration > Settings > Machine Learning Settings > Smart Sea
|
|||
|
||||
### CLIP models
|
||||
|
||||
More powerful models can be used for more accurate search results, but are slower and can require more server resources. Check the dropdowns below to see how they compare in memory usage, speed and quality by language.
|
||||
The default search model is fast, but there are many other options that can provide better search results. The tradeoff of using these models is that they're slower and/or use more memory (both when indexing images with background Smart Search jobs and when searching).
|
||||
|
||||
The first step of choosing the right model for you is to know which languages your users will search in.
|
||||
|
||||
If your users will only search in English, then the [CLIP][huggingface-clip] section is the first place to look. This is a curated list of the models that generally perform the best for their size class. The models here are ordered from higher to lower quality. This means that the top models will generally rank the most relevant results higher and have a higher capacity to understand descriptive, detailed, and/or niche queries. The models are also generally ordered from larger to smaller, so consider the impact on memory usage, job processing and search speed when deciding on one. The smaller models in this list are not too different in quality and many times faster.
|
||||
|
||||
[Multilingual models][huggingface-multilingual-clip] are also available so users can search in their native language. Use these models if you expect non-English searches to be common. They can be separated into three search patterns:
|
||||
|
||||
- `nllb` models expect the search query to be in the language specified in the user settings
|
||||
- `xlm` and `siglip2` models understand search text regardless of the current language setting
|
||||
|
||||
`nllb` models tend to perform the best and are recommended when users primarily searches in their native, non-English language. `xlm` and `siglip2` models are more flexible and are recommended for mixed language search, where the same user might search in different languages at different times.
|
||||
|
||||
For more details, check the tables below to see how they compare in memory usage, speed and quality by language.
|
||||
|
||||
Once you've chosen a model, follow these steps:
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
|
|||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.constants import WEBLATE_TO_FLORES200
|
||||
from immich_ml.models.transforms import clean_text, serialize_np_array
|
||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
depends = []
|
||||
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||
|
||||
def _predict(self, inputs: str, **kwargs: Any) -> str:
|
||||
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
|
||||
def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> str:
|
||||
tokens = self.tokenize(inputs, language=language)
|
||||
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
|
||||
return serialize_np_array(res)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
|
@ -28,6 +30,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
self.tokenizer = self._load_tokenizer()
|
||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||
self.is_nllb = self.model_name.startswith("nllb")
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
return session
|
||||
|
@ -37,7 +40,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
|
@ -92,14 +95,23 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
|||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
if self.is_nllb and language is not None:
|
||||
flores_code = WEBLATE_TO_FLORES200.get(language)
|
||||
if flores_code is None:
|
||||
no_country = language.split("-")[0]
|
||||
flores_code = WEBLATE_TO_FLORES200.get(no_country)
|
||||
if flores_code is None:
|
||||
log.warning(f"Language '{language}' not found, defaulting to 'en'")
|
||||
flores_code = "eng_Latn"
|
||||
text = f"{flores_code}{text}"
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
||||
|
||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {
|
||||
|
|
|
@ -86,6 +86,66 @@ RKNN_SUPPORTED_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]
|
|||
RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"]
|
||||
|
||||
|
||||
WEBLATE_TO_FLORES200 = {
|
||||
"af": "afr_Latn",
|
||||
"ar": "arb_Arab",
|
||||
"az": "azj_Latn",
|
||||
"be": "bel_Cyrl",
|
||||
"bg": "bul_Cyrl",
|
||||
"ca": "cat_Latn",
|
||||
"cs": "ces_Latn",
|
||||
"da": "dan_Latn",
|
||||
"de": "deu_Latn",
|
||||
"el": "ell_Grek",
|
||||
"en": "eng_Latn",
|
||||
"es": "spa_Latn",
|
||||
"et": "est_Latn",
|
||||
"fa": "pes_Arab",
|
||||
"fi": "fin_Latn",
|
||||
"fr": "fra_Latn",
|
||||
"he": "heb_Hebr",
|
||||
"hi": "hin_Deva",
|
||||
"hr": "hrv_Latn",
|
||||
"hu": "hun_Latn",
|
||||
"hy": "hye_Armn",
|
||||
"id": "ind_Latn",
|
||||
"it": "ita_Latn",
|
||||
"ja": "jpn_Hira",
|
||||
"kmr": "kmr_Latn",
|
||||
"ko": "kor_Hang",
|
||||
"lb": "ltz_Latn",
|
||||
"lt": "lit_Latn",
|
||||
"lv": "lav_Latn",
|
||||
"mfa": "zsm_Latn",
|
||||
"mk": "mkd_Cyrl",
|
||||
"mn": "khk_Cyrl",
|
||||
"mr": "mar_Deva",
|
||||
"ms": "zsm_Latn",
|
||||
"nb-NO": "nob_Latn",
|
||||
"nn": "nno_Latn",
|
||||
"nl": "nld_Latn",
|
||||
"pl": "pol_Latn",
|
||||
"pt-BR": "por_Latn",
|
||||
"pt": "por_Latn",
|
||||
"ro": "ron_Latn",
|
||||
"ru": "rus_Cyrl",
|
||||
"sk": "slk_Latn",
|
||||
"sl": "slv_Latn",
|
||||
"sr-Cyrl": "srp_Cyrl",
|
||||
"sv": "swe_Latn",
|
||||
"ta": "tam_Taml",
|
||||
"te": "tel_Telu",
|
||||
"th": "tha_Thai",
|
||||
"tr": "tur_Latn",
|
||||
"uk": "ukr_Cyrl",
|
||||
"ur": "urd_Arab",
|
||||
"vi": "vie_Latn",
|
||||
"zh-CN": "zho_Hans",
|
||||
"zh-Hans": "zho_Hans",
|
||||
"zh-TW": "zho_Hant",
|
||||
}
|
||||
|
||||
|
||||
def get_model_source(model_name: str) -> ModelSource | None:
|
||||
cleaned_name = clean_name(model_name)
|
||||
|
||||
|
|
|
@ -494,6 +494,88 @@ class TestCLIP:
|
|||
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_openclip_tokenizer_adds_flores_token_for_nllb(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
|
||||
|
||||
def test_openclip_tokenizer_removes_country_code_from_language_for_nllb_if_not_found(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de-CH")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
|
||||
|
||||
def test_openclip_tokenizer_falls_back_to_english_for_nllb_if_language_code_not_found(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
warning: mock.Mock,
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="unknown")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("eng_Latntest search query")
|
||||
warning.assert_called_once_with("Language 'unknown' not found, defaulting to 'en'")
|
||||
|
||||
def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
clip_model_cfg: dict[str, Any],
|
||||
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||
) -> None:
|
||||
mocker.patch.object(OpenClipTextualEncoder, "download")
|
||||
mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
|
||||
mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||
mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
|
||||
mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
|
||||
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||
|
||||
clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
|
||||
clip_encoder._load()
|
||||
clip_encoder.tokenize("test search query", language="de")
|
||||
|
||||
mock_tokenizer.encode.assert_called_once_with("test search query")
|
||||
|
||||
def test_mclip_tokenizer(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
|
|
|
@ -236,6 +236,7 @@ class SearchFilter {
|
|||
String? context;
|
||||
String? filename;
|
||||
String? description;
|
||||
String? language;
|
||||
Set<Person> people;
|
||||
SearchLocationFilter location;
|
||||
SearchCameraFilter camera;
|
||||
|
@ -249,6 +250,7 @@ class SearchFilter {
|
|||
this.context,
|
||||
this.filename,
|
||||
this.description,
|
||||
this.language,
|
||||
required this.people,
|
||||
required this.location,
|
||||
required this.camera,
|
||||
|
@ -279,6 +281,7 @@ class SearchFilter {
|
|||
String? context,
|
||||
String? filename,
|
||||
String? description,
|
||||
String? language,
|
||||
Set<Person>? people,
|
||||
SearchLocationFilter? location,
|
||||
SearchCameraFilter? camera,
|
||||
|
@ -290,6 +293,7 @@ class SearchFilter {
|
|||
context: context ?? this.context,
|
||||
filename: filename ?? this.filename,
|
||||
description: description ?? this.description,
|
||||
language: language ?? this.language,
|
||||
people: people ?? this.people,
|
||||
location: location ?? this.location,
|
||||
camera: camera ?? this.camera,
|
||||
|
@ -301,7 +305,7 @@ class SearchFilter {
|
|||
|
||||
@override
|
||||
String toString() {
|
||||
return 'SearchFilter(context: $context, filename: $filename, description: $description, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)';
|
||||
return 'SearchFilter(context: $context, filename: $filename, description: $description, language: $language, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)';
|
||||
}
|
||||
|
||||
@override
|
||||
|
@ -311,6 +315,7 @@ class SearchFilter {
|
|||
return other.context == context &&
|
||||
other.filename == filename &&
|
||||
other.description == description &&
|
||||
other.language == language &&
|
||||
other.people == people &&
|
||||
other.location == location &&
|
||||
other.camera == camera &&
|
||||
|
@ -324,6 +329,7 @@ class SearchFilter {
|
|||
return context.hashCode ^
|
||||
filename.hashCode ^
|
||||
description.hashCode ^
|
||||
language.hashCode ^
|
||||
people.hashCode ^
|
||||
location.hashCode ^
|
||||
camera.hashCode ^
|
||||
|
|
|
@ -48,6 +48,8 @@ class SearchPage extends HookConsumerWidget {
|
|||
isFavorite: false,
|
||||
),
|
||||
mediaType: prefilter?.mediaType ?? AssetType.other,
|
||||
language:
|
||||
"${context.locale.languageCode}-${context.locale.countryCode}",
|
||||
),
|
||||
);
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class SearchService {
|
|||
response = await _apiService.searchApi.searchSmart(
|
||||
SmartSearchDto(
|
||||
query: filter.context!,
|
||||
language: filter.language,
|
||||
country: filter.location.country,
|
||||
state: filter.location.state,
|
||||
city: filter.location.city,
|
||||
|
|
19
mobile/openapi/lib/model/smart_search_dto.dart
generated
19
mobile/openapi/lib/model/smart_search_dto.dart
generated
|
@ -25,6 +25,7 @@ class SmartSearchDto {
|
|||
this.isNotInAlbum,
|
||||
this.isOffline,
|
||||
this.isVisible,
|
||||
this.language,
|
||||
this.lensModel,
|
||||
this.libraryId,
|
||||
this.make,
|
||||
|
@ -132,6 +133,14 @@ class SmartSearchDto {
|
|||
///
|
||||
bool? isVisible;
|
||||
|
||||
///
|
||||
/// Please note: This property should have been non-nullable! Since the specification file
|
||||
/// does not include a default value (using the "default:" property), however, the generated
|
||||
/// source code must fall back to having a nullable type.
|
||||
/// Consider adding a "default:" property in the specification file to hide this note.
|
||||
///
|
||||
String? language;
|
||||
|
||||
String? lensModel;
|
||||
|
||||
String? libraryId;
|
||||
|
@ -271,6 +280,7 @@ class SmartSearchDto {
|
|||
other.isNotInAlbum == isNotInAlbum &&
|
||||
other.isOffline == isOffline &&
|
||||
other.isVisible == isVisible &&
|
||||
other.language == language &&
|
||||
other.lensModel == lensModel &&
|
||||
other.libraryId == libraryId &&
|
||||
other.make == make &&
|
||||
|
@ -308,6 +318,7 @@ class SmartSearchDto {
|
|||
(isNotInAlbum == null ? 0 : isNotInAlbum!.hashCode) +
|
||||
(isOffline == null ? 0 : isOffline!.hashCode) +
|
||||
(isVisible == null ? 0 : isVisible!.hashCode) +
|
||||
(language == null ? 0 : language!.hashCode) +
|
||||
(lensModel == null ? 0 : lensModel!.hashCode) +
|
||||
(libraryId == null ? 0 : libraryId!.hashCode) +
|
||||
(make == null ? 0 : make!.hashCode) +
|
||||
|
@ -331,7 +342,7 @@ class SmartSearchDto {
|
|||
(withExif == null ? 0 : withExif!.hashCode);
|
||||
|
||||
@override
|
||||
String toString() => 'SmartSearchDto[city=$city, country=$country, createdAfter=$createdAfter, createdBefore=$createdBefore, deviceId=$deviceId, isArchived=$isArchived, isEncoded=$isEncoded, isFavorite=$isFavorite, isMotion=$isMotion, isNotInAlbum=$isNotInAlbum, isOffline=$isOffline, isVisible=$isVisible, lensModel=$lensModel, libraryId=$libraryId, make=$make, model=$model, page=$page, personIds=$personIds, query=$query, rating=$rating, size=$size, state=$state, tagIds=$tagIds, takenAfter=$takenAfter, takenBefore=$takenBefore, trashedAfter=$trashedAfter, trashedBefore=$trashedBefore, type=$type, updatedAfter=$updatedAfter, updatedBefore=$updatedBefore, withArchived=$withArchived, withDeleted=$withDeleted, withExif=$withExif]';
|
||||
String toString() => 'SmartSearchDto[city=$city, country=$country, createdAfter=$createdAfter, createdBefore=$createdBefore, deviceId=$deviceId, isArchived=$isArchived, isEncoded=$isEncoded, isFavorite=$isFavorite, isMotion=$isMotion, isNotInAlbum=$isNotInAlbum, isOffline=$isOffline, isVisible=$isVisible, language=$language, lensModel=$lensModel, libraryId=$libraryId, make=$make, model=$model, page=$page, personIds=$personIds, query=$query, rating=$rating, size=$size, state=$state, tagIds=$tagIds, takenAfter=$takenAfter, takenBefore=$takenBefore, trashedAfter=$trashedAfter, trashedBefore=$trashedBefore, type=$type, updatedAfter=$updatedAfter, updatedBefore=$updatedBefore, withArchived=$withArchived, withDeleted=$withDeleted, withExif=$withExif]';
|
||||
|
||||
Map<String, dynamic> toJson() {
|
||||
final json = <String, dynamic>{};
|
||||
|
@ -395,6 +406,11 @@ class SmartSearchDto {
|
|||
} else {
|
||||
// json[r'isVisible'] = null;
|
||||
}
|
||||
if (this.language != null) {
|
||||
json[r'language'] = this.language;
|
||||
} else {
|
||||
// json[r'language'] = null;
|
||||
}
|
||||
if (this.lensModel != null) {
|
||||
json[r'lensModel'] = this.lensModel;
|
||||
} else {
|
||||
|
@ -508,6 +524,7 @@ class SmartSearchDto {
|
|||
isNotInAlbum: mapValueOfType<bool>(json, r'isNotInAlbum'),
|
||||
isOffline: mapValueOfType<bool>(json, r'isOffline'),
|
||||
isVisible: mapValueOfType<bool>(json, r'isVisible'),
|
||||
language: mapValueOfType<String>(json, r'language'),
|
||||
lensModel: mapValueOfType<String>(json, r'lensModel'),
|
||||
libraryId: mapValueOfType<String>(json, r'libraryId'),
|
||||
make: mapValueOfType<String>(json, r'make'),
|
||||
|
|
|
@ -11853,6 +11853,9 @@
|
|||
"isVisible": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"language": {
|
||||
"type": "string"
|
||||
},
|
||||
"lensModel": {
|
||||
"nullable": true,
|
||||
"type": "string"
|
||||
|
|
|
@ -924,6 +924,7 @@ export type SmartSearchDto = {
|
|||
isNotInAlbum?: boolean;
|
||||
isOffline?: boolean;
|
||||
isVisible?: boolean;
|
||||
language?: string;
|
||||
lensModel?: string | null;
|
||||
libraryId?: string | null;
|
||||
make?: string;
|
||||
|
|
|
@ -191,6 +191,11 @@ export class SmartSearchDto extends BaseSearchDto {
|
|||
@IsNotEmpty()
|
||||
query!: string;
|
||||
|
||||
@IsString()
|
||||
@IsNotEmpty()
|
||||
@Optional()
|
||||
language?: string;
|
||||
|
||||
@IsInt()
|
||||
@Min(1)
|
||||
@Type(() => Number)
|
||||
|
|
|
@ -53,6 +53,7 @@ export interface Face {
|
|||
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
||||
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
||||
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||
export type TextEncodingOptions = ModelOptions & { language?: string };
|
||||
|
||||
@Injectable()
|
||||
export class MachineLearningRepository {
|
||||
|
@ -170,8 +171,8 @@ export class MachineLearningRepository {
|
|||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||
async encodeText(urls: string[], text: string, { language, modelName }: TextEncodingOptions) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, options: { language } } } };
|
||||
const response = await this.predict<ClipTextualResponse>(urls, { text }, request);
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import { BadRequestException } from '@nestjs/common';
|
||||
import { mapAsset } from 'src/dtos/asset-response.dto';
|
||||
import { SearchSuggestionType } from 'src/dtos/search.dto';
|
||||
import { SearchService } from 'src/services/search.service';
|
||||
|
@ -15,6 +16,7 @@ describe(SearchService.name, () => {
|
|||
|
||||
beforeEach(() => {
|
||||
({ sut, mocks } = newTestService(SearchService));
|
||||
mocks.partner.getAll.mockResolvedValue([]);
|
||||
});
|
||||
|
||||
it('should work', () => {
|
||||
|
@ -155,4 +157,83 @@ describe(SearchService.name, () => {
|
|||
expect(mocks.search.getCameraModels).toHaveBeenCalledWith([authStub.user1.user.id], expect.anything());
|
||||
});
|
||||
});
|
||||
|
||||
describe('searchSmart', () => {
|
||||
beforeEach(() => {
|
||||
mocks.search.searchSmart.mockResolvedValue({ hasNextPage: false, items: [] });
|
||||
mocks.machineLearning.encodeText.mockResolvedValue('[1, 2, 3]');
|
||||
});
|
||||
|
||||
it('should raise a BadRequestException if machine learning is disabled', async () => {
|
||||
mocks.systemMetadata.get.mockResolvedValue({
|
||||
machineLearning: { enabled: false },
|
||||
});
|
||||
|
||||
await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError(
|
||||
new BadRequestException('Smart search is not enabled'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should raise a BadRequestException if smart search is disabled', async () => {
|
||||
mocks.systemMetadata.get.mockResolvedValue({
|
||||
machineLearning: { clip: { enabled: false } },
|
||||
});
|
||||
|
||||
await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError(
|
||||
new BadRequestException('Smart search is not enabled'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should work', async () => {
|
||||
await sut.searchSmart(authStub.user1, { query: 'test' });
|
||||
|
||||
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
|
||||
[expect.any(String)],
|
||||
'test',
|
||||
expect.objectContaining({ modelName: expect.any(String) }),
|
||||
);
|
||||
expect(mocks.search.searchSmart).toHaveBeenCalledWith(
|
||||
{ page: 1, size: 100 },
|
||||
{ query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] },
|
||||
);
|
||||
});
|
||||
|
||||
it('should consider page and size parameters', async () => {
|
||||
await sut.searchSmart(authStub.user1, { query: 'test', page: 2, size: 50 });
|
||||
|
||||
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
|
||||
[expect.any(String)],
|
||||
'test',
|
||||
expect.objectContaining({ modelName: expect.any(String) }),
|
||||
);
|
||||
expect(mocks.search.searchSmart).toHaveBeenCalledWith(
|
||||
{ page: 2, size: 50 },
|
||||
expect.objectContaining({ query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use clip model specified in config', async () => {
|
||||
mocks.systemMetadata.get.mockResolvedValue({
|
||||
machineLearning: { clip: { modelName: 'ViT-B-16-SigLIP__webli' } },
|
||||
});
|
||||
|
||||
await sut.searchSmart(authStub.user1, { query: 'test' });
|
||||
|
||||
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
|
||||
[expect.any(String)],
|
||||
'test',
|
||||
expect.objectContaining({ modelName: 'ViT-B-16-SigLIP__webli' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use language specified in request', async () => {
|
||||
await sut.searchSmart(authStub.user1, { query: 'test', language: 'de' });
|
||||
|
||||
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
|
||||
[expect.any(String)],
|
||||
'test',
|
||||
expect.objectContaining({ language: 'de' }),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -78,12 +78,10 @@ export class SearchService extends BaseService {
|
|||
}
|
||||
|
||||
const userIds = await this.getUserIdsToSearch(auth);
|
||||
|
||||
const embedding = await this.machineLearningRepository.encodeText(
|
||||
machineLearning.urls,
|
||||
dto.query,
|
||||
machineLearning.clip,
|
||||
);
|
||||
const embedding = await this.machineLearningRepository.encodeText(machineLearning.urls, dto.query, {
|
||||
modelName: machineLearning.clip.modelName,
|
||||
language: dto.language,
|
||||
});
|
||||
const page = dto.page ?? 1;
|
||||
const size = dto.size || 100;
|
||||
const { hasNextPage, items } = await this.searchRepository.searchSmart(
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
} from '@immich/sdk';
|
||||
import { mdiArrowLeft, mdiDotsVertical, mdiImageOffOutline, mdiPlus, mdiSelectAll } from '@mdi/js';
|
||||
import type { Viewport } from '$lib/stores/assets-store.svelte';
|
||||
import { locale } from '$lib/stores/preferences.store';
|
||||
import { lang, locale } from '$lib/stores/preferences.store';
|
||||
import LoadingSpinner from '$lib/components/shared-components/loading-spinner.svelte';
|
||||
import { handlePromiseError } from '$lib/utils';
|
||||
import { parseUtcDate } from '$lib/utils/date-time';
|
||||
|
@ -153,6 +153,7 @@
|
|||
page: nextPage,
|
||||
withExif: true,
|
||||
isVisible: true,
|
||||
language: $lang,
|
||||
...terms,
|
||||
};
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue