%DF5(u;`q<|4*+3%(MoBW__&isVZ=u@Zjx;1)D+4DKWPmTC{PZ(<
z6}sAyq^-OBLyb3U25*=G(Wk=_L}iGduq0+_l+FH@dSSyY>1{*eNSv}DzHE=`q>YCy
zLZrO_*W@`mSCWwf&n(xO&g;-(Y&CWY7AgE;_DD)M<~bO+4wLkK->vQci6HvsZo1dB
z#o}JWVqvYZG07c`g?>&2Yxp1_it_-paEb%$IUBNg`Ed!fduSIC-^9|%8i{j7$>t!i^7y9(oE0ZzDB0HgHeAkHqBk~^HOQ0C@Dt^mf#Z@6_~Qj$DzFlISVIgI&^
vXE%(wlGg*q{K$I-#(c*kJ^2RTLKwG_-y6m}!k+_U8VFp2GB>jcYA^x-$68h2
delta 800
zcmZ`%&ubGw7-el*O=|R1L5lcNYqPOwqKKD-P{E5-3I%(SVqvoTZL>O=S@uUt6bT+Y
zh~VYJgEud|s*r0h-UV+SrGo#4;OyoXV!bTPeDA$^@9p>N>c`d3Ws5hg3ZE=2=V}Fj
zUs^{Ofw<#GwXSwa?)C#No%uPa*Xw3%zSxl@B7)K}HLz=d2tplD821r%bRu1JM32xW
zQU;_pR3fElODf@^uN*SkgeV}#_k2ghzE<#jNn>$WQKxN_*1YGh&GLSKPLM8+|
ziF|dCsjOrp88=p@7AjSA^da>qtY(k{QWT1q+_8S7lhAl@Of+@1otgzR
zzN<}^(C7Tz7C)N%at&*1<`&SV#&G(J2(8wX$d_~9ZeTrw*;BO?k;cL%ZlW;-Q;C1%
zZt>Fmv!&|nvb8$l4d2dRw+6lk`>&lKltywDJkK5ioOgVFUwqNxsfJ}+IzLBtP4j^2)z|MpI3Fr
zo&2{}4#$ny;Ik4A-$6PL=guELURdLUFd3=l4mU6QS&5tYjD|9&tS6k2#spKU4iY*$J8J|G&S^qUk{i6BXkkz10c-q2Qs
zHHEdL%GpQcRMPxW=vpo^oQuiseP#JDtd-BA!5M-j&oi8^O(~t;DYK&ZmJ?&!6l5be
z57M#6r7}X(*WfzK=;ot+l+GbMOq!R$riKp^w
zSusgwP8g4&*hn`#%Syme%Hctm8_!(jL29INf87u9585)K?+?~$S~`a6$nEwPoLDe7gR3E^
zfR;t{xI^^I*gy}G68a_ix+tXgOE}r_0CT@_-bAXg*TfiEX7q>MN{Hi#2cj1`&eZF%
z8{hwvL5j_+x;I#TX&K6L>YR0{$Th~pJcoj1zn2j5+!3}^xuY#t1CN(~-4GTQ78n2r
zN(ZYa%@>Nr7-8>Vgq_@A5tJ<=Y>)NqVBsW0SwUTE&zl2|t%kFz+AlDEz$3v~c0{;n
zARo4!<`%c|)HMBaDS?8{=E-X2OqA-trc3+24& RecognitionConfig)
+ @Type(() => FacialRecognitionConfig)
@ValidateNested()
@IsObject()
- facialRecognition!: RecognitionConfig;
+ facialRecognition!: FacialRecognitionConfig;
}
enum MapTheme {
diff --git a/server/src/interfaces/machine-learning.interface.ts b/server/src/interfaces/machine-learning.interface.ts
index 0aeed7635a..143281c23a 100644
--- a/server/src/interfaces/machine-learning.interface.ts
+++ b/server/src/interfaces/machine-learning.interface.ts
@@ -1,15 +1,5 @@
-import { CLIPConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
-
export const IMachineLearningRepository = 'IMachineLearningRepository';
-export interface VisionModelInput {
- imagePath: string;
-}
-
-export interface TextModelInput {
- text: string;
-}
-
export interface BoundingBox {
x1: number;
y1: number;
@@ -17,26 +7,51 @@ export interface BoundingBox {
y2: number;
}
-export interface DetectFaceResult {
- imageWidth: number;
- imageHeight: number;
- boundingBox: BoundingBox;
- score: number;
- embedding: number[];
+export enum ModelTask {
+ FACIAL_RECOGNITION = 'facial-recognition',
+ SEARCH = 'clip',
}
export enum ModelType {
- FACIAL_RECOGNITION = 'facial-recognition',
- CLIP = 'clip',
+ DETECTION = 'detection',
+ PIPELINE = 'pipeline',
+ RECOGNITION = 'recognition',
+ TEXTUAL = 'textual',
+ VISUAL = 'visual',
}
-export enum CLIPMode {
- VISION = 'vision',
- TEXT = 'text',
+export type ModelPayload = { imagePath: string } | { text: string };
+
+type ModelOptions = { modelName: string };
+
+export type FaceDetectionOptions = ModelOptions & { minScore: number };
+
+type VisualResponse = { imageHeight: number; imageWidth: number };
+export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
+export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
+
+export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
+export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
+
+export type FacialRecognitionRequest = {
+ [ModelTask.FACIAL_RECOGNITION]: {
+ [ModelType.DETECTION]: FaceDetectionOptions;
+ [ModelType.RECOGNITION]: ModelOptions;
+ };
+};
+
+export interface Face {
+ boundingBox: BoundingBox;
+ embedding: number[];
+ score: number;
}
+export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
+export type DetectedFaces = { faces: Face[] } & VisualResponse;
+export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
+
export interface IMachineLearningRepository {
- encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise;
- encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise;
- detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise;
+ encodeImage(url: string, imagePath: string, config: ModelOptions): Promise;
+ encodeText(url: string, text: string, config: ModelOptions): Promise;
+ detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise;
}
diff --git a/server/src/interfaces/search.interface.ts b/server/src/interfaces/search.interface.ts
index ce9e2a1940..d5382a04fa 100644
--- a/server/src/interfaces/search.interface.ts
+++ b/server/src/interfaces/search.interface.ts
@@ -37,8 +37,6 @@ export interface SearchExploreItem {
items: SearchExploreItemSet;
}
-export type Embedding = number[];
-
export interface SearchAssetIDOptions {
checksum?: Buffer;
deviceAssetId?: string;
@@ -106,7 +104,7 @@ export interface SearchExifOptions {
}
export interface SearchEmbeddingOptions {
- embedding: Embedding;
+ embedding: number[];
userIds: string[];
}
@@ -154,7 +152,7 @@ export interface FaceEmbeddingSearch extends SearchEmbeddingOptions {
export interface AssetDuplicateSearch {
assetId: string;
- embedding: Embedding;
+ embedding: number[];
maxDistance?: number;
type: AssetType;
userIds: string[];
diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts
index bff22b9507..405e5a421d 100644
--- a/server/src/repositories/machine-learning.repository.ts
+++ b/server/src/repositories/machine-learning.repository.ts
@@ -1,13 +1,16 @@
import { Injectable } from '@nestjs/common';
import { readFile } from 'node:fs/promises';
-import { CLIPConfig, ModelConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
+import { CLIPConfig } from 'src/dtos/model-config.dto';
import {
- CLIPMode,
- DetectFaceResult,
+ ClipTextualResponse,
+ ClipVisualResponse,
+ FaceDetectionOptions,
+ FacialRecognitionResponse,
IMachineLearningRepository,
+ MachineLearningRequest,
+ ModelPayload,
+ ModelTask,
ModelType,
- TextModelInput,
- VisionModelInput,
} from 'src/interfaces/machine-learning.interface';
import { Instrumentation } from 'src/utils/instrumentation';
@@ -16,8 +19,8 @@ const errorPrefix = 'Machine learning request';
@Instrumentation()
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
- private async predict(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise {
- const formData = await this.getFormData(input, config);
+ private async predict(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise {
+ const formData = await this.getFormData(payload, config);
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
(error: Error | any) => {
@@ -26,50 +29,46 @@ export class MachineLearningRepository implements IMachineLearningRepository {
);
if (res.status >= 400) {
- const modelType = config.modelType ? ` for ${config.modelType.replace('-', ' ')}` : '';
- throw new Error(`${errorPrefix}${modelType} failed with status ${res.status}: ${res.statusText}`);
+ throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
}
return res.json();
}
- detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise {
- return this.predict(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
+ async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
+ const request = {
+ [ModelTask.FACIAL_RECOGNITION]: {
+ [ModelType.DETECTION]: { modelName, minScore },
+ [ModelType.RECOGNITION]: { modelName },
+ },
+ };
+ const response = await this.predict(url, { imagePath }, request);
+ return {
+ imageHeight: response.imageHeight,
+ imageWidth: response.imageWidth,
+ faces: response[ModelTask.FACIAL_RECOGNITION],
+ };
}
- encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise {
- return this.predict(url, input, {
- ...config,
- modelType: ModelType.CLIP,
- mode: CLIPMode.VISION,
- } as CLIPConfig);
+ async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
+ const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
+ const response = await this.predict(url, { imagePath }, request);
+ return response[ModelTask.SEARCH];
}
- encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise {
- return this.predict(url, input, {
- ...config,
- modelType: ModelType.CLIP,
- mode: CLIPMode.TEXT,
- } as CLIPConfig);
+ async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
+ const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
+ const response = await this.predict(url, { text }, request);
+ return response[ModelTask.SEARCH];
}
- private async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise {
+ private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise {
const formData = new FormData();
- const { enabled, modelName, modelType, ...options } = config;
- if (!enabled) {
- throw new Error(`${modelType} is not enabled`);
- }
+ formData.append('entries', JSON.stringify(config));
- formData.append('modelName', modelName);
- if (modelType) {
- formData.append('modelType', modelType);
- }
- if (options) {
- formData.append('options', JSON.stringify(options));
- }
- if ('imagePath' in input) {
- formData.append('image', new Blob([await readFile(input.imagePath)]));
- } else if ('text' in input) {
- formData.append('text', input.text);
+ if ('imagePath' in payload) {
+ formData.append('image', new Blob([await readFile(payload.imagePath)]));
+ } else if ('text' in payload) {
+ formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
}
diff --git a/server/src/services/person.service.spec.ts b/server/src/services/person.service.spec.ts
index 1644c0c896..56447c8d20 100644
--- a/server/src/services/person.service.spec.ts
+++ b/server/src/services/person.service.spec.ts
@@ -7,7 +7,7 @@ import { IAssetRepository, WithoutProperty } from 'src/interfaces/asset.interfac
import { ICryptoRepository } from 'src/interfaces/crypto.interface';
import { IJobRepository, JobName, JobStatus } from 'src/interfaces/job.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
-import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
+import { DetectedFaces, IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
import { IMediaRepository } from 'src/interfaces/media.interface';
import { IMoveRepository } from 'src/interfaces/move.interface';
import { IPersonRepository } from 'src/interfaces/person.interface';
@@ -46,19 +46,21 @@ const responseDto: PersonResponseDto = {
const statistics = { assets: 3 };
-const detectFaceMock = {
- assetId: 'asset-1',
- personId: 'person-1',
- boundingBox: {
- x1: 100,
- y1: 100,
- x2: 200,
- y2: 200,
- },
+const detectFaceMock: DetectedFaces = {
+ faces: [
+ {
+ boundingBox: {
+ x1: 100,
+ y1: 100,
+ x2: 200,
+ y2: 200,
+ },
+ embedding: [1, 2, 3, 4],
+ score: 0.2,
+ },
+ ],
imageHeight: 500,
imageWidth: 400,
- embedding: [1, 2, 3, 4],
- score: 0.2,
};
describe(PersonService.name, () => {
@@ -642,21 +644,13 @@ describe(PersonService.name, () => {
it('should handle no results', async () => {
const start = Date.now();
- machineLearningMock.detectFaces.mockResolvedValue([]);
+ machineLearningMock.detectFaces.mockResolvedValue({ imageHeight: 500, imageWidth: 400, faces: [] });
assetMock.getByIds.mockResolvedValue([assetStub.image]);
await sut.handleDetectFaces({ id: assetStub.image.id });
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
- {
- imagePath: assetStub.image.previewPath,
- },
- {
- enabled: true,
- maxDistance: 0.5,
- minScore: 0.7,
- minFaces: 3,
- modelName: 'buffalo_l',
- },
+ assetStub.image.previewPath,
+ expect.objectContaining({ minScore: 0.7, modelName: 'buffalo_l' }),
);
expect(personMock.createFaces).not.toHaveBeenCalled();
expect(jobMock.queue).not.toHaveBeenCalled();
@@ -671,7 +665,7 @@ describe(PersonService.name, () => {
it('should create a face with no person and queue recognition job', async () => {
personMock.createFaces.mockResolvedValue([faceStub.face1.id]);
- machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
+ machineLearningMock.detectFaces.mockResolvedValue(detectFaceMock);
searchMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
assetMock.getByIds.mockResolvedValue([assetStub.image]);
const face = {
diff --git a/server/src/services/person.service.ts b/server/src/services/person.service.ts
index de0c191667..faa65974d4 100644
--- a/server/src/services/person.service.ts
+++ b/server/src/services/person.service.ts
@@ -333,26 +333,28 @@ export class PersonService {
return JobStatus.SKIPPED;
}
- const faces = await this.machineLearningRepository.detectFaces(
+ if (!asset.isVisible) {
+ return JobStatus.SKIPPED;
+ }
+
+ const { imageHeight, imageWidth, faces } = await this.machineLearningRepository.detectFaces(
machineLearning.url,
- { imagePath: asset.previewPath },
+ asset.previewPath,
machineLearning.facialRecognition,
);
this.logger.debug(`${faces.length} faces detected in ${asset.previewPath}`);
- this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
if (faces.length > 0) {
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
-
const mappedFaces = faces.map((face) => ({
assetId: asset.id,
embedding: face.embedding,
- imageHeight: face.imageHeight,
- imageWidth: face.imageWidth,
+ imageHeight,
+ imageWidth,
boundingBoxX1: face.boundingBox.x1,
- boundingBoxX2: face.boundingBox.x2,
boundingBoxY1: face.boundingBox.y1,
+ boundingBoxX2: face.boundingBox.x2,
boundingBoxY2: face.boundingBox.y2,
}));
diff --git a/server/src/services/search.service.ts b/server/src/services/search.service.ts
index 8c89218138..9213cc4290 100644
--- a/server/src/services/search.service.ts
+++ b/server/src/services/search.service.ts
@@ -102,12 +102,7 @@ export class SearchService {
const userIds = await this.getUserIdsToSearch(auth);
- const embedding = await this.machineLearning.encodeText(
- machineLearning.url,
- { text: dto.query },
- machineLearning.clip,
- );
-
+ const embedding = await this.machineLearning.encodeText(machineLearning.url, dto.query, machineLearning.clip);
const page = dto.page ?? 1;
const size = dto.size || 100;
const { hasNextPage, items } = await this.searchRepository.searchSmart(
diff --git a/server/src/services/smart-info.service.spec.ts b/server/src/services/smart-info.service.spec.ts
index 7ac6dac414..95f76edc49 100644
--- a/server/src/services/smart-info.service.spec.ts
+++ b/server/src/services/smart-info.service.spec.ts
@@ -108,8 +108,8 @@ describe(SmartInfoService.name, () => {
expect(machineMock.encodeImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
- { imagePath: assetStub.image.previewPath },
- { enabled: true, modelName: 'ViT-B-32__openai' },
+ assetStub.image.previewPath,
+ expect.objectContaining({ modelName: 'ViT-B-32__openai' }),
);
expect(searchMock.upsert).toHaveBeenCalledWith(assetStub.image.id, [0.01, 0.02, 0.03]);
});
diff --git a/server/src/services/smart-info.service.ts b/server/src/services/smart-info.service.ts
index f902aa7e57..46a57c3cd0 100644
--- a/server/src/services/smart-info.service.ts
+++ b/server/src/services/smart-info.service.ts
@@ -93,9 +93,9 @@ export class SmartInfoService {
return JobStatus.FAILED;
}
- const clipEmbedding = await this.machineLearning.encodeImage(
+ const embedding = await this.machineLearning.encodeImage(
machineLearning.url,
- { imagePath: asset.previewPath },
+ asset.previewPath,
machineLearning.clip,
);
@@ -104,7 +104,7 @@ export class SmartInfoService {
await this.databaseRepository.wait(DatabaseLock.CLIPDimSize);
}
- await this.repository.upsert(asset.id, clipEmbedding);
+ await this.repository.upsert(asset.id, embedding);
return JobStatus.SUCCESS;
}