mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
feat: preload textual model
This commit is contained in:
parent
d34d631dd4
commit
59300d2097
10 changed files with 59 additions and 59 deletions
|
@ -28,7 +28,6 @@ from .schemas import (
|
||||||
InferenceEntries,
|
InferenceEntries,
|
||||||
InferenceEntry,
|
InferenceEntry,
|
||||||
InferenceResponse,
|
InferenceResponse,
|
||||||
LoadModelEntry,
|
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelIdentity,
|
ModelIdentity,
|
||||||
|
@ -125,17 +124,16 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
|
||||||
raise HTTPException(422, "Invalid request format.")
|
raise HTTPException(422, "Invalid request format.")
|
||||||
|
|
||||||
|
|
||||||
def get_entry(entries: str = Form()) -> LoadModelEntry:
|
def get_entry(entries: str = Form()) -> InferenceEntry:
|
||||||
try:
|
try:
|
||||||
request: PipelineRequest = orjson.loads(entries)
|
request: PipelineRequest = orjson.loads(entries)
|
||||||
for task, types in request.items():
|
for task, types in request.items():
|
||||||
for type, entry in types.items():
|
for type, entry in types.items():
|
||||||
parsed: LoadModelEntry = {
|
parsed: InferenceEntry = {
|
||||||
"name": entry["modelName"],
|
"name": entry["modelName"],
|
||||||
"task": task,
|
"task": task,
|
||||||
"type": type,
|
"type": type,
|
||||||
"options": entry.get("options", {}),
|
"options": entry.get("options", {}),
|
||||||
"ttl": entry["ttl"] if "ttl" in entry else settings.ttl,
|
|
||||||
}
|
}
|
||||||
return parsed
|
return parsed
|
||||||
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
|
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
|
||||||
|
@ -163,6 +161,13 @@ async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/unload", response_model=TextResponse)
|
||||||
|
async def unload_model(entry: InferenceEntry = Depends(get_entry)) -> None:
|
||||||
|
await model_cache.unload(entry["name"], entry["type"], entry["task"])
|
||||||
|
print("unload")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/predict", dependencies=[Depends(update_state)])
|
@app.post("/predict", dependencies=[Depends(update_state)])
|
||||||
async def predict(
|
async def predict(
|
||||||
entries: InferenceEntries = Depends(get_entries),
|
entries: InferenceEntries = Depends(get_entries),
|
||||||
|
|
|
@ -58,3 +58,10 @@ class ModelCache:
|
||||||
async def revalidate(self, key: str, ttl: int | None) -> None:
|
async def revalidate(self, key: str, ttl: int | None) -> None:
|
||||||
if ttl is not None and key in self.cache._handlers:
|
if ttl is not None and key in self.cache._handlers:
|
||||||
await self.cache.expire(key, ttl)
|
await self.cache.expire(key, ttl)
|
||||||
|
|
||||||
|
async def unload(self, model_name: str, model_type: ModelType, model_task: ModelTask) -> None:
|
||||||
|
key = f"{model_name}{model_type}{model_task}"
|
||||||
|
async with OptimisticLock(self.cache, key):
|
||||||
|
value = await self.cache.get(key)
|
||||||
|
if value is not None:
|
||||||
|
await self.cache.delete(key)
|
||||||
|
|
|
@ -109,17 +109,6 @@ class InferenceEntry(TypedDict):
|
||||||
options: dict[str, Any]
|
options: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class LoadModelEntry(InferenceEntry):
|
|
||||||
ttl: int
|
|
||||||
|
|
||||||
def __init__(self, name: str, task: ModelTask, type: ModelType, options: dict[str, Any], ttl: int):
|
|
||||||
super().__init__(name=name, task=task, type=type, options=options)
|
|
||||||
|
|
||||||
if ttl <= 0:
|
|
||||||
raise ValueError("ttl must be a positive integer")
|
|
||||||
self.ttl = ttl
|
|
||||||
|
|
||||||
|
|
||||||
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]
|
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5307,8 +5307,8 @@
|
||||||
"name": "password",
|
"name": "password",
|
||||||
"required": false,
|
"required": false,
|
||||||
"in": "query",
|
"in": "query",
|
||||||
|
"example": "password",
|
||||||
"schema": {
|
"schema": {
|
||||||
"example": "password",
|
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -9510,16 +9510,10 @@
|
||||||
"properties": {
|
"properties": {
|
||||||
"enabled": {
|
"enabled": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
|
||||||
"ttl": {
|
|
||||||
"format": "int64",
|
|
||||||
"minimum": 0,
|
|
||||||
"type": "number"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"enabled",
|
"enabled"
|
||||||
"ttl"
|
|
||||||
],
|
],
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
|
|
|
@ -122,7 +122,6 @@ export interface SystemConfig {
|
||||||
modelName: string;
|
modelName: string;
|
||||||
loadTextualModelOnConnection: {
|
loadTextualModelOnConnection: {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
ttl: number;
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
duplicateDetection: {
|
duplicateDetection: {
|
||||||
|
@ -276,7 +275,6 @@ export const defaults = Object.freeze<SystemConfig>({
|
||||||
modelName: 'ViT-B-32__openai',
|
modelName: 'ViT-B-32__openai',
|
||||||
loadTextualModelOnConnection: {
|
loadTextualModelOnConnection: {
|
||||||
enabled: false,
|
enabled: false,
|
||||||
ttl: 300,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
duplicateDetection: {
|
duplicateDetection: {
|
||||||
|
|
|
@ -14,12 +14,9 @@ export class ModelConfig extends TaskConfig {
|
||||||
modelName!: string;
|
modelName!: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class LoadTextualModelOnConnection extends TaskConfig {
|
export class LoadTextualModelOnConnection {
|
||||||
@IsNumber()
|
@ValidateBoolean()
|
||||||
@Min(0)
|
enabled!: boolean;
|
||||||
@Type(() => Number)
|
|
||||||
@ApiProperty({ type: 'number', format: 'int64' })
|
|
||||||
ttl!: number;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export class CLIPConfig extends ModelConfig {
|
export class CLIPConfig extends ModelConfig {
|
||||||
|
|
|
@ -24,17 +24,13 @@ export type ModelPayload = { imagePath: string } | { text: string };
|
||||||
|
|
||||||
type ModelOptions = { modelName: string };
|
type ModelOptions = { modelName: string };
|
||||||
|
|
||||||
export interface LoadModelOptions extends ModelOptions {
|
|
||||||
ttl: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
export type FaceDetectionOptions = ModelOptions & { minScore: number };
|
export type FaceDetectionOptions = ModelOptions & { minScore: number };
|
||||||
|
|
||||||
type VisualResponse = { imageHeight: number; imageWidth: number };
|
type VisualResponse = { imageHeight: number; imageWidth: number };
|
||||||
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
||||||
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
|
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
|
||||||
|
|
||||||
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } };
|
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
|
||||||
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
|
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
|
||||||
|
|
||||||
export type FacialRecognitionRequest = {
|
export type FacialRecognitionRequest = {
|
||||||
|
@ -50,6 +46,11 @@ export interface Face {
|
||||||
score: number;
|
score: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum LoadTextModelActions {
|
||||||
|
LOAD,
|
||||||
|
UNLOAD,
|
||||||
|
}
|
||||||
|
|
||||||
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
||||||
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
||||||
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||||
|
@ -58,5 +59,5 @@ export interface IMachineLearningRepository {
|
||||||
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||||
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
||||||
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||||
loadTextModel(url: string, config: ModelOptions): Promise<void>;
|
prepareTextModel(url: string, config: ModelOptions, action: LoadTextModelActions): Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ import {
|
||||||
ServerEventMap,
|
ServerEventMap,
|
||||||
} from 'src/interfaces/event.interface';
|
} from 'src/interfaces/event.interface';
|
||||||
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||||
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
import { IMachineLearningRepository, LoadTextModelActions } from 'src/interfaces/machine-learning.interface';
|
||||||
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
|
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
|
||||||
import { AuthService } from 'src/services/auth.service';
|
import { AuthService } from 'src/services/auth.service';
|
||||||
import { Instrumentation } from 'src/utils/instrumentation';
|
import { Instrumentation } from 'src/utils/instrumentation';
|
||||||
|
@ -79,7 +79,12 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
||||||
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
|
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
|
||||||
if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
|
if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
|
||||||
try {
|
try {
|
||||||
this.machineLearningRepository.loadTextModel(machineLearning.url, machineLearning.clip);
|
console.log(this.server);
|
||||||
|
this.machineLearningRepository.prepareTextModel(
|
||||||
|
machineLearning.url,
|
||||||
|
machineLearning.clip,
|
||||||
|
LoadTextModelActions.LOAD,
|
||||||
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.logger.warn(error);
|
this.logger.warn(error);
|
||||||
}
|
}
|
||||||
|
@ -100,6 +105,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
||||||
async handleDisconnect(client: Socket) {
|
async handleDisconnect(client: Socket) {
|
||||||
this.logger.log(`Websocket Disconnect: ${client.id}`);
|
this.logger.log(`Websocket Disconnect: ${client.id}`);
|
||||||
await client.leave(client.nsp.name);
|
await client.leave(client.nsp.name);
|
||||||
|
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
|
||||||
|
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
|
||||||
|
if (machineLearning.clip.loadTextualModelOnConnection.enabled && this.server?.engine.clientsCount == 0) {
|
||||||
|
try {
|
||||||
|
this.machineLearningRepository.prepareTextModel(
|
||||||
|
machineLearning.url,
|
||||||
|
machineLearning.clip,
|
||||||
|
LoadTextModelActions.UNLOAD,
|
||||||
|
);
|
||||||
|
this.logger.debug('sent request to unload text model');
|
||||||
|
} catch (error) {
|
||||||
|
this.logger.warn(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
on<T extends EmitEvent>(event: T, handler: EmitHandler<T>): void {
|
on<T extends EmitEvent>(event: T, handler: EmitHandler<T>): void {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import {
|
||||||
FaceDetectionOptions,
|
FaceDetectionOptions,
|
||||||
FacialRecognitionResponse,
|
FacialRecognitionResponse,
|
||||||
IMachineLearningRepository,
|
IMachineLearningRepository,
|
||||||
|
LoadTextModelActions,
|
||||||
MachineLearningRequest,
|
MachineLearningRequest,
|
||||||
ModelPayload,
|
ModelPayload,
|
||||||
ModelTask,
|
ModelTask,
|
||||||
|
@ -38,11 +39,16 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
async loadTextModel(url: string, { modelName, loadTextualModelOnConnection: { ttl } }: CLIPConfig) {
|
private prepareTextModelUrl: Record<LoadTextModelActions, string> = {
|
||||||
|
[LoadTextModelActions.LOAD]: '/load',
|
||||||
|
[LoadTextModelActions.UNLOAD]: '/unload',
|
||||||
|
};
|
||||||
|
|
||||||
|
async prepareTextModel(url: string, { modelName }: CLIPConfig, actions: LoadTextModelActions) {
|
||||||
try {
|
try {
|
||||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, ttl } } };
|
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||||
const formData = await this.getFormData(request);
|
const formData = await this.getFormData(request);
|
||||||
const res = await this.fetchData(url, '/load', formData);
|
const res = await this.fetchData(url, this.prepareTextModelUrl[actions], formData);
|
||||||
if (res.status >= 400) {
|
if (res.status >= 400) {
|
||||||
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
|
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,23 +88,6 @@
|
||||||
bind:checked={config.machineLearning.clip.loadTextualModelOnConnection.enabled}
|
bind:checked={config.machineLearning.clip.loadTextualModelOnConnection.enabled}
|
||||||
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.clip.enabled}
|
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.clip.enabled}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<hr />
|
|
||||||
|
|
||||||
<SettingInputField
|
|
||||||
inputType={SettingInputFieldType.NUMBER}
|
|
||||||
label={$t('admin.machine_learning_preload_model_ttl')}
|
|
||||||
bind:value={config.machineLearning.clip.loadTextualModelOnConnection.ttl}
|
|
||||||
step="1"
|
|
||||||
min={0}
|
|
||||||
desc={$t('admin.machine_learning_max_detection_distance_description')}
|
|
||||||
disabled={disabled ||
|
|
||||||
!config.machineLearning.enabled ||
|
|
||||||
!config.machineLearning.clip.enabled ||
|
|
||||||
!config.machineLearning.clip.loadTextualModelOnConnection.enabled}
|
|
||||||
isEdited={config.machineLearning.clip.loadTextualModelOnConnection.ttl !==
|
|
||||||
savedConfig.machineLearning.clip.loadTextualModelOnConnection.ttl}
|
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
</SettingAccordion>
|
</SettingAccordion>
|
||||||
</div>
|
</div>
|
||||||
|
|
Loading…
Add table
Reference in a new issue