mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
fix(ml): openvino not working with dynamic axes (#6871)
* convert to static * add comment about gross code * formatting * fixed test * fix typing * cleanup * formatting * Revert "formatting" This reverts commit073965c47e
. * Revert "cleanup" This reverts commitbb56bd3297
. * formatting
This commit is contained in:
parent
b768eef44d
commit
79d3342c3d
2 changed files with 51 additions and 6 deletions
|
@ -6,12 +6,15 @@ from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import onnx
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from onnx.shape_inference import infer_shapes
|
||||||
|
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||||
from typing_extensions import Buffer
|
from typing_extensions import Buffer
|
||||||
|
|
||||||
import ann.ann
|
import ann.ann
|
||||||
from app.models.constants import SUPPORTED_PROVIDERS
|
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
|
||||||
|
|
||||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||||
from ..schemas import ModelRuntime, ModelType
|
from ..schemas import ModelRuntime, ModelType
|
||||||
|
@ -114,6 +117,13 @@ class InferenceModel(ABC):
|
||||||
)
|
)
|
||||||
model_path = onnx_path
|
model_path = onnx_path
|
||||||
|
|
||||||
|
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
|
||||||
|
static_path = model_path.parent / "static_1" / "model.onnx"
|
||||||
|
static_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not static_path.is_file():
|
||||||
|
self._convert_to_static(model_path, static_path)
|
||||||
|
model_path = static_path
|
||||||
|
|
||||||
match model_path.suffix:
|
match model_path.suffix:
|
||||||
case ".armnn":
|
case ".armnn":
|
||||||
session = AnnSession(model_path)
|
session = AnnSession(model_path)
|
||||||
|
@ -128,6 +138,42 @@ class InferenceModel(ABC):
|
||||||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
|
||||||
|
inferred = infer_shapes(onnx.load(source_path))
|
||||||
|
inputs = self._get_static_dims(inferred.graph.input)
|
||||||
|
outputs = self._get_static_dims(inferred.graph.output)
|
||||||
|
|
||||||
|
# check_model gets called in update_inputs_outputs_dims and doesn't work for large models
|
||||||
|
check_model = onnx.checker.check_model
|
||||||
|
try:
|
||||||
|
|
||||||
|
def check_model_stub(*args: Any, **kwargs: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
onnx.checker.check_model = check_model_stub
|
||||||
|
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs)
|
||||||
|
finally:
|
||||||
|
onnx.checker.check_model = check_model
|
||||||
|
|
||||||
|
onnx.save(
|
||||||
|
updated_model,
|
||||||
|
target_path,
|
||||||
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=False,
|
||||||
|
size_threshold=1048576,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
|
||||||
|
return {
|
||||||
|
field.name: [
|
||||||
|
d.dim_value if d.HasField("dim_value") else dim_size
|
||||||
|
for shape in field.type.ListFields()
|
||||||
|
if (dim := shape[1].shape.dim)
|
||||||
|
for d in dim
|
||||||
|
]
|
||||||
|
for field in graph_io
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> ModelType:
|
def model_type(self) -> ModelType:
|
||||||
return self._model_type
|
return self._model_type
|
||||||
|
|
|
@ -51,11 +51,10 @@ _INSIGHTFACE_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_PROVIDERS = [
|
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
||||||
"CUDAExecutionProvider",
|
|
||||||
"OpenVINOExecutionProvider",
|
|
||||||
"CPUExecutionProvider",
|
STATIC_INPUT_PROVIDERS = ["OpenVINOExecutionProvider"]
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_openclip(model_name: str) -> bool:
|
def is_openclip(model_name: str) -> bool:
|
||||||
|
|
Loading…
Add table
Reference in a new issue