mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
fix(ml): batch axis not being added for recognition model (#12588)
* fix has_batch_axis * fix typing
This commit is contained in:
parent
fa095c3ca0
commit
22dc9bcebb
2 changed files with 1 additions and 7 deletions
|
@ -13,7 +13,6 @@ from app.config import log
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
from app.models.transforms import decode_cv2
|
from app.models.transforms import decode_cv2
|
||||||
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
|
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
|
||||||
from app.sessions import has_batch_axis
|
|
||||||
|
|
||||||
|
|
||||||
class FaceRecognizer(InferenceModel):
|
class FaceRecognizer(InferenceModel):
|
||||||
|
@ -27,7 +26,7 @@ class FaceRecognizer(InferenceModel):
|
||||||
|
|
||||||
def _load(self) -> ModelSession:
|
def _load(self) -> ModelSession:
|
||||||
session = self._make_session(self.model_path)
|
session = self._make_session(self.model_path)
|
||||||
if self.batch and not has_batch_axis(session):
|
if self.batch and str(session.get_inputs()[0].shape[0]) != "batch":
|
||||||
self._add_batch_axis(self.model_path)
|
self._add_batch_axis(self.model_path)
|
||||||
session = self._make_session(self.model_path)
|
session = self._make_session(self.model_path)
|
||||||
self.model = ArcFaceONNX(
|
self.model = ArcFaceONNX(
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
from app.schemas import ModelSession
|
|
||||||
|
|
||||||
|
|
||||||
def has_batch_axis(session: ModelSession) -> bool:
|
|
||||||
return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0
|
|
Loading…
Add table
Reference in a new issue