0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-02-04 01:09:14 -05:00
This commit is contained in:
mertalev 2024-07-11 19:12:55 -04:00
parent 3db69b94ed
commit 72269ab58c
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
13 changed files with 2185 additions and 689 deletions

View file

@ -1,28 +1,35 @@
FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a as builder
ENV TRANSFORMERS_CACHE=/cache \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PATH="/opt/venv/bin:$PATH"
WORKDIR /export/ann
USER root
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
curl \
git
USER $MAMBA_USER
COPY --chown=$MAMBA_USER:$MAMBA_USER env.yaml ./
RUN micromamba install -y -f env.yaml
COPY --chown=$MAMBA_USER:$MAMBA_USER *.sh *.cpp ./
ENV ARMNN_PATH=/export/ann/armnn
WORKDIR /home/mambauser
ENV ARMNN_PATH=armnn
COPY --chown=$MAMBA_USER:$MAMBA_USER scripts/* .
RUN ./download-armnn.sh && \
./build-converter.sh && \
./build.sh
COPY --chown=$MAMBA_USER:$MAMBA_USER run.py ./
./build-converter.sh && \
./build.sh
ENTRYPOINT ["/usr/local/bin/_entrypoint.sh"]
CMD ["python", "run.py"]
COPY --chown=$MAMBA_USER:$MAMBA_USER conda-lock.yml .
RUN micromamba create -y -p /home/mambauser/venv -f conda-lock.yml && \
micromamba clean --all --yes
ENV PATH="/home/mambauser/venv/bin:${PATH}"
FROM gcr.io/distroless/base-debian12
# FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a
WORKDIR /export/ann
ENV PYTHONDONTWRITEBYTECODE=1 \
LD_LIBRARY_PATH=/export/ann/armnn \
PATH="/opt/venv/bin:${PATH}"
COPY --from=builder /home/mambauser/armnnconverter /home/mambauser/armnn ./
COPY --from=builder /home/mambauser/venv /opt/venv
COPY --chown=$MAMBA_USER:$MAMBA_USER onnx2ann onnx2ann
ENTRYPOINT ["python", "-m", "onnx2ann"]

File diff suppressed because it is too large Load diff

View file

@ -1,201 +1,21 @@
name: annexport
name: onnx2ann
channels:
- pytorch
- nvidia
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_kmp_llvm
- aiohttp=3.9.1=py310h2372a71_0
- aiosignal=1.3.1=pyhd8ed1ab_0
- arpack=3.8.0=nompi_h0baa96a_101
- async-timeout=4.0.3=pyhd8ed1ab_0
- attrs=23.1.0=pyh71513ae_1
- aws-c-auth=0.7.3=h28f7589_1
- aws-c-cal=0.6.1=hc309b26_1
- aws-c-common=0.9.0=hd590300_0
- aws-c-compression=0.2.17=h4d4d85c_2
- aws-c-event-stream=0.3.1=h2e3709c_4
- aws-c-http=0.7.11=h00aa349_4
- aws-c-io=0.13.32=he9a53bd_1
- aws-c-mqtt=0.9.3=hb447be9_1
- aws-c-s3=0.3.14=hf3aad02_1
- aws-c-sdkutils=0.1.12=h4d4d85c_1
- aws-checksums=0.1.17=h4d4d85c_1
- aws-crt-cpp=0.21.0=hb942446_5
- aws-sdk-cpp=1.10.57=h85b1a90_19
- blas=2.120=openblas
- blas-devel=3.9.0=20_linux64_openblas
- brotli-python=1.0.9=py310hd8f1fbe_9
- bzip2=1.0.8=hd590300_5
- c-ares=1.23.0=hd590300_0
- ca-certificates=2023.11.17=hbcca054_0
- certifi=2023.11.17=pyhd8ed1ab_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- click=8.1.7=unix_pyh707e725_0
- colorama=0.4.6=pyhd8ed1ab_0
- coloredlogs=15.0.1=pyhd8ed1ab_3
- cuda-cudart=11.7.99=0
- cuda-cupti=11.7.101=0
- cuda-libraries=11.7.1=0
- cuda-nvrtc=11.7.99=0
- cuda-nvtx=11.7.91=0
- cuda-runtime=11.7.1=0
- dataclasses=0.8=pyhc8e2a94_3
- datasets=2.14.7=pyhd8ed1ab_0
- dill=0.3.7=pyhd8ed1ab_0
- filelock=3.13.1=pyhd8ed1ab_0
- flatbuffers=23.5.26=h59595ed_1
- freetype=2.12.1=h267a509_2
- frozenlist=1.4.0=py310h2372a71_1
- fsspec=2023.10.0=pyhca7485f_0
- ftfy=6.1.3=pyhd8ed1ab_0
- gflags=2.2.2=he1b5a44_1004
- glog=0.6.0=h6f12383_0
- glpk=5.0=h445213a_0
- gmp=6.3.0=h59595ed_0
- gmpy2=2.1.2=py310h3ec546c_1
- huggingface_hub=0.17.3=pyhd8ed1ab_0
- humanfriendly=10.0=pyhd8ed1ab_6
- icu=73.2=h59595ed_0
- idna=3.6=pyhd8ed1ab_0
- importlib-metadata=7.0.0=pyha770c72_0
- importlib_metadata=7.0.0=hd8ed1ab_0
- joblib=1.3.2=pyhd8ed1ab_0
- keyutils=1.6.1=h166bdaf_0
- krb5=1.21.2=h659d440_0
- lcms2=2.15=h7f713cb_2
- ld_impl_linux-64=2.40=h41732ed_0
- lerc=4.0.0=h27087fc_0
- libabseil=20230125.3=cxx17_h59595ed_0
- libarrow=12.0.1=hb87d912_8_cpu
- libblas=3.9.0=20_linux64_openblas
- libbrotlicommon=1.0.9=h166bdaf_9
- libbrotlidec=1.0.9=h166bdaf_9
- libbrotlienc=1.0.9=h166bdaf_9
- libcblas=3.9.0=20_linux64_openblas
- libcrc32c=1.1.2=h9c3ff4c_0
- libcublas=11.10.3.66=0
- libcufft=10.7.2.124=h4fbf590_0
- libcufile=1.8.1.2=0
- libcurand=10.3.4.101=0
- libcurl=8.5.0=hca28451_0
- libcusolver=11.4.0.1=0
- libcusparse=11.7.4.91=0
- libdeflate=1.19=hd590300_0
- libedit=3.1.20191231=he28a2e2_2
- libev=4.33=hd590300_2
- libevent=2.1.12=hf998b51_1
- libffi=3.4.2=h7f98852_5
- libgcc-ng=13.2.0=h807b86a_3
- libgfortran-ng=13.2.0=h69a702a_3
- libgfortran5=13.2.0=ha4646dd_3
- libgoogle-cloud=2.12.0=hac9eb74_1
- libgrpc=1.54.3=hb20ce57_0
- libhwloc=2.9.3=default_h554bfaf_1009
- libiconv=1.17=hd590300_1
- libjpeg-turbo=2.1.5.1=hd590300_1
- liblapack=3.9.0=20_linux64_openblas
- liblapacke=3.9.0=20_linux64_openblas
- libnghttp2=1.58.0=h47da74e_1
- libnpp=11.7.4.75=0
- libnsl=2.0.1=hd590300_0
- libnuma=2.0.16=h0b41bf4_1
- libnvjpeg=11.8.0.2=0
- libopenblas=0.3.25=pthreads_h413a1c8_0
- libpng=1.6.39=h753d276_0
- libprotobuf=3.21.12=hfc55251_2
- libsentencepiece=0.1.99=h180e1df_0
- libsqlite=3.44.2=h2797004_0
- libssh2=1.11.0=h0841786_0
- libstdcxx-ng=13.2.0=h7e041cc_3
- libthrift=0.18.1=h8fd135c_2
- libtiff=4.6.0=h29866fb_1
- libutf8proc=2.8.0=h166bdaf_0
- libuuid=2.38.1=h0b41bf4_0
- libwebp-base=1.3.2=hd590300_0
- libxcb=1.15=h0b41bf4_0
- libxml2=2.11.6=h232c23b_0
- libzlib=1.2.13=hd590300_5
- llvm-openmp=17.0.6=h4dfa4b3_0
- lz4-c=1.9.4=hcb278e6_0
- mkl=2022.2.1=h84fe81f_16997
- mkl-devel=2022.2.1=ha770c72_16998
- mkl-include=2022.2.1=h84fe81f_16997
- mpc=1.3.1=hfe3b2da_0
- mpfr=4.2.1=h9458935_0
- mpmath=1.3.0=pyhd8ed1ab_0
- multidict=6.0.4=py310h2372a71_1
- multiprocess=0.70.15=py310h2372a71_1
- ncurses=6.4=h59595ed_2
- numpy=1.26.2=py310hb13e2d6_0
- onnx=1.14.0=py310ha3deec4_1
- onnx2torch=1.5.13=pyhd8ed1ab_0
- onnxruntime=1.16.3=py310hd4b7fbc_1_cpu
- open-clip-torch=2.23.0=pyhd8ed1ab_1
- openblas=0.3.25=pthreads_h7a3da1a_0
- openjpeg=2.5.0=h488ebb8_3
- openssl=3.2.0=hd590300_1
- orc=1.9.0=h2f23424_1
- packaging=23.2=pyhd8ed1ab_0
- pandas=2.1.4=py310hcc13569_0
- pillow=10.0.1=py310h29da1c1_1
- pip=23.3.1=pyhd8ed1ab_0
- protobuf=4.21.12=py310heca2aa9_0
- pthread-stubs=0.4=h36c2ea0_1001
- pyarrow=12.0.1=py310h0576679_8_cpu
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.10.13=hd12c33a_0_cpython
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python-flatbuffers=23.5.26=pyhd8ed1ab_0
- python-tzdata=2023.3=pyhd8ed1ab_0
- python-xxhash=3.4.1=py310h2372a71_0
- python_abi=3.10=4_cp310
- pytorch=1.13.1=cpu_py310hd11e9c7_1
- pytorch-cuda=11.7=h778d358_5
- pytorch-mutex=1.0=cuda
- pytz=2023.3.post1=pyhd8ed1ab_0
- pyyaml=6.0.1=py310h2372a71_1
- rdma-core=28.9=h59595ed_1
- re2=2023.03.02=h8c504da_0
- readline=8.2=h8228510_1
- regex=2023.10.3=py310h2372a71_0
- requests=2.31.0=pyhd8ed1ab_0
- s2n=1.3.49=h06160fa_0
- sacremoses=0.0.53=pyhd8ed1ab_0
- safetensors=0.3.3=py310hcb5633a_1
- sentencepiece=0.1.99=hff52083_0
- sentencepiece-python=0.1.99=py310hebdb9f0_0
- sentencepiece-spm=0.1.99=h180e1df_0
- setuptools=68.2.2=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- sleef=3.5.1=h9b69904_2
- snappy=1.1.10=h9fff704_0
- sympy=1.12=pypyh9d50eac_103
- tbb=2021.11.0=h00ab1b0_0
- texttable=1.7.0=pyhd8ed1ab_0
- timm=0.9.12=pyhd8ed1ab_0
- tk=8.6.13=noxft_h4845f30_101
- tokenizers=0.14.1=py310h320607d_2
- torchvision=0.14.1=cpu_py310hd3d2ac3_1
- tqdm=4.66.1=pyhd8ed1ab_0
- transformers=4.35.2=pyhd8ed1ab_0
- typing-extensions=4.9.0=hd8ed1ab_0
- typing_extensions=4.9.0=pyha770c72_0
- tzdata=2023c=h71feb2d_0
- ucx=1.14.1=h64cca9d_5
- urllib3=2.1.0=pyhd8ed1ab_0
- wcwidth=0.2.12=pyhd8ed1ab_0
- wheel=0.42.0=pyhd8ed1ab_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xxhash=0.8.2=hd590300_0
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- yarl=1.9.3=py310h2372a71_0
- zipp=3.17.0=pyhd8ed1ab_0
- zlib=1.2.13=hd590300_5
- zstd=1.5.5=hfc55251_0
- python>=3.11,<4.0
- onnx>=1.16.1
# - onnxruntime>=1.18.1 # conda only has gpu version
- psutil>=6.0.0
- flatbuffers>=24.3.25
- ml_dtypes>=0.3.1
- typer-slim>=0.12.3
- huggingface_hub>=0.23.4
- pip
- pip:
- git+https://github.com/fyfrey/TinyNeuralNetwork.git
- onnxruntime>=1.18.1 # conda only has gpu version
- onnxsim>=0.4.36
- onnx2tf>=1.24.1
- onnx_graphsurgeon>=0.5.2
- simple_onnx_processing_tools>=1.1.32
- tf_keras>=2.16.0
- git+https://github.com/microsoft/onnxconverter-common.git

View file

@ -0,0 +1,99 @@
import os
import platform
from typing import Annotated, Optional
import typer
from onnx2ann.export import Exporter, ModelType, Precision
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@app.command()
def export(
model_name: Annotated[
str, typer.Argument(..., help="The name of the model to be exported as it exists in Hugging Face.")
],
model_type: Annotated[ModelType, typer.Option(..., "--type", "-t", help="The type of model to be exported.")],
input_shapes: Annotated[
list[str],
typer.Option(
...,
"--input-shape",
"-s",
help="The shape of an input tensor to the model, each dimension separated by commas. "
"Multiple shapes can be provided for multiple inputs.",
),
],
precision: Annotated[
Precision,
typer.Option(
...,
"--precision",
"-p",
help="The precision of the exported model. `float16` requires a GPU.",
),
] = Precision.FLOAT32,
cache_dir: Annotated[
str,
typer.Option(
...,
"--cache-dir",
"-c",
help="Directory where pre-export models will be stored.",
envvar="CACHE_DIR",
show_envvar=True,
),
] = "~/.cache/huggingface",
output_dir: Annotated[
str,
typer.Option(
...,
"--output-dir",
"-o",
help="Directory where exported models will be stored.",
),
] = "output",
auth_token: Annotated[
Optional[str],
typer.Option(
...,
"--auth-token",
"-t",
help="If uploading models to Hugging Face, the auth token of the user or organisation.",
envvar="HF_AUTH_TOKEN",
show_envvar=True,
),
] = None,
force_export: Annotated[
bool,
typer.Option(
...,
"--force-export",
"-f",
help="Export the model even if an exported model already exists in the output directory.",
),
] = False,
) -> None:
if platform.machine() not in ("x86_64", "AMD64"):
msg = f"Can only run on x86_64 / AMD64, not {platform.machine()}"
raise RuntimeError(msg)
os.environ.setdefault("LD_LIBRARY_PATH", "armnn")
parsed_input_shapes = [tuple(map(int, shape.split(","))) for shape in input_shapes]
model = Exporter(
model_name, model_type, input_shapes=parsed_input_shapes, cache_dir=cache_dir, force_export=force_export
)
model_dir = os.path.join("output", model_name)
output_dir = os.path.join(model_dir, model_type)
armnn_model = model.to_armnn(output_dir, precision)
if not auth_token:
return
from huggingface_hub import upload_file
relative_path = os.path.relpath(armnn_model, start=model_dir)
upload_file(path_or_fileobj=armnn_model, path_in_repo=relative_path, repo_id=model.repo_name, token=auth_token)
app()

View file

@ -0,0 +1,129 @@
import os
import subprocess
from enum import StrEnum
from onnx2ann.helpers import onnx_make_armnn_compatible, onnx_make_inputs_fixed
class ModelType(StrEnum):
VISUAL = "visual"
TEXTUAL = "textual"
RECOGNITION = "recognition"
DETECTION = "detection"
class Precision(StrEnum):
FLOAT16 = "float16"
FLOAT32 = "float32"
class Exporter:
def __init__(
self,
model_name: str,
model_type: str,
input_shapes: list[tuple[int, ...]],
optimization_level: int = 5,
cache_dir: str = os.environ.get("CACHE_DIR", "~/.cache/huggingface"),
force_export: bool = False,
):
self.model_name = model_name.split("/")[-1]
self.model_type = model_type
self.optimize = optimization_level
self.input_shapes = input_shapes
self.cache_dir = os.path.join(cache_dir, self.repo_name)
self.force_export = force_export
def download(self) -> str:
model_path = os.path.join(self.cache_dir, self.model_type, "model.onnx")
if os.path.isfile(model_path):
print(f"Model is already downloaded at {model_path}")
return model_path
from huggingface_hub import snapshot_download
snapshot_download(
self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False
)
return model_path
def to_onnx_static(self, precision: Precision) -> str:
import onnx
from onnxconverter_common import float16
onnx_path_original = self.download()
static_dir = os.path.join(self.cache_dir, self.model_type, "static")
static_path = os.path.join(static_dir, f"model.onnx")
if self.force_export and not os.path.isfile(static_path):
print(f"Making {self} static")
os.makedirs(static_dir, exist_ok=True)
onnx_make_inputs_fixed(onnx_path_original, static_path, self.input_shapes)
onnx_make_armnn_compatible(static_path)
print(f"Finished making {self} static")
model = onnx.load(static_path)
self.inputs = [input_.name for input_ in model.graph.input]
self.outputs = [output_.name for output_ in model.graph.output]
if precision == Precision.FLOAT16:
static_path = os.path.join(static_dir, f"model_{precision}.onnx")
print(f"Converting {self} to {precision} precision")
model = float16.convert_float_to_float16(model, keep_io_types=True, disable_shape_infer=True)
onnx.save(model, static_path)
print(f"Finished converting {self} to {precision} precision")
# self.inputs, self.outputs = onnx_get_inputs_outputs(static_path)
return static_path
def to_tflite(self, output_dir: str, precision: Precision) -> str:
onnx_model = self.to_onnx_static(precision)
tflite_dir = os.path.join(output_dir, precision)
tflite_model = os.path.join(tflite_dir, f"model_{precision}.tflite")
if self.force_export or not os.path.isfile(tflite_model):
import onnx2tf
print(f"Exporting {self} to TFLite with {precision} precision (this might take a few minutes)")
onnx2tf.convert(
input_onnx_file_path=onnx_model,
output_folder_path=tflite_dir,
keep_shape_absolutely_input_names=self.inputs,
# verbosity="warn",
copy_onnx_input_output_names_to_tflite=True,
output_signaturedefs=True,
not_use_onnxsim=True,
)
print(f"Finished exporting {self} to TFLite with {precision} precision")
return tflite_model
def to_armnn(self, output_dir: str, precision: Precision) -> tuple[str, str]:
armnn_model = os.path.join(output_dir, "model.armnn")
if not self.force_export and os.path.isfile(armnn_model):
return armnn_model
tflite_model_dir = os.path.join(output_dir, "tflite")
tflite_model = self.to_tflite(tflite_model_dir, precision)
args = ["./armnnconverter", "-f", "tflite-binary", "-m", tflite_model, "-p", armnn_model]
args.append("-i")
args.extend(self.inputs)
args.append("-o")
args.extend(self.outputs)
print(f"Exporting {self} to ARM NN with {precision} precision")
try:
if (stdout := subprocess.check_output(args, stderr=subprocess.STDOUT).decode()):
print(stdout)
print(f"Finished exporting {self} to ARM NN with {precision} precision")
except subprocess.CalledProcessError as e:
print(e.output.decode())
try:
from shutil import rmtree
rmtree(tflite_model_dir, ignore_errors=True)
finally:
raise e
@property
def repo_name(self) -> str:
return f"immich-app/{self.model_name}"
def __repr__(self) -> str:
return f"{self.model_name} ({self.model_type})"

View file

@ -0,0 +1,260 @@
from typing import Any
def onnx_make_armnn_compatible(model_path: str) -> None:
"""
i can explain
armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
it also switches from gather ops to slices since armnn has different dimension semantics for gathers
also fixes batch normalization being in training mode
"""
import numpy as np
import onnx
from onnx_graphsurgeon import Constant, Node, Variable, export_onnx, import_onnx
proto = onnx.load(model_path)
graph = import_onnx(proto)
gather_idx = 1
squeeze_idx = 1
for node in graph.nodes:
for link1 in node.outputs:
if "Unsqueeze" in link1.name:
for node1 in link1.outputs:
for link2 in node1.outputs:
if "Transpose" in link2.name:
for node2 in link2.outputs:
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
node2.attrs["perm"] = [2, 0, 1, 3]
link2.shape = link1.shape
for link3 in node2.outputs:
if "Squeeze" in link3.name:
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
for node3 in link3.outputs:
for link4 in node3.outputs:
link4.shape = link3.shape
try:
idx = link2.inputs.index(node1)
link2.inputs[idx] = node
except ValueError:
pass
node.outputs = [link2]
if "Gather" in link4.name:
for node4 in link4.outputs:
axis = node1.attrs.get("axis", 0)
index = node4.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=link4.dtype,
shape=[1] + link3.shape[1:],
)
slice_node = Node(
op="Slice",
inputs=[
link3,
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
for link5 in node4.outputs:
for node5 in link5.outputs:
try:
idx = node5.inputs.index(link5)
node5.inputs[idx] = slice_link
except ValueError:
pass
elif node.op == "LayerNormalization":
for node1 in link1.outputs:
if node1.op == "Gather":
for link2 in node1.outputs:
for node2 in link2.outputs:
axis = node1.attrs.get("axis", 0)
index = node1.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=link2.dtype,
shape=[1, *link2.shape],
)
slice_node = Node(
op="Slice",
inputs=[
node1.inputs[0],
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
squeeze_link = Variable(
f"onnx::Squeeze_123{squeeze_idx}",
dtype=link2.dtype,
shape=link2.shape,
)
squeeze_node = Node(
op="Squeeze",
inputs=[
slice_link,
Constant(
f"SqueezeAxis_123{squeeze_idx}",
np.array([0]),
),
],
outputs=[squeeze_link],
name=f"Squeeze_123{squeeze_idx}",
)
graph.nodes.append(squeeze_node)
squeeze_idx += 1
try:
idx = node2.inputs.index(link2)
node2.inputs[idx] = squeeze_link
except ValueError:
pass
elif node.op == "Reshape":
for node1 in link1.outputs:
if node1.op == "Gather":
node2s = [n for link in node1.outputs for n in link.outputs]
if any(n.op == "Abs" for n in node2s):
axis = node1.attrs.get("axis", 0)
index = node1.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=node1.outputs[0].dtype,
shape=[1, *node1.outputs[0].shape],
)
slice_node = Node(
op="Slice",
inputs=[
node1.inputs[0],
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
squeeze_link = Variable(
f"onnx::Squeeze_123{squeeze_idx}",
dtype=node1.outputs[0].dtype,
shape=node1.outputs[0].shape,
)
squeeze_node = Node(
op="Squeeze",
inputs=[
slice_link,
Constant(
f"SqueezeAxis_123{squeeze_idx}",
np.array([0]),
),
],
outputs=[squeeze_link],
name=f"Squeeze_123{squeeze_idx}",
)
graph.nodes.append(squeeze_node)
squeeze_idx += 1
for node2 in node2s:
node2.inputs[0] = squeeze_link
elif node.op == "BatchNormalization" and node.attrs.get("training_mode") == 1:
node.attrs["training_mode"] = 0
node.outputs = node.outputs[:1]
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
graph.toposort()
graph.fold_constants()
updated = export_onnx(graph)
onnx_save(updated, model_path)
# for some reason, reloading the model is necessary to apply the correct shape
proto = onnx.load(model_path)
graph = import_onnx(proto)
for node in graph.nodes:
if node.op == "Slice":
for link in node.outputs:
if "Slice_123" in link.name and link.shape[0] == 3: # noqa: PLR2004
link.shape[0] = 1
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
graph.toposort()
graph.fold_constants()
updated = export_onnx(graph)
onnx_save(updated, model_path)
onnx.shape_inference.infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
def onnx_make_inputs_fixed(input_path: str, output_path: str, input_shapes: list[tuple[int, ...]]) -> None:
import onnx
import onnxsim
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
model, success = onnxsim.simplify(input_path)
if not success:
msg = f"Failed to simplify {input_path}"
raise RuntimeError(msg)
onnx_save(model, output_path)
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
model = onnx.load_model(output_path)
for input_node, shape in zip(model.graph.input, input_shapes, strict=False):
make_input_shape_fixed(model.graph, input_node.name, shape)
fix_output_shapes(model)
onnx_save(model, output_path)
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
def onnx_get_inputs_outputs(model_path: str) -> tuple[list[str], list[str]]:
import onnx
model = onnx.load(model_path)
inputs = [input_.name for input_ in model.graph.input]
outputs = [output_.name for output_ in model.graph.output]
return inputs, outputs
def onnx_save(model: Any, output_path: str) -> None:
import onnx
try:
onnx.save(model, output_path)
except:
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False, size_threshold=1_000_000)

View file

@ -0,0 +1,56 @@
[project]
name = "onnx2ann"
version = "1.107.2"
dependencies = [
"onnx>=1.16.1",
"psutil>=6.0.0",
"flatbuffers>=24.3.25",
"ml_dtypes>=0.3.1,<1.0.0",
"typer-slim>=0.12.3,<1.0.0",
"huggingface_hub>=0.23.4,<1.0.0",
"onnxruntime>=1.18.1",
"onnxsim>=0.4.36,<1.0.0",
"onnx2tf>=1.24.0",
"onnx_graphsurgeon>=0.5.2,<1.0.0",
"simple_onnx_processing_tools>=1.1.32",
"tf_keras>=2.16.0",
"onnxconverter-common @ git+https://github.com/microsoft/onnxconverter-common"
]
requires-python = ">=3.11"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.sdist]
only-include = ["onnx2ann"]
[tool.hatch.metadata]
allow-direct-references = true
[tool.mypy]
python_version = "3.12"
follow_imports = "silent"
warn_redundant_casts = true
disallow_any_generics = true
check_untyped_defs = true
disallow_untyped_defs = true
ignore_missing_imports = true
[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
warn_untyped_fields = true
[tool.ruff]
line-length = 120
target-version = "py312"
[tool.ruff.lint]
extend-select = ["E", "F", "I"]
extend-ignore = ["FBT001", "FBT002"]
[tool.black]
line-length = 120
target-version = ['py312']

View file

@ -1,475 +0,0 @@
import os
import platform
import subprocess
from typing import Callable, ClassVar
import onnx
from onnx_graphsurgeon import Constant, Node, Variable, import_onnx, export_onnx
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
from huggingface_hub import snapshot_download
from onnx.shape_inference import infer_shapes_path
from huggingface_hub import login, upload_file
import onnx2tf
import numpy as np
import onnxsim
from shutil import rmtree
# hack: changed Mul op in onnx2tf to skip broadcast if graph_node.o().op == 'Sigmoid'
# i can explain
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
# it also switches from gather ops to slices since armnn has different dimension semantics for gathers
# also fixes batch normalization being in training mode
def make_onnx_armnn_compatible(model_path: str):
proto = onnx.load(model_path)
graph = import_onnx(proto)
gather_idx = 1
squeeze_idx = 1
for node in graph.nodes:
for link1 in node.outputs:
if "Unsqueeze" in link1.name:
for node1 in link1.outputs:
for link2 in node1.outputs:
if "Transpose" in link2.name:
for node2 in link2.outputs:
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
node2.attrs["perm"] = [2, 0, 1, 3]
link2.shape = link1.shape
for link3 in node2.outputs:
if "Squeeze" in link3.name:
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
for node3 in link3.outputs:
for link4 in node3.outputs:
link4.shape = link3.shape
try:
idx = link2.inputs.index(node1)
link2.inputs[idx] = node
except ValueError:
pass
node.outputs = [link2]
if "Gather" in link4.name:
for node4 in link4.outputs:
axis = node1.attrs.get("axis", 0)
index = node4.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=link4.dtype,
shape=[1] + link3.shape[1:],
)
slice_node = Node(
op="Slice",
inputs=[
link3,
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
for link5 in node4.outputs:
for node5 in link5.outputs:
try:
idx = node5.inputs.index(link5)
node5.inputs[idx] = slice_link
except ValueError:
pass
elif node.op == "LayerNormalization":
for node1 in link1.outputs:
if node1.op == "Gather":
for link2 in node1.outputs:
for node2 in link2.outputs:
axis = node1.attrs.get("axis", 0)
index = node1.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=link2.dtype,
shape=[1] + link2.shape,
)
slice_node = Node(
op="Slice",
inputs=[
node1.inputs[0],
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
squeeze_link = Variable(
f"onnx::Squeeze_123{squeeze_idx}",
dtype=link2.dtype,
shape=link2.shape,
)
squeeze_node = Node(
op="Squeeze",
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
outputs=[squeeze_link],
name=f"Squeeze_123{squeeze_idx}",
)
graph.nodes.append(squeeze_node)
squeeze_idx += 1
try:
idx = node2.inputs.index(link2)
node2.inputs[idx] = squeeze_link
except ValueError:
pass
elif node.op == "Reshape":
for node1 in link1.outputs:
if node1.op == "Gather":
node2s = [n for l in node1.outputs for n in l.outputs]
if any(n.op == "Abs" for n in node2s):
axis = node1.attrs.get("axis", 0)
index = node1.inputs[1].values
slice_link = Variable(
f"onnx::Slice_123{gather_idx}",
dtype=node1.outputs[0].dtype,
shape=[1] + node1.outputs[0].shape,
)
slice_node = Node(
op="Slice",
inputs=[
node1.inputs[0],
Constant(
f"SliceStart_123{gather_idx}",
np.array([index]),
),
Constant(
f"SliceEnd_123{gather_idx}",
np.array([index + 1]),
),
Constant(
f"SliceAxis_123{gather_idx}",
np.array([axis]),
),
],
outputs=[slice_link],
name=f"Slice_123{gather_idx}",
)
graph.nodes.append(slice_node)
gather_idx += 1
squeeze_link = Variable(
f"onnx::Squeeze_123{squeeze_idx}",
dtype=node1.outputs[0].dtype,
shape=node1.outputs[0].shape,
)
squeeze_node = Node(
op="Squeeze",
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
outputs=[squeeze_link],
name=f"Squeeze_123{squeeze_idx}",
)
graph.nodes.append(squeeze_node)
squeeze_idx += 1
for node2 in node2s:
node2.inputs[0] = squeeze_link
elif node.op == "BatchNormalization":
if node.attrs.get("training_mode") == 1:
node.attrs["training_mode"] = 0
node.outputs = node.outputs[:1]
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
graph.toposort()
graph.fold_constants()
updated = export_onnx(graph)
onnx.save(updated, model_path)
# infer_shapes_path(updated, check_type=True, strict_mode=False, data_prop=True)
# for some reason, reloading the model is necessary to apply the correct shape
proto = onnx.load(model_path)
graph = import_onnx(proto)
for node in graph.nodes:
if node.op == "Slice":
for link in node.outputs:
if "Slice_123" in link.name and link.shape[0] == 3:
link.shape[0] = 1
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
graph.toposort()
graph.fold_constants()
updated = export_onnx(graph)
onnx.save(updated, model_path)
infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, ...]):
simplified, success = onnxsim.simplify(input_path)
if not success:
raise RuntimeError(f"Failed to simplify {input_path}")
try:
onnx.save(simplified, output_path)
except:
onnx.save(simplified, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
model = onnx.load_model(output_path)
make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape)
fix_output_shapes(model)
try:
onnx.save(model, output_path)
except:
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
onnx.save(model, output_path)
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
class ExportBase:
task: ClassVar[str]
def __init__(
self,
name: str,
input_shape: tuple[int, ...],
pretrained: str | None = None,
optimization_level: int = 5,
):
super().__init__()
self.name = name
self.optimize = optimization_level
self.input_shape = input_shape
self.pretrained = pretrained
self.cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name)
def download(self) -> str:
model_path = os.path.join(self.cache_dir, self.task, "model.onnx")
if not os.path.isfile(model_path):
print(f"Downloading {self.model_name}...")
snapshot_download(self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False)
return model_path
def to_onnx_static(self) -> str:
onnx_path_original = self.download()
static_dir = os.path.join(self.cache_dir, self.task, "static")
os.makedirs(static_dir, exist_ok=True)
static_path = os.path.join(static_dir, "model.onnx")
if not os.path.isfile(static_path):
print(f"Making {self.model_name} ({self.task}) static")
onnx_make_fixed(onnx_path_original, static_path, self.input_shape)
make_onnx_armnn_compatible(static_path)
static_model = onnx.load_model(static_path)
self.inputs = [input_.name for input_ in static_model.graph.input]
self.outputs = [output_.name for output_ in static_model.graph.output]
return static_path
def to_tflite(self, output_dir: str) -> tuple[str, str]:
input_path = self.to_onnx_static()
tflite_fp32 = os.path.join(output_dir, "model_float32.tflite")
tflite_fp16 = os.path.join(output_dir, "model_float16.tflite")
if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16):
print(f"Exporting {self.model_name} ({self.task}) to TFLite (this might take a few minutes)")
onnx2tf.convert(
input_onnx_file_path=input_path,
output_folder_path=output_dir,
keep_shape_absolutely_input_names=self.inputs,
verbosity="warn",
copy_onnx_input_output_names_to_tflite=True,
output_signaturedefs=True,
)
return tflite_fp32, tflite_fp16
def to_armnn(self, output_dir: str) -> tuple[str, str]:
output_dir = os.path.abspath(output_dir)
tflite_model_dir = os.path.join(output_dir, "tflite")
tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir)
fp16_dir = os.path.join(output_dir, "fp16")
os.makedirs(fp16_dir, exist_ok=True)
armnn_fp32 = os.path.join(output_dir, "model.armnn")
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
args = ["./armnnconverter", "-f", "tflite-binary"]
args.append("-i")
args.extend(self.inputs)
args.append("-o")
args.extend(self.outputs)
fp32_args = args.copy()
fp32_args.extend(["-m", tflite_fp32, "-p", armnn_fp32])
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
try:
print(subprocess.check_output(fp32_args, stderr=subprocess.STDOUT).decode())
except subprocess.CalledProcessError as e:
print(e.output.decode())
try:
rmtree(tflite_model_dir, ignore_errors=True)
finally:
raise e
print(f"Finished exporting {self.model_name} ({self.task}) with fp32 precision")
fp16_args = args.copy()
fp16_args.extend(["-m", tflite_fp16, "-p", armnn_fp16])
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
try:
print(subprocess.check_output(fp16_args, stderr=subprocess.STDOUT).decode())
except subprocess.CalledProcessError as e:
print(e.output.decode())
try:
rmtree(tflite_model_dir, ignore_errors=True)
finally:
raise e
print(f"Finished exporting {self.model_name} ({self.task}) with fp16 precision")
return armnn_fp32, armnn_fp16
@property
def model_name(self) -> str:
return f"{self.name}__{self.pretrained}" if self.pretrained else self.name
@property
def repo_name(self) -> str:
return f"immich-app/{self.model_name}"
class ArcFace(ExportBase):
task = "recognition"
class RetinaFace(ExportBase):
task = "detection"
class OpenClipVisual(ExportBase):
task = "visual"
class OpenClipTextual(ExportBase):
task = "textual"
class MClipTextual(ExportBase):
task = "textual"
def main() -> None:
if platform.machine() not in ("x86_64", "AMD64"):
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
hf_token = os.environ.get("HF_AUTH_TOKEN")
if hf_token:
login(token=hf_token)
os.environ["LD_LIBRARY_PATH"] = "armnn"
failed: list[Callable[[], ExportBase]] = [
lambda: OpenClipVisual("ViT-H-14-378-quickgelu", (1, 3, 378, 378), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
lambda: OpenClipVisual("ViT-H-14-quickgelu", (1, 3, 224, 224), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
lambda: OpenClipVisual("ViT-H-14", (1, 3, 224, 224), pretrained="laion2b-s32b-b79k"),
lambda: OpenClipTextual("ViT-H-14", (1, 77), pretrained="laion2b-s32b-b79k"),
lambda: OpenClipVisual("ViT-g-14", (1, 3, 224, 224), pretrained="laion2b-s12b-b42k"),
lambda: OpenClipTextual("ViT-g-14", (1, 77), pretrained="laion2b-s12b-b42k"),
lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-16Plus", (1, 3, 240, 240)),
lambda: OpenClipVisual("XLM-Roberta-Large-ViT-H-14", (1, 3, 224, 224), pretrained="frozen_laion5b_s13b_b90k"),
lambda: MClipTextual("XLM-Roberta-Large-Vit-L-14", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
lambda: MClipTextual("XLM-Roberta-Large-Vit-B-16Plus", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
lambda: MClipTextual("LABSE-Vit-L-14", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
lambda: OpenClipTextual("XLM-Roberta-Large-ViT-H-14", (1, 77), pretrained="frozen_laion5b_s13b_b90k"), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
]
oom = [
lambda: OpenClipVisual("nllb-clip-base-siglip", (1, 3, 384, 384), pretrained="v1"),
lambda: OpenClipTextual("nllb-clip-base-siglip", (1, 77), pretrained="v1"),
lambda: OpenClipVisual("nllb-clip-large-siglip", (1, 3, 384, 384), pretrained="v1"),
lambda: OpenClipTextual("nllb-clip-large-siglip", (1, 77), pretrained="v1"), # ERROR (tinynn.converter.base) Unsupported ops: aten::logical_not
# lambda: OpenClipTextual("ViT-H-14-quickgelu", (1, 77), pretrained="dfn5b"),
# lambda: OpenClipTextual("ViT-H-14-378-quickgelu", (1, 77), pretrained="dfn5b"),
# lambda: OpenClipVisual("XLM-Roberta-Large-Vit-L-14", (1, 3, 224, 224)),
]
succeeded: list[Callable[[], ExportBase]] = [
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion2b_e16"),
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion2b_e16"),
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion400m_e31"),
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion400m_e31"),
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion400m_e32"),
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion400m_e32"),
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion2b-s34b-b79k"),
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion2b-s34b-b79k"),
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="laion400m_e31"),
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="laion400m_e31"),
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="laion400m_e32"),
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="laion400m_e32"),
# lambda: OpenClipVisual("ViT-B-16-plus-240", (1, 3, 240, 240), pretrained="laion400m_e31"),
# lambda: OpenClipTextual("ViT-B-16-plus-240", (1, 77), pretrained="laion400m_e31"),
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="openai"),
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="openai"),
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="openai"),
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="openai"),
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="openai"),
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="openai"),
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="yfcc15m"),
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="yfcc15m"),
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="cc12m"),
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="cc12m"),
# lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-32", (1, 3, 224, 224)),
# lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="openai"),
# lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="openai"),
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion400m_e31"),
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion400m_e31"),
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion400m_e32"),
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion400m_e32"),
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion2b-s32b-b82k"),
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion2b-s32b-b82k"),
# lambda: OpenClipVisual("ViT-L-14-336", (1, 3, 336, 336), pretrained="openai"),
# lambda: OpenClipTextual("ViT-L-14-336", (1, 77), pretrained="openai"),
# lambda: ArcFace("buffalo_s", (1, 3, 112, 112), optimization_level=3),
# lambda: RetinaFace("buffalo_s", (1, 3, 640, 640), optimization_level=3),
# lambda: ArcFace("buffalo_m", (1, 3, 112, 112), optimization_level=3),
# lambda: RetinaFace("buffalo_m", (1, 3, 640, 640), optimization_level=3),
# lambda: ArcFace("buffalo_l", (1, 3, 112, 112), optimization_level=3),
# lambda: RetinaFace("buffalo_l", (1, 3, 640, 640), optimization_level=3),
# lambda: ArcFace("antelopev2", (1, 3, 112, 112), optimization_level=3),
# lambda: RetinaFace("antelopev2", (1, 3, 640, 640), optimization_level=3),
]
models: list[Callable[[], ExportBase]] = [*failed, *succeeded]
for _model in succeeded:
model = _model()
try:
model_dir = os.path.join("output", model.model_name)
output_dir = os.path.join(model_dir, model.task)
armnn_fp32, armnn_fp16 = model.to_armnn(output_dir)
relative_fp32 = os.path.relpath(armnn_fp32, start=model_dir)
relative_fp16 = os.path.relpath(armnn_fp16, start=model_dir)
if hf_token and os.path.isfile(armnn_fp32):
print(f"Uploading {model.model_name} ({model.task}) ARM NN model with fp32 precision")
upload_file(path_or_fileobj=armnn_fp32, path_in_repo=relative_fp32, repo_id=model.repo_name)
print(f"Finished uploading {model.model_name} ({model.task}) ARM NN model with fp32 precision")
if hf_token and os.path.isfile(armnn_fp16):
print(f"Uploading {model.model_name} ({model.task}) ARM NN model with fp16 precision")
upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name)
print(f"Finished uploading {model.model_name} ({model.task}) ARM NN model with fp16 precision")
except Exception as exc:
print(f"Failed to export {model.model_name} ({model.task}): {exc}")
raise exc
if __name__ == "__main__":
main()