2023-06-06 21:48:51 -04:00
|
|
|
import torch
|
|
|
|
from insightface.app import FaceAnalysis
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from transformers import pipeline, Pipeline
|
|
|
|
from sentence_transformers import SentenceTransformer
|
2023-06-17 22:49:19 -05:00
|
|
|
from typing import Any, BinaryIO
|
2023-06-06 21:48:51 -04:00
|
|
|
import cv2 as cv
|
2023-06-17 22:49:19 -05:00
|
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
|
|
from config import settings
|
2023-06-06 21:48:51 -04:00
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(model_name: str, model_type: str, **model_kwargs):
|
|
|
|
"""
|
|
|
|
Instantiates the specified model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_name: Name of model in the model hub used for the task.
|
|
|
|
model_type: Model type or task, which determines which model zoo is used.
|
|
|
|
`facial-recognition` uses Insightface, while all other models use the HF Model Hub.
|
|
|
|
|
|
|
|
Options:
|
|
|
|
`image-classification`, `clip`,`facial-recognition`, `tokenizer`, `processor`
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
model: The requested model.
|
|
|
|
"""
|
|
|
|
|
|
|
|
cache_dir = _get_cache_dir(model_name, model_type)
|
|
|
|
match model_type:
|
|
|
|
case "facial-recognition":
|
|
|
|
model = _load_facial_recognition(
|
|
|
|
model_name, cache_dir=cache_dir, **model_kwargs
|
|
|
|
)
|
|
|
|
case "clip":
|
|
|
|
model = SentenceTransformer(
|
|
|
|
model_name, cache_folder=cache_dir, **model_kwargs
|
|
|
|
)
|
|
|
|
case _:
|
|
|
|
model = pipeline(
|
|
|
|
model_type,
|
|
|
|
model_name,
|
|
|
|
model_kwargs={"cache_dir": cache_dir, **model_kwargs},
|
|
|
|
)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def run_classification(
|
2023-06-17 22:49:19 -05:00
|
|
|
model: Pipeline, image: Image, min_score: float | None = None
|
2023-06-06 21:48:51 -04:00
|
|
|
):
|
2023-06-17 22:49:19 -05:00
|
|
|
predictions: list[dict[str, Any]] = model(image) # type: ignore
|
2023-06-06 21:48:51 -04:00
|
|
|
result = {
|
|
|
|
tag
|
|
|
|
for pred in predictions
|
|
|
|
for tag in pred["label"].split(", ")
|
|
|
|
if min_score is None or pred["score"] >= min_score
|
|
|
|
}
|
|
|
|
|
|
|
|
return list(result)
|
|
|
|
|
|
|
|
|
|
|
|
def run_facial_recognition(
|
2023-06-17 22:49:19 -05:00
|
|
|
model: FaceAnalysis, image: bytes
|
2023-06-06 21:48:51 -04:00
|
|
|
) -> list[dict[str, Any]]:
|
2023-06-17 22:49:19 -05:00
|
|
|
file_bytes = np.frombuffer(image, dtype=np.uint8)
|
|
|
|
img = cv.imdecode(file_bytes, cv.IMREAD_COLOR)
|
2023-06-06 21:48:51 -04:00
|
|
|
height, width, _ = img.shape
|
|
|
|
results = []
|
|
|
|
faces = model.get(img)
|
|
|
|
|
|
|
|
for face in faces:
|
|
|
|
x1, y1, x2, y2 = face.bbox
|
|
|
|
|
|
|
|
results.append(
|
|
|
|
{
|
|
|
|
"imageWidth": width,
|
|
|
|
"imageHeight": height,
|
|
|
|
"boundingBox": {
|
|
|
|
"x1": round(x1),
|
|
|
|
"y1": round(y1),
|
|
|
|
"x2": round(x2),
|
|
|
|
"y2": round(y2),
|
|
|
|
},
|
|
|
|
"score": face.det_score.item(),
|
|
|
|
"embedding": face.normed_embedding.tolist(),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def _load_facial_recognition(
|
|
|
|
model_name: str,
|
|
|
|
min_face_score: float | None = None,
|
|
|
|
cache_dir: Path | str | None = None,
|
|
|
|
**model_kwargs,
|
|
|
|
):
|
|
|
|
if cache_dir is None:
|
|
|
|
cache_dir = _get_cache_dir(model_name, "facial-recognition")
|
|
|
|
if isinstance(cache_dir, Path):
|
|
|
|
cache_dir = cache_dir.as_posix()
|
|
|
|
if min_face_score is None:
|
2023-06-17 22:49:19 -05:00
|
|
|
min_face_score = settings.min_face_score
|
2023-06-06 21:48:51 -04:00
|
|
|
|
|
|
|
model = FaceAnalysis(
|
|
|
|
name=model_name,
|
|
|
|
root=cache_dir,
|
|
|
|
allowed_modules=["detection", "recognition"],
|
|
|
|
**model_kwargs,
|
|
|
|
)
|
|
|
|
model.prepare(ctx_id=0, det_thresh=min_face_score, det_size=(640, 640))
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def _get_cache_dir(model_name: str, model_type: str) -> Path:
|
2023-06-17 22:49:19 -05:00
|
|
|
return Path(settings.cache_folder, device, model_type, model_name)
|