0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-21 00:52:43 -05:00

feat: preload textual model

This commit is contained in:
martabal 2024-09-25 18:22:54 +02:00
parent d34d631dd4
commit 59300d2097
No known key found for this signature in database
GPG key ID: C00196E3148A52BD
10 changed files with 59 additions and 59 deletions

View file

@ -28,7 +28,6 @@ from .schemas import (
InferenceEntries, InferenceEntries,
InferenceEntry, InferenceEntry,
InferenceResponse, InferenceResponse,
LoadModelEntry,
MessageResponse, MessageResponse,
ModelFormat, ModelFormat,
ModelIdentity, ModelIdentity,
@ -125,17 +124,16 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
raise HTTPException(422, "Invalid request format.") raise HTTPException(422, "Invalid request format.")
def get_entry(entries: str = Form()) -> LoadModelEntry: def get_entry(entries: str = Form()) -> InferenceEntry:
try: try:
request: PipelineRequest = orjson.loads(entries) request: PipelineRequest = orjson.loads(entries)
for task, types in request.items(): for task, types in request.items():
for type, entry in types.items(): for type, entry in types.items():
parsed: LoadModelEntry = { parsed: InferenceEntry = {
"name": entry["modelName"], "name": entry["modelName"],
"task": task, "task": task,
"type": type, "type": type,
"options": entry.get("options", {}), "options": entry.get("options", {}),
"ttl": entry["ttl"] if "ttl" in entry else settings.ttl,
} }
return parsed return parsed
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e: except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
@ -163,6 +161,13 @@ async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
return Response(status_code=200) return Response(status_code=200)
@app.post("/unload", response_model=TextResponse)
async def unload_model(entry: InferenceEntry = Depends(get_entry)) -> None:
await model_cache.unload(entry["name"], entry["type"], entry["task"])
print("unload")
return Response(status_code=200)
@app.post("/predict", dependencies=[Depends(update_state)]) @app.post("/predict", dependencies=[Depends(update_state)])
async def predict( async def predict(
entries: InferenceEntries = Depends(get_entries), entries: InferenceEntries = Depends(get_entries),

View file

@ -58,3 +58,10 @@ class ModelCache:
async def revalidate(self, key: str, ttl: int | None) -> None: async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers: if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl) await self.cache.expire(key, ttl)
async def unload(self, model_name: str, model_type: ModelType, model_task: ModelTask) -> None:
key = f"{model_name}{model_type}{model_task}"
async with OptimisticLock(self.cache, key):
value = await self.cache.get(key)
if value is not None:
await self.cache.delete(key)

View file

@ -109,17 +109,6 @@ class InferenceEntry(TypedDict):
options: dict[str, Any] options: dict[str, Any]
class LoadModelEntry(InferenceEntry):
ttl: int
def __init__(self, name: str, task: ModelTask, type: ModelType, options: dict[str, Any], ttl: int):
super().__init__(name=name, task=task, type=type, options=options)
if ttl <= 0:
raise ValueError("ttl must be a positive integer")
self.ttl = ttl
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]] InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]

View file

@ -5307,8 +5307,8 @@
"name": "password", "name": "password",
"required": false, "required": false,
"in": "query", "in": "query",
"example": "password",
"schema": { "schema": {
"example": "password",
"type": "string" "type": "string"
} }
}, },
@ -9510,16 +9510,10 @@
"properties": { "properties": {
"enabled": { "enabled": {
"type": "boolean" "type": "boolean"
},
"ttl": {
"format": "int64",
"minimum": 0,
"type": "number"
} }
}, },
"required": [ "required": [
"enabled", "enabled"
"ttl"
], ],
"type": "object" "type": "object"
}, },

View file

@ -122,7 +122,6 @@ export interface SystemConfig {
modelName: string; modelName: string;
loadTextualModelOnConnection: { loadTextualModelOnConnection: {
enabled: boolean; enabled: boolean;
ttl: number;
}; };
}; };
duplicateDetection: { duplicateDetection: {
@ -276,7 +275,6 @@ export const defaults = Object.freeze<SystemConfig>({
modelName: 'ViT-B-32__openai', modelName: 'ViT-B-32__openai',
loadTextualModelOnConnection: { loadTextualModelOnConnection: {
enabled: false, enabled: false,
ttl: 300,
}, },
}, },
duplicateDetection: { duplicateDetection: {

View file

@ -14,12 +14,9 @@ export class ModelConfig extends TaskConfig {
modelName!: string; modelName!: string;
} }
export class LoadTextualModelOnConnection extends TaskConfig { export class LoadTextualModelOnConnection {
@IsNumber() @ValidateBoolean()
@Min(0) enabled!: boolean;
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'int64' })
ttl!: number;
} }
export class CLIPConfig extends ModelConfig { export class CLIPConfig extends ModelConfig {

View file

@ -24,17 +24,13 @@ export type ModelPayload = { imagePath: string } | { text: string };
type ModelOptions = { modelName: string }; type ModelOptions = { modelName: string };
export interface LoadModelOptions extends ModelOptions {
ttl: number;
}
export type FaceDetectionOptions = ModelOptions & { minScore: number }; export type FaceDetectionOptions = ModelOptions & { minScore: number };
type VisualResponse = { imageHeight: number; imageWidth: number }; type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } }; export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse; export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } }; export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] }; export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
export type FacialRecognitionRequest = { export type FacialRecognitionRequest = {
@ -50,6 +46,11 @@ export interface Face {
score: number; score: number;
} }
export enum LoadTextModelActions {
LOAD,
UNLOAD,
}
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse; export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
export type DetectedFaces = { faces: Face[] } & VisualResponse; export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest; export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
@ -58,5 +59,5 @@ export interface IMachineLearningRepository {
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>; encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>; encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>; detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
loadTextModel(url: string, config: ModelOptions): Promise<void>; prepareTextModel(url: string, config: ModelOptions, action: LoadTextModelActions): Promise<void>;
} }

View file

@ -20,7 +20,7 @@ import {
ServerEventMap, ServerEventMap,
} from 'src/interfaces/event.interface'; } from 'src/interfaces/event.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface'; import { ILoggerRepository } from 'src/interfaces/logger.interface';
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface'; import { IMachineLearningRepository, LoadTextModelActions } from 'src/interfaces/machine-learning.interface';
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface'; import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
import { AuthService } from 'src/services/auth.service'; import { AuthService } from 'src/services/auth.service';
import { Instrumentation } from 'src/utils/instrumentation'; import { Instrumentation } from 'src/utils/instrumentation';
@ -79,7 +79,12 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
const { machineLearning } = await this.configCore.getConfig({ withCache: true }); const { machineLearning } = await this.configCore.getConfig({ withCache: true });
if (machineLearning.clip.loadTextualModelOnConnection.enabled) { if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
try { try {
this.machineLearningRepository.loadTextModel(machineLearning.url, machineLearning.clip); console.log(this.server);
this.machineLearningRepository.prepareTextModel(
machineLearning.url,
machineLearning.clip,
LoadTextModelActions.LOAD,
);
} catch (error) { } catch (error) {
this.logger.warn(error); this.logger.warn(error);
} }
@ -100,6 +105,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
async handleDisconnect(client: Socket) { async handleDisconnect(client: Socket) {
this.logger.log(`Websocket Disconnect: ${client.id}`); this.logger.log(`Websocket Disconnect: ${client.id}`);
await client.leave(client.nsp.name); await client.leave(client.nsp.name);
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
if (machineLearning.clip.loadTextualModelOnConnection.enabled && this.server?.engine.clientsCount == 0) {
try {
this.machineLearningRepository.prepareTextModel(
machineLearning.url,
machineLearning.clip,
LoadTextModelActions.UNLOAD,
);
this.logger.debug('sent request to unload text model');
} catch (error) {
this.logger.warn(error);
}
}
}
} }
on<T extends EmitEvent>(event: T, handler: EmitHandler<T>): void { on<T extends EmitEvent>(event: T, handler: EmitHandler<T>): void {

View file

@ -7,6 +7,7 @@ import {
FaceDetectionOptions, FaceDetectionOptions,
FacialRecognitionResponse, FacialRecognitionResponse,
IMachineLearningRepository, IMachineLearningRepository,
LoadTextModelActions,
MachineLearningRequest, MachineLearningRequest,
ModelPayload, ModelPayload,
ModelTask, ModelTask,
@ -38,11 +39,16 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return res; return res;
} }
async loadTextModel(url: string, { modelName, loadTextualModelOnConnection: { ttl } }: CLIPConfig) { private prepareTextModelUrl: Record<LoadTextModelActions, string> = {
[LoadTextModelActions.LOAD]: '/load',
[LoadTextModelActions.UNLOAD]: '/unload',
};
async prepareTextModel(url: string, { modelName }: CLIPConfig, actions: LoadTextModelActions) {
try { try {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, ttl } } }; const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
const formData = await this.getFormData(request); const formData = await this.getFormData(request);
const res = await this.fetchData(url, '/load', formData); const res = await this.fetchData(url, this.prepareTextModelUrl[actions], formData);
if (res.status >= 400) { if (res.status >= 400) {
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`); throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
} }

View file

@ -88,23 +88,6 @@
bind:checked={config.machineLearning.clip.loadTextualModelOnConnection.enabled} bind:checked={config.machineLearning.clip.loadTextualModelOnConnection.enabled}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.clip.enabled} disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.clip.enabled}
/> />
<hr />
<SettingInputField
inputType={SettingInputFieldType.NUMBER}
label={$t('admin.machine_learning_preload_model_ttl')}
bind:value={config.machineLearning.clip.loadTextualModelOnConnection.ttl}
step="1"
min={0}
desc={$t('admin.machine_learning_max_detection_distance_description')}
disabled={disabled ||
!config.machineLearning.enabled ||
!config.machineLearning.clip.enabled ||
!config.machineLearning.clip.loadTextualModelOnConnection.enabled}
isEdited={config.machineLearning.clip.loadTextualModelOnConnection.ttl !==
savedConfig.machineLearning.clip.loadTextualModelOnConnection.ttl}
/>
</div> </div>
</SettingAccordion> </SettingAccordion>
</div> </div>