mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
chore(ml): updated dockerfile, added typing, packaging (#2642)
* updated dockerfile, added typing, packaging apply env change * added arm64 support * added ml version pump, second try for arm64 * added linting config to pyproject.toml * renamed ml input field * fixed linter config * fixed dev docker compose
This commit is contained in:
parent
c92c442356
commit
1e748864c5
13 changed files with 2647 additions and 67 deletions
|
@ -35,7 +35,7 @@ services:
|
|||
ports:
|
||||
- 3003:3003
|
||||
volumes:
|
||||
- ../machine-learning/src:/usr/src/app
|
||||
- ../machine-learning/app:/usr/src/app
|
||||
- ${UPLOAD_LOCATION}:/usr/src/app/upload
|
||||
- model-cache:/cache
|
||||
env_file:
|
||||
|
|
|
@ -1,29 +1,26 @@
|
|||
FROM python:3.10 as builder
|
||||
|
||||
FROM python:3.11 as builder
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PIP_NO_CACHE_DIR=true
|
||||
|
||||
RUN pip install --upgrade pip && pip install poetry
|
||||
RUN poetry config installer.max-workers 10 && \
|
||||
poetry config virtualenvs.create false
|
||||
RUN python -m venv /opt/venv
|
||||
RUN /opt/venv/bin/pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN /opt/venv/bin/pip install transformers tqdm numpy scikit-learn scipy nltk sentencepiece fastapi Pillow uvicorn[standard]
|
||||
RUN /opt/venv/bin/pip install --no-deps sentence-transformers
|
||||
# Facial Recognition Stuff
|
||||
RUN /opt/venv/bin/pip install insightface onnxruntime
|
||||
ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}"
|
||||
|
||||
FROM python:3.10-slim
|
||||
COPY poetry.lock pyproject.toml ./
|
||||
RUN poetry install --sync --no-interaction --no-ansi --no-root --only main
|
||||
|
||||
ENV NODE_ENV=production
|
||||
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
|
||||
ENV TRANSFORMERS_CACHE=/cache \
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /usr/src/app
|
||||
ENV NODE_ENV=production \
|
||||
TRANSFORMERS_CACHE=/cache \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
PYTHONPATH=`pwd`
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
COPY . .
|
||||
ENV PYTHONPATH=`pwd`
|
||||
CMD ["python", "src/main.py"]
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
COPY app .
|
||||
ENTRYPOINT ["python", "main.py"]
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
|
||||
# Immich Machine Learning
|
||||
|
||||
- Object Detection
|
||||
- Image Classification
|
||||
- Image classification
|
||||
- CLIP embeddings
|
||||
- Facial recognition
|
||||
|
||||
# Setup
|
||||
|
||||
This project uses [Poetry](https://python-poetry.org/docs/#installation), so be sure to install it first.
|
||||
Running `poetry install --no-root --with dev` will install everything you need in an isolated virtual environment.
|
||||
|
||||
To add or remove dependencies, you can use the commands `poetry add $PACKAGE_NAME` and `poetry remove $PACKAGE_NAME`, respectively.
|
||||
Be sure to commit the `poetry.lock` and `pyproject.toml` files to reflect any changes in dependencies.
|
||||
|
|
|
@ -1,22 +1,23 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
from schemas import (
|
||||
EmbeddingResponse,
|
||||
FaceResponse,
|
||||
TagResponse,
|
||||
MessageResponse,
|
||||
TextModelRequest,
|
||||
TextResponse,
|
||||
VisionModelRequest,
|
||||
)
|
||||
import cv2 as cv
|
||||
import uvicorn
|
||||
|
||||
from insightface.app import FaceAnalysis
|
||||
from transformers import pipeline
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import Pipeline
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MlRequestBody(BaseModel):
|
||||
thumbnailPath: str
|
||||
|
||||
|
||||
class ClipRequestBody(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
classification_model = os.getenv(
|
||||
|
@ -42,7 +43,7 @@ app = FastAPI()
|
|||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
async def startup_event() -> None:
|
||||
models = [
|
||||
(classification_model, "image-classification"),
|
||||
(clip_image_model, "clip"),
|
||||
|
@ -58,42 +59,51 @@ async def startup_event():
|
|||
_get_model(model_name, model_type)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
@app.get("/", response_model=MessageResponse)
|
||||
async def root() -> dict[str, str]:
|
||||
return {"message": "Immich ML"}
|
||||
|
||||
|
||||
@app.get("/ping")
|
||||
def ping():
|
||||
@app.get("/ping", response_model=TextResponse)
|
||||
def ping() -> str:
|
||||
return "pong"
|
||||
|
||||
|
||||
@app.post("/image-classifier/tag-image", status_code=200)
|
||||
def image_classification(payload: MlRequestBody):
|
||||
@app.post("/image-classifier/tag-image", response_model=TagResponse, status_code=200)
|
||||
def image_classification(payload: VisionModelRequest) -> list[str]:
|
||||
model = get_cached_model(classification_model, "image-classification")
|
||||
assetPath = payload.thumbnailPath
|
||||
return run_engine(model, assetPath)
|
||||
assetPath = payload.image_path
|
||||
labels = run_engine(model, assetPath)
|
||||
return labels
|
||||
|
||||
|
||||
@app.post("/sentence-transformer/encode-image", status_code=200)
|
||||
def clip_encode_image(payload: MlRequestBody):
|
||||
@app.post(
|
||||
"/sentence-transformer/encode-image",
|
||||
response_model=EmbeddingResponse,
|
||||
status_code=200,
|
||||
)
|
||||
def clip_encode_image(payload: VisionModelRequest) -> list[float]:
|
||||
model = get_cached_model(clip_image_model, "clip")
|
||||
assetPath = payload.thumbnailPath
|
||||
return model.encode(Image.open(assetPath)).tolist()
|
||||
image = Image.open(payload.image_path)
|
||||
return model.encode(image).tolist()
|
||||
|
||||
|
||||
@app.post("/sentence-transformer/encode-text", status_code=200)
|
||||
def clip_encode_text(payload: ClipRequestBody):
|
||||
@app.post(
|
||||
"/sentence-transformer/encode-text",
|
||||
response_model=EmbeddingResponse,
|
||||
status_code=200,
|
||||
)
|
||||
def clip_encode_text(payload: TextModelRequest) -> list[float]:
|
||||
model = get_cached_model(clip_text_model, "clip")
|
||||
text = payload.text
|
||||
return model.encode(text).tolist()
|
||||
return model.encode(payload.text).tolist()
|
||||
|
||||
|
||||
@app.post("/facial-recognition/detect-faces", status_code=200)
|
||||
def facial_recognition(payload: MlRequestBody):
|
||||
@app.post(
|
||||
"/facial-recognition/detect-faces", response_model=FaceResponse, status_code=200
|
||||
)
|
||||
def facial_recognition(payload: VisionModelRequest) -> list[dict[str, Any]]:
|
||||
model = get_cached_model(facial_recognition_model, "facial-recognition")
|
||||
assetPath = payload.thumbnailPath
|
||||
img = cv.imread(assetPath)
|
||||
img = cv.imread(payload.image_path)
|
||||
height, width, _ = img.shape
|
||||
results = []
|
||||
faces = model.get(img)
|
||||
|
@ -120,11 +130,11 @@ def facial_recognition(payload: MlRequestBody):
|
|||
return results
|
||||
|
||||
|
||||
def run_engine(engine, path):
|
||||
result = []
|
||||
predictions = engine(path)
|
||||
def run_engine(engine: Pipeline, path: str) -> list[str]:
|
||||
result: list[str] = []
|
||||
predictions: list[dict[str, Any]] = engine(path) # type: ignore
|
||||
|
||||
for index, pred in enumerate(predictions):
|
||||
for pred in predictions:
|
||||
tags = pred["label"].split(", ")
|
||||
if pred["score"] > min_tag_score:
|
||||
result = [*result, *tags]
|
||||
|
@ -135,7 +145,7 @@ def run_engine(engine, path):
|
|||
return result
|
||||
|
||||
|
||||
def get_cached_model(model, task):
|
||||
def get_cached_model(model, task) -> Any:
|
||||
global _model_cache
|
||||
key = "|".join([model, str(task)])
|
||||
if key not in _model_cache:
|
||||
|
@ -145,7 +155,7 @@ def get_cached_model(model, task):
|
|||
return _model_cache[key]
|
||||
|
||||
|
||||
def _get_model(model, task):
|
||||
def _get_model(model, task) -> Any:
|
||||
match task:
|
||||
case "facial-recognition":
|
||||
model = FaceAnalysis(
|
64
machine-learning/app/schemas.py
Normal file
64
machine-learning/app/schemas.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def to_lower_camel(string: str) -> str:
|
||||
tokens = [
|
||||
token.capitalize() if i > 0 else token
|
||||
for i, token in enumerate(string.split("_"))
|
||||
]
|
||||
return "".join(tokens)
|
||||
|
||||
|
||||
class VisionModelRequest(BaseModel):
|
||||
image_path: str
|
||||
|
||||
class Config:
|
||||
alias_generator = to_lower_camel
|
||||
allow_population_by_field_name = True
|
||||
|
||||
|
||||
class TextModelRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
__root__: str
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class TagResponse(BaseModel):
|
||||
__root__: list[str]
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
__root__: list[float]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
__root__: Embedding
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
|
||||
class Face(BaseModel):
|
||||
image_width: int
|
||||
image_height: int
|
||||
bounding_box: BoundingBox
|
||||
score: float
|
||||
embedding: Embedding
|
||||
|
||||
class Config:
|
||||
alias_generator = to_lower_camel
|
||||
allow_population_by_field_name = True
|
||||
|
||||
|
||||
class FaceResponse(BaseModel):
|
||||
__root__: list[Face]
|
2444
machine-learning/poetry.lock
generated
Normal file
2444
machine-learning/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
56
machine-learning/pyproject.toml
Normal file
56
machine-learning/pyproject.toml
Normal file
|
@ -0,0 +1,56 @@
|
|||
[tool.poetry]
|
||||
name = "machine-learning"
|
||||
version = "1.59.1"
|
||||
description = ""
|
||||
authors = ["Hau Tran <alex.tran1502@gmail.com>"]
|
||||
readme = "README.md"
|
||||
packages = [{include = "app"}]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
torch = [
|
||||
{markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=2.0.1", source = "pypi"},
|
||||
{markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1+cpu", source = "pytorch-cpu"}
|
||||
]
|
||||
transformers = "^4.29.2"
|
||||
sentence-transformers = "^2.2.2"
|
||||
onnxruntime = "^1.15.0"
|
||||
insightface = "^0.7.3"
|
||||
opencv-python-headless = "^4.7.0.72"
|
||||
pillow = "^9.5.0"
|
||||
fastapi = "^0.95.2"
|
||||
uvicorn = {extras = ["standard"], version = "^0.22.0"}
|
||||
pydantic = "^1.10.8"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
mypy = "^1.3.0"
|
||||
black = "^23.3.0"
|
||||
pytest = "^7.3.1"
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
priority = "explicit"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.flake8]
|
||||
max-line-length = 120
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
plugins = "pydantic.mypy"
|
||||
follow_imports = "silent"
|
||||
warn_redundant_casts = true
|
||||
disallow_any_generics = true
|
||||
check_untyped_defs = true
|
||||
no_implicit_reexport = true
|
||||
disallow_untyped_defs = true
|
||||
|
||||
[tool.pydantic-mypy]
|
||||
init_forbid_extra = true
|
||||
init_typed = true
|
||||
warn_required_dynamic_aliases = true
|
||||
warn_untyped_fields = true
|
|
@ -63,6 +63,7 @@ if [ "$CURRENT_SERVER" != "$NEXT_SERVER" ]; then
|
|||
echo "Pumping Server: $CURRENT_SERVER => $NEXT_SERVER"
|
||||
npm --prefix server version $SERVER_PUMP
|
||||
npm --prefix server run api:generate
|
||||
poetry --directory machine-learning version $SERVER_PUMP
|
||||
fi
|
||||
|
||||
if [ "$CURRENT_MOBILE" != "$NEXT_MOBILE" ]; then
|
||||
|
|
|
@ -175,7 +175,7 @@ describe(FacialRecognitionService.name, () => {
|
|||
assetMock.getByIds.mockResolvedValue([assetEntityStub.image]);
|
||||
await sut.handleRecognizeFaces({ id: assetEntityStub.image.id });
|
||||
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith({
|
||||
thumbnailPath: assetEntityStub.image.resizePath,
|
||||
imagePath: assetEntityStub.image.resizePath,
|
||||
});
|
||||
expect(faceMock.create).not.toHaveBeenCalled();
|
||||
expect(jobMock.queue).not.toHaveBeenCalled();
|
||||
|
|
|
@ -54,7 +54,7 @@ export class FacialRecognitionService {
|
|||
return false;
|
||||
}
|
||||
|
||||
const faces = await this.machineLearning.detectFaces({ thumbnailPath: asset.resizePath });
|
||||
const faces = await this.machineLearning.detectFaces({ imagePath: asset.resizePath });
|
||||
|
||||
this.logger.debug(`${faces.length} faces detected in ${asset.resizePath}`);
|
||||
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `float[${face.embedding.length}]` })));
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
export const IMachineLearningRepository = 'IMachineLearningRepository';
|
||||
|
||||
export interface MachineLearningInput {
|
||||
thumbnailPath: string;
|
||||
imagePath: string;
|
||||
}
|
||||
|
||||
export interface BoundingBox {
|
||||
|
|
|
@ -84,7 +84,7 @@ describe(SmartInfoService.name, () => {
|
|||
|
||||
await sut.handleClassifyImage({ id: asset.id });
|
||||
|
||||
expect(machineMock.classifyImage).toHaveBeenCalledWith({ thumbnailPath: 'path/to/resize.ext' });
|
||||
expect(machineMock.classifyImage).toHaveBeenCalledWith({ imagePath: 'path/to/resize.ext' });
|
||||
expect(smartMock.upsert).toHaveBeenCalledWith({
|
||||
assetId: 'asset-1',
|
||||
tags: ['tag1', 'tag2', 'tag3'],
|
||||
|
@ -143,7 +143,7 @@ describe(SmartInfoService.name, () => {
|
|||
|
||||
await sut.handleEncodeClip({ id: asset.id });
|
||||
|
||||
expect(machineMock.encodeImage).toHaveBeenCalledWith({ thumbnailPath: 'path/to/resize.ext' });
|
||||
expect(machineMock.encodeImage).toHaveBeenCalledWith({ imagePath: 'path/to/resize.ext' });
|
||||
expect(smartMock.upsert).toHaveBeenCalledWith({
|
||||
assetId: 'asset-1',
|
||||
clipEmbedding: [0.01, 0.02, 0.03],
|
||||
|
|
|
@ -40,7 +40,7 @@ export class SmartInfoService {
|
|||
return false;
|
||||
}
|
||||
|
||||
const tags = await this.machineLearning.classifyImage({ thumbnailPath: asset.resizePath });
|
||||
const tags = await this.machineLearning.classifyImage({ imagePath: asset.resizePath });
|
||||
if (tags.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
@ -73,7 +73,7 @@ export class SmartInfoService {
|
|||
return false;
|
||||
}
|
||||
|
||||
const clipEmbedding = await this.machineLearning.encodeImage({ thumbnailPath: asset.resizePath });
|
||||
const clipEmbedding = await this.machineLearning.encodeImage({ imagePath: asset.resizePath });
|
||||
await this.repository.upsert({ assetId: asset.id, clipEmbedding: clipEmbedding });
|
||||
|
||||
return true;
|
||||
|
|
Loading…
Add table
Reference in a new issue