mirror of
https://github.com/immich-app/immich.git
synced 2025-02-04 01:09:14 -05:00
add cli
This commit is contained in:
parent
3db69b94ed
commit
72269ab58c
13 changed files with 2185 additions and 689 deletions
|
@ -1,28 +1,35 @@
|
|||
FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a as builder
|
||||
|
||||
ENV TRANSFORMERS_CACHE=/cache \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
WORKDIR /export/ann
|
||||
|
||||
USER root
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
cmake \
|
||||
curl \
|
||||
git
|
||||
|
||||
USER $MAMBA_USER
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER env.yaml ./
|
||||
RUN micromamba install -y -f env.yaml
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER *.sh *.cpp ./
|
||||
|
||||
ENV ARMNN_PATH=/export/ann/armnn
|
||||
WORKDIR /home/mambauser
|
||||
ENV ARMNN_PATH=armnn
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER scripts/* .
|
||||
RUN ./download-armnn.sh && \
|
||||
./build-converter.sh && \
|
||||
./build.sh
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER run.py ./
|
||||
./build-converter.sh && \
|
||||
./build.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/_entrypoint.sh"]
|
||||
CMD ["python", "run.py"]
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER conda-lock.yml .
|
||||
RUN micromamba create -y -p /home/mambauser/venv -f conda-lock.yml && \
|
||||
micromamba clean --all --yes
|
||||
ENV PATH="/home/mambauser/venv/bin:${PATH}"
|
||||
|
||||
FROM gcr.io/distroless/base-debian12
|
||||
# FROM mambaorg/micromamba:bookworm-slim@sha256:333f7598ff2c2400fb10bfe057709c68b7daab5d847143af85abcf224a07271a
|
||||
|
||||
WORKDIR /export/ann
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
LD_LIBRARY_PATH=/export/ann/armnn \
|
||||
PATH="/opt/venv/bin:${PATH}"
|
||||
|
||||
COPY --from=builder /home/mambauser/armnnconverter /home/mambauser/armnn ./
|
||||
COPY --from=builder /home/mambauser/venv /opt/venv
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER onnx2ann onnx2ann
|
||||
|
||||
ENTRYPOINT ["python", "-m", "onnx2ann"]
|
||||
|
|
1600
machine-learning/export/ann/conda-lock.yml
Normal file
1600
machine-learning/export/ann/conda-lock.yml
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,201 +1,21 @@
|
|||
name: annexport
|
||||
name: onnx2ann
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=conda_forge
|
||||
- _openmp_mutex=4.5=2_kmp_llvm
|
||||
- aiohttp=3.9.1=py310h2372a71_0
|
||||
- aiosignal=1.3.1=pyhd8ed1ab_0
|
||||
- arpack=3.8.0=nompi_h0baa96a_101
|
||||
- async-timeout=4.0.3=pyhd8ed1ab_0
|
||||
- attrs=23.1.0=pyh71513ae_1
|
||||
- aws-c-auth=0.7.3=h28f7589_1
|
||||
- aws-c-cal=0.6.1=hc309b26_1
|
||||
- aws-c-common=0.9.0=hd590300_0
|
||||
- aws-c-compression=0.2.17=h4d4d85c_2
|
||||
- aws-c-event-stream=0.3.1=h2e3709c_4
|
||||
- aws-c-http=0.7.11=h00aa349_4
|
||||
- aws-c-io=0.13.32=he9a53bd_1
|
||||
- aws-c-mqtt=0.9.3=hb447be9_1
|
||||
- aws-c-s3=0.3.14=hf3aad02_1
|
||||
- aws-c-sdkutils=0.1.12=h4d4d85c_1
|
||||
- aws-checksums=0.1.17=h4d4d85c_1
|
||||
- aws-crt-cpp=0.21.0=hb942446_5
|
||||
- aws-sdk-cpp=1.10.57=h85b1a90_19
|
||||
- blas=2.120=openblas
|
||||
- blas-devel=3.9.0=20_linux64_openblas
|
||||
- brotli-python=1.0.9=py310hd8f1fbe_9
|
||||
- bzip2=1.0.8=hd590300_5
|
||||
- c-ares=1.23.0=hd590300_0
|
||||
- ca-certificates=2023.11.17=hbcca054_0
|
||||
- certifi=2023.11.17=pyhd8ed1ab_0
|
||||
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
||||
- click=8.1.7=unix_pyh707e725_0
|
||||
- colorama=0.4.6=pyhd8ed1ab_0
|
||||
- coloredlogs=15.0.1=pyhd8ed1ab_3
|
||||
- cuda-cudart=11.7.99=0
|
||||
- cuda-cupti=11.7.101=0
|
||||
- cuda-libraries=11.7.1=0
|
||||
- cuda-nvrtc=11.7.99=0
|
||||
- cuda-nvtx=11.7.91=0
|
||||
- cuda-runtime=11.7.1=0
|
||||
- dataclasses=0.8=pyhc8e2a94_3
|
||||
- datasets=2.14.7=pyhd8ed1ab_0
|
||||
- dill=0.3.7=pyhd8ed1ab_0
|
||||
- filelock=3.13.1=pyhd8ed1ab_0
|
||||
- flatbuffers=23.5.26=h59595ed_1
|
||||
- freetype=2.12.1=h267a509_2
|
||||
- frozenlist=1.4.0=py310h2372a71_1
|
||||
- fsspec=2023.10.0=pyhca7485f_0
|
||||
- ftfy=6.1.3=pyhd8ed1ab_0
|
||||
- gflags=2.2.2=he1b5a44_1004
|
||||
- glog=0.6.0=h6f12383_0
|
||||
- glpk=5.0=h445213a_0
|
||||
- gmp=6.3.0=h59595ed_0
|
||||
- gmpy2=2.1.2=py310h3ec546c_1
|
||||
- huggingface_hub=0.17.3=pyhd8ed1ab_0
|
||||
- humanfriendly=10.0=pyhd8ed1ab_6
|
||||
- icu=73.2=h59595ed_0
|
||||
- idna=3.6=pyhd8ed1ab_0
|
||||
- importlib-metadata=7.0.0=pyha770c72_0
|
||||
- importlib_metadata=7.0.0=hd8ed1ab_0
|
||||
- joblib=1.3.2=pyhd8ed1ab_0
|
||||
- keyutils=1.6.1=h166bdaf_0
|
||||
- krb5=1.21.2=h659d440_0
|
||||
- lcms2=2.15=h7f713cb_2
|
||||
- ld_impl_linux-64=2.40=h41732ed_0
|
||||
- lerc=4.0.0=h27087fc_0
|
||||
- libabseil=20230125.3=cxx17_h59595ed_0
|
||||
- libarrow=12.0.1=hb87d912_8_cpu
|
||||
- libblas=3.9.0=20_linux64_openblas
|
||||
- libbrotlicommon=1.0.9=h166bdaf_9
|
||||
- libbrotlidec=1.0.9=h166bdaf_9
|
||||
- libbrotlienc=1.0.9=h166bdaf_9
|
||||
- libcblas=3.9.0=20_linux64_openblas
|
||||
- libcrc32c=1.1.2=h9c3ff4c_0
|
||||
- libcublas=11.10.3.66=0
|
||||
- libcufft=10.7.2.124=h4fbf590_0
|
||||
- libcufile=1.8.1.2=0
|
||||
- libcurand=10.3.4.101=0
|
||||
- libcurl=8.5.0=hca28451_0
|
||||
- libcusolver=11.4.0.1=0
|
||||
- libcusparse=11.7.4.91=0
|
||||
- libdeflate=1.19=hd590300_0
|
||||
- libedit=3.1.20191231=he28a2e2_2
|
||||
- libev=4.33=hd590300_2
|
||||
- libevent=2.1.12=hf998b51_1
|
||||
- libffi=3.4.2=h7f98852_5
|
||||
- libgcc-ng=13.2.0=h807b86a_3
|
||||
- libgfortran-ng=13.2.0=h69a702a_3
|
||||
- libgfortran5=13.2.0=ha4646dd_3
|
||||
- libgoogle-cloud=2.12.0=hac9eb74_1
|
||||
- libgrpc=1.54.3=hb20ce57_0
|
||||
- libhwloc=2.9.3=default_h554bfaf_1009
|
||||
- libiconv=1.17=hd590300_1
|
||||
- libjpeg-turbo=2.1.5.1=hd590300_1
|
||||
- liblapack=3.9.0=20_linux64_openblas
|
||||
- liblapacke=3.9.0=20_linux64_openblas
|
||||
- libnghttp2=1.58.0=h47da74e_1
|
||||
- libnpp=11.7.4.75=0
|
||||
- libnsl=2.0.1=hd590300_0
|
||||
- libnuma=2.0.16=h0b41bf4_1
|
||||
- libnvjpeg=11.8.0.2=0
|
||||
- libopenblas=0.3.25=pthreads_h413a1c8_0
|
||||
- libpng=1.6.39=h753d276_0
|
||||
- libprotobuf=3.21.12=hfc55251_2
|
||||
- libsentencepiece=0.1.99=h180e1df_0
|
||||
- libsqlite=3.44.2=h2797004_0
|
||||
- libssh2=1.11.0=h0841786_0
|
||||
- libstdcxx-ng=13.2.0=h7e041cc_3
|
||||
- libthrift=0.18.1=h8fd135c_2
|
||||
- libtiff=4.6.0=h29866fb_1
|
||||
- libutf8proc=2.8.0=h166bdaf_0
|
||||
- libuuid=2.38.1=h0b41bf4_0
|
||||
- libwebp-base=1.3.2=hd590300_0
|
||||
- libxcb=1.15=h0b41bf4_0
|
||||
- libxml2=2.11.6=h232c23b_0
|
||||
- libzlib=1.2.13=hd590300_5
|
||||
- llvm-openmp=17.0.6=h4dfa4b3_0
|
||||
- lz4-c=1.9.4=hcb278e6_0
|
||||
- mkl=2022.2.1=h84fe81f_16997
|
||||
- mkl-devel=2022.2.1=ha770c72_16998
|
||||
- mkl-include=2022.2.1=h84fe81f_16997
|
||||
- mpc=1.3.1=hfe3b2da_0
|
||||
- mpfr=4.2.1=h9458935_0
|
||||
- mpmath=1.3.0=pyhd8ed1ab_0
|
||||
- multidict=6.0.4=py310h2372a71_1
|
||||
- multiprocess=0.70.15=py310h2372a71_1
|
||||
- ncurses=6.4=h59595ed_2
|
||||
- numpy=1.26.2=py310hb13e2d6_0
|
||||
- onnx=1.14.0=py310ha3deec4_1
|
||||
- onnx2torch=1.5.13=pyhd8ed1ab_0
|
||||
- onnxruntime=1.16.3=py310hd4b7fbc_1_cpu
|
||||
- open-clip-torch=2.23.0=pyhd8ed1ab_1
|
||||
- openblas=0.3.25=pthreads_h7a3da1a_0
|
||||
- openjpeg=2.5.0=h488ebb8_3
|
||||
- openssl=3.2.0=hd590300_1
|
||||
- orc=1.9.0=h2f23424_1
|
||||
- packaging=23.2=pyhd8ed1ab_0
|
||||
- pandas=2.1.4=py310hcc13569_0
|
||||
- pillow=10.0.1=py310h29da1c1_1
|
||||
- pip=23.3.1=pyhd8ed1ab_0
|
||||
- protobuf=4.21.12=py310heca2aa9_0
|
||||
- pthread-stubs=0.4=h36c2ea0_1001
|
||||
- pyarrow=12.0.1=py310h0576679_8_cpu
|
||||
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
|
||||
- pysocks=1.7.1=pyha2e5f31_6
|
||||
- python=3.10.13=hd12c33a_0_cpython
|
||||
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
||||
- python-flatbuffers=23.5.26=pyhd8ed1ab_0
|
||||
- python-tzdata=2023.3=pyhd8ed1ab_0
|
||||
- python-xxhash=3.4.1=py310h2372a71_0
|
||||
- python_abi=3.10=4_cp310
|
||||
- pytorch=1.13.1=cpu_py310hd11e9c7_1
|
||||
- pytorch-cuda=11.7=h778d358_5
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pytz=2023.3.post1=pyhd8ed1ab_0
|
||||
- pyyaml=6.0.1=py310h2372a71_1
|
||||
- rdma-core=28.9=h59595ed_1
|
||||
- re2=2023.03.02=h8c504da_0
|
||||
- readline=8.2=h8228510_1
|
||||
- regex=2023.10.3=py310h2372a71_0
|
||||
- requests=2.31.0=pyhd8ed1ab_0
|
||||
- s2n=1.3.49=h06160fa_0
|
||||
- sacremoses=0.0.53=pyhd8ed1ab_0
|
||||
- safetensors=0.3.3=py310hcb5633a_1
|
||||
- sentencepiece=0.1.99=hff52083_0
|
||||
- sentencepiece-python=0.1.99=py310hebdb9f0_0
|
||||
- sentencepiece-spm=0.1.99=h180e1df_0
|
||||
- setuptools=68.2.2=pyhd8ed1ab_0
|
||||
- six=1.16.0=pyh6c4a22f_0
|
||||
- sleef=3.5.1=h9b69904_2
|
||||
- snappy=1.1.10=h9fff704_0
|
||||
- sympy=1.12=pypyh9d50eac_103
|
||||
- tbb=2021.11.0=h00ab1b0_0
|
||||
- texttable=1.7.0=pyhd8ed1ab_0
|
||||
- timm=0.9.12=pyhd8ed1ab_0
|
||||
- tk=8.6.13=noxft_h4845f30_101
|
||||
- tokenizers=0.14.1=py310h320607d_2
|
||||
- torchvision=0.14.1=cpu_py310hd3d2ac3_1
|
||||
- tqdm=4.66.1=pyhd8ed1ab_0
|
||||
- transformers=4.35.2=pyhd8ed1ab_0
|
||||
- typing-extensions=4.9.0=hd8ed1ab_0
|
||||
- typing_extensions=4.9.0=pyha770c72_0
|
||||
- tzdata=2023c=h71feb2d_0
|
||||
- ucx=1.14.1=h64cca9d_5
|
||||
- urllib3=2.1.0=pyhd8ed1ab_0
|
||||
- wcwidth=0.2.12=pyhd8ed1ab_0
|
||||
- wheel=0.42.0=pyhd8ed1ab_0
|
||||
- xorg-libxau=1.0.11=hd590300_0
|
||||
- xorg-libxdmcp=1.1.3=h7f98852_0
|
||||
- xxhash=0.8.2=hd590300_0
|
||||
- xz=5.2.6=h166bdaf_0
|
||||
- yaml=0.2.5=h7f98852_2
|
||||
- yarl=1.9.3=py310h2372a71_0
|
||||
- zipp=3.17.0=pyhd8ed1ab_0
|
||||
- zlib=1.2.13=hd590300_5
|
||||
- zstd=1.5.5=hfc55251_0
|
||||
- python>=3.11,<4.0
|
||||
- onnx>=1.16.1
|
||||
# - onnxruntime>=1.18.1 # conda only has gpu version
|
||||
- psutil>=6.0.0
|
||||
- flatbuffers>=24.3.25
|
||||
- ml_dtypes>=0.3.1
|
||||
- typer-slim>=0.12.3
|
||||
- huggingface_hub>=0.23.4
|
||||
- pip
|
||||
- pip:
|
||||
- git+https://github.com/fyfrey/TinyNeuralNetwork.git
|
||||
- onnxruntime>=1.18.1 # conda only has gpu version
|
||||
- onnxsim>=0.4.36
|
||||
- onnx2tf>=1.24.1
|
||||
- onnx_graphsurgeon>=0.5.2
|
||||
- simple_onnx_processing_tools>=1.1.32
|
||||
- tf_keras>=2.16.0
|
||||
- git+https://github.com/microsoft/onnxconverter-common.git
|
||||
|
|
0
machine-learning/export/ann/onnx2ann/__init__.py
Normal file
0
machine-learning/export/ann/onnx2ann/__init__.py
Normal file
99
machine-learning/export/ann/onnx2ann/__main__.py
Normal file
99
machine-learning/export/ann/onnx2ann/__main__.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
import platform
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import typer
|
||||
|
||||
from onnx2ann.export import Exporter, ModelType, Precision
|
||||
|
||||
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
@app.command()
|
||||
def export(
|
||||
model_name: Annotated[
|
||||
str, typer.Argument(..., help="The name of the model to be exported as it exists in Hugging Face.")
|
||||
],
|
||||
model_type: Annotated[ModelType, typer.Option(..., "--type", "-t", help="The type of model to be exported.")],
|
||||
input_shapes: Annotated[
|
||||
list[str],
|
||||
typer.Option(
|
||||
...,
|
||||
"--input-shape",
|
||||
"-s",
|
||||
help="The shape of an input tensor to the model, each dimension separated by commas. "
|
||||
"Multiple shapes can be provided for multiple inputs.",
|
||||
),
|
||||
],
|
||||
precision: Annotated[
|
||||
Precision,
|
||||
typer.Option(
|
||||
...,
|
||||
"--precision",
|
||||
"-p",
|
||||
help="The precision of the exported model. `float16` requires a GPU.",
|
||||
),
|
||||
] = Precision.FLOAT32,
|
||||
cache_dir: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
...,
|
||||
"--cache-dir",
|
||||
"-c",
|
||||
help="Directory where pre-export models will be stored.",
|
||||
envvar="CACHE_DIR",
|
||||
show_envvar=True,
|
||||
),
|
||||
] = "~/.cache/huggingface",
|
||||
output_dir: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
...,
|
||||
"--output-dir",
|
||||
"-o",
|
||||
help="Directory where exported models will be stored.",
|
||||
),
|
||||
] = "output",
|
||||
auth_token: Annotated[
|
||||
Optional[str],
|
||||
typer.Option(
|
||||
...,
|
||||
"--auth-token",
|
||||
"-t",
|
||||
help="If uploading models to Hugging Face, the auth token of the user or organisation.",
|
||||
envvar="HF_AUTH_TOKEN",
|
||||
show_envvar=True,
|
||||
),
|
||||
] = None,
|
||||
force_export: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
...,
|
||||
"--force-export",
|
||||
"-f",
|
||||
help="Export the model even if an exported model already exists in the output directory.",
|
||||
),
|
||||
] = False,
|
||||
) -> None:
|
||||
if platform.machine() not in ("x86_64", "AMD64"):
|
||||
msg = f"Can only run on x86_64 / AMD64, not {platform.machine()}"
|
||||
raise RuntimeError(msg)
|
||||
os.environ.setdefault("LD_LIBRARY_PATH", "armnn")
|
||||
parsed_input_shapes = [tuple(map(int, shape.split(","))) for shape in input_shapes]
|
||||
model = Exporter(
|
||||
model_name, model_type, input_shapes=parsed_input_shapes, cache_dir=cache_dir, force_export=force_export
|
||||
)
|
||||
model_dir = os.path.join("output", model_name)
|
||||
output_dir = os.path.join(model_dir, model_type)
|
||||
armnn_model = model.to_armnn(output_dir, precision)
|
||||
|
||||
if not auth_token:
|
||||
return
|
||||
|
||||
from huggingface_hub import upload_file
|
||||
|
||||
relative_path = os.path.relpath(armnn_model, start=model_dir)
|
||||
upload_file(path_or_fileobj=armnn_model, path_in_repo=relative_path, repo_id=model.repo_name, token=auth_token)
|
||||
|
||||
|
||||
app()
|
129
machine-learning/export/ann/onnx2ann/export.py
Normal file
129
machine-learning/export/ann/onnx2ann/export.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import os
|
||||
import subprocess
|
||||
from enum import StrEnum
|
||||
|
||||
from onnx2ann.helpers import onnx_make_armnn_compatible, onnx_make_inputs_fixed
|
||||
|
||||
|
||||
class ModelType(StrEnum):
|
||||
VISUAL = "visual"
|
||||
TEXTUAL = "textual"
|
||||
RECOGNITION = "recognition"
|
||||
DETECTION = "detection"
|
||||
|
||||
|
||||
class Precision(StrEnum):
|
||||
FLOAT16 = "float16"
|
||||
FLOAT32 = "float32"
|
||||
|
||||
|
||||
class Exporter:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: str,
|
||||
input_shapes: list[tuple[int, ...]],
|
||||
optimization_level: int = 5,
|
||||
cache_dir: str = os.environ.get("CACHE_DIR", "~/.cache/huggingface"),
|
||||
force_export: bool = False,
|
||||
):
|
||||
self.model_name = model_name.split("/")[-1]
|
||||
self.model_type = model_type
|
||||
self.optimize = optimization_level
|
||||
self.input_shapes = input_shapes
|
||||
self.cache_dir = os.path.join(cache_dir, self.repo_name)
|
||||
self.force_export = force_export
|
||||
|
||||
def download(self) -> str:
|
||||
model_path = os.path.join(self.cache_dir, self.model_type, "model.onnx")
|
||||
if os.path.isfile(model_path):
|
||||
print(f"Model is already downloaded at {model_path}")
|
||||
return model_path
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False
|
||||
)
|
||||
return model_path
|
||||
|
||||
def to_onnx_static(self, precision: Precision) -> str:
|
||||
import onnx
|
||||
from onnxconverter_common import float16
|
||||
onnx_path_original = self.download()
|
||||
static_dir = os.path.join(self.cache_dir, self.model_type, "static")
|
||||
|
||||
static_path = os.path.join(static_dir, f"model.onnx")
|
||||
if self.force_export and not os.path.isfile(static_path):
|
||||
print(f"Making {self} static")
|
||||
os.makedirs(static_dir, exist_ok=True)
|
||||
onnx_make_inputs_fixed(onnx_path_original, static_path, self.input_shapes)
|
||||
onnx_make_armnn_compatible(static_path)
|
||||
print(f"Finished making {self} static")
|
||||
|
||||
model = onnx.load(static_path)
|
||||
self.inputs = [input_.name for input_ in model.graph.input]
|
||||
self.outputs = [output_.name for output_ in model.graph.output]
|
||||
if precision == Precision.FLOAT16:
|
||||
static_path = os.path.join(static_dir, f"model_{precision}.onnx")
|
||||
print(f"Converting {self} to {precision} precision")
|
||||
model = float16.convert_float_to_float16(model, keep_io_types=True, disable_shape_infer=True)
|
||||
onnx.save(model, static_path)
|
||||
print(f"Finished converting {self} to {precision} precision")
|
||||
# self.inputs, self.outputs = onnx_get_inputs_outputs(static_path)
|
||||
return static_path
|
||||
|
||||
def to_tflite(self, output_dir: str, precision: Precision) -> str:
|
||||
onnx_model = self.to_onnx_static(precision)
|
||||
tflite_dir = os.path.join(output_dir, precision)
|
||||
tflite_model = os.path.join(tflite_dir, f"model_{precision}.tflite")
|
||||
if self.force_export or not os.path.isfile(tflite_model):
|
||||
import onnx2tf
|
||||
|
||||
print(f"Exporting {self} to TFLite with {precision} precision (this might take a few minutes)")
|
||||
onnx2tf.convert(
|
||||
input_onnx_file_path=onnx_model,
|
||||
output_folder_path=tflite_dir,
|
||||
keep_shape_absolutely_input_names=self.inputs,
|
||||
# verbosity="warn",
|
||||
copy_onnx_input_output_names_to_tflite=True,
|
||||
output_signaturedefs=True,
|
||||
not_use_onnxsim=True,
|
||||
)
|
||||
print(f"Finished exporting {self} to TFLite with {precision} precision")
|
||||
|
||||
return tflite_model
|
||||
|
||||
def to_armnn(self, output_dir: str, precision: Precision) -> tuple[str, str]:
|
||||
armnn_model = os.path.join(output_dir, "model.armnn")
|
||||
if not self.force_export and os.path.isfile(armnn_model):
|
||||
return armnn_model
|
||||
|
||||
tflite_model_dir = os.path.join(output_dir, "tflite")
|
||||
tflite_model = self.to_tflite(tflite_model_dir, precision)
|
||||
|
||||
args = ["./armnnconverter", "-f", "tflite-binary", "-m", tflite_model, "-p", armnn_model]
|
||||
args.append("-i")
|
||||
args.extend(self.inputs)
|
||||
args.append("-o")
|
||||
args.extend(self.outputs)
|
||||
|
||||
print(f"Exporting {self} to ARM NN with {precision} precision")
|
||||
try:
|
||||
if (stdout := subprocess.check_output(args, stderr=subprocess.STDOUT).decode()):
|
||||
print(stdout)
|
||||
print(f"Finished exporting {self} to ARM NN with {precision} precision")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output.decode())
|
||||
try:
|
||||
from shutil import rmtree
|
||||
|
||||
rmtree(tflite_model_dir, ignore_errors=True)
|
||||
finally:
|
||||
raise e
|
||||
|
||||
@property
|
||||
def repo_name(self) -> str:
|
||||
return f"immich-app/{self.model_name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.model_name} ({self.model_type})"
|
260
machine-learning/export/ann/onnx2ann/helpers.py
Normal file
260
machine-learning/export/ann/onnx2ann/helpers.py
Normal file
|
@ -0,0 +1,260 @@
|
|||
from typing import Any
|
||||
|
||||
|
||||
def onnx_make_armnn_compatible(model_path: str) -> None:
|
||||
"""
|
||||
i can explain
|
||||
armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
||||
this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
||||
it also switches from gather ops to slices since armnn has different dimension semantics for gathers
|
||||
also fixes batch normalization being in training mode
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx_graphsurgeon import Constant, Node, Variable, export_onnx, import_onnx
|
||||
|
||||
proto = onnx.load(model_path)
|
||||
graph = import_onnx(proto)
|
||||
|
||||
gather_idx = 1
|
||||
squeeze_idx = 1
|
||||
for node in graph.nodes:
|
||||
for link1 in node.outputs:
|
||||
if "Unsqueeze" in link1.name:
|
||||
for node1 in link1.outputs:
|
||||
for link2 in node1.outputs:
|
||||
if "Transpose" in link2.name:
|
||||
for node2 in link2.outputs:
|
||||
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
|
||||
node2.attrs["perm"] = [2, 0, 1, 3]
|
||||
link2.shape = link1.shape
|
||||
for link3 in node2.outputs:
|
||||
if "Squeeze" in link3.name:
|
||||
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
|
||||
for node3 in link3.outputs:
|
||||
for link4 in node3.outputs:
|
||||
link4.shape = link3.shape
|
||||
try:
|
||||
idx = link2.inputs.index(node1)
|
||||
link2.inputs[idx] = node
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
node.outputs = [link2]
|
||||
if "Gather" in link4.name:
|
||||
for node4 in link4.outputs:
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node4.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=link4.dtype,
|
||||
shape=[1] + link3.shape[1:],
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
link3,
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
for link5 in node4.outputs:
|
||||
for node5 in link5.outputs:
|
||||
try:
|
||||
idx = node5.inputs.index(link5)
|
||||
node5.inputs[idx] = slice_link
|
||||
except ValueError:
|
||||
pass
|
||||
elif node.op == "LayerNormalization":
|
||||
for node1 in link1.outputs:
|
||||
if node1.op == "Gather":
|
||||
for link2 in node1.outputs:
|
||||
for node2 in link2.outputs:
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node1.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=link2.dtype,
|
||||
shape=[1, *link2.shape],
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
node1.inputs[0],
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
squeeze_link = Variable(
|
||||
f"onnx::Squeeze_123{squeeze_idx}",
|
||||
dtype=link2.dtype,
|
||||
shape=link2.shape,
|
||||
)
|
||||
squeeze_node = Node(
|
||||
op="Squeeze",
|
||||
inputs=[
|
||||
slice_link,
|
||||
Constant(
|
||||
f"SqueezeAxis_123{squeeze_idx}",
|
||||
np.array([0]),
|
||||
),
|
||||
],
|
||||
outputs=[squeeze_link],
|
||||
name=f"Squeeze_123{squeeze_idx}",
|
||||
)
|
||||
graph.nodes.append(squeeze_node)
|
||||
squeeze_idx += 1
|
||||
try:
|
||||
idx = node2.inputs.index(link2)
|
||||
node2.inputs[idx] = squeeze_link
|
||||
except ValueError:
|
||||
pass
|
||||
elif node.op == "Reshape":
|
||||
for node1 in link1.outputs:
|
||||
if node1.op == "Gather":
|
||||
node2s = [n for link in node1.outputs for n in link.outputs]
|
||||
if any(n.op == "Abs" for n in node2s):
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node1.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=node1.outputs[0].dtype,
|
||||
shape=[1, *node1.outputs[0].shape],
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
node1.inputs[0],
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
squeeze_link = Variable(
|
||||
f"onnx::Squeeze_123{squeeze_idx}",
|
||||
dtype=node1.outputs[0].dtype,
|
||||
shape=node1.outputs[0].shape,
|
||||
)
|
||||
squeeze_node = Node(
|
||||
op="Squeeze",
|
||||
inputs=[
|
||||
slice_link,
|
||||
Constant(
|
||||
f"SqueezeAxis_123{squeeze_idx}",
|
||||
np.array([0]),
|
||||
),
|
||||
],
|
||||
outputs=[squeeze_link],
|
||||
name=f"Squeeze_123{squeeze_idx}",
|
||||
)
|
||||
graph.nodes.append(squeeze_node)
|
||||
squeeze_idx += 1
|
||||
for node2 in node2s:
|
||||
node2.inputs[0] = squeeze_link
|
||||
elif node.op == "BatchNormalization" and node.attrs.get("training_mode") == 1:
|
||||
node.attrs["training_mode"] = 0
|
||||
node.outputs = node.outputs[:1]
|
||||
|
||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||
graph.toposort()
|
||||
graph.fold_constants()
|
||||
updated = export_onnx(graph)
|
||||
onnx_save(updated, model_path)
|
||||
|
||||
# for some reason, reloading the model is necessary to apply the correct shape
|
||||
proto = onnx.load(model_path)
|
||||
graph = import_onnx(proto)
|
||||
for node in graph.nodes:
|
||||
if node.op == "Slice":
|
||||
for link in node.outputs:
|
||||
if "Slice_123" in link.name and link.shape[0] == 3: # noqa: PLR2004
|
||||
link.shape[0] = 1
|
||||
|
||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||
graph.toposort()
|
||||
graph.fold_constants()
|
||||
updated = export_onnx(graph)
|
||||
onnx_save(updated, model_path)
|
||||
onnx.shape_inference.infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
|
||||
|
||||
def onnx_make_inputs_fixed(input_path: str, output_path: str, input_shapes: list[tuple[int, ...]]) -> None:
|
||||
import onnx
|
||||
import onnxsim
|
||||
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
||||
|
||||
model, success = onnxsim.simplify(input_path)
|
||||
if not success:
|
||||
msg = f"Failed to simplify {input_path}"
|
||||
raise RuntimeError(msg)
|
||||
onnx_save(model, output_path)
|
||||
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
model = onnx.load_model(output_path)
|
||||
for input_node, shape in zip(model.graph.input, input_shapes, strict=False):
|
||||
make_input_shape_fixed(model.graph, input_node.name, shape)
|
||||
fix_output_shapes(model)
|
||||
onnx_save(model, output_path)
|
||||
onnx.shape_inference.infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
|
||||
|
||||
def onnx_get_inputs_outputs(model_path: str) -> tuple[list[str], list[str]]:
|
||||
import onnx
|
||||
|
||||
model = onnx.load(model_path)
|
||||
inputs = [input_.name for input_ in model.graph.input]
|
||||
outputs = [output_.name for output_ in model.graph.output]
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def onnx_save(model: Any, output_path: str) -> None:
|
||||
import onnx
|
||||
|
||||
try:
|
||||
onnx.save(model, output_path)
|
||||
except:
|
||||
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False, size_threshold=1_000_000)
|
56
machine-learning/export/ann/pyproject.toml
Normal file
56
machine-learning/export/ann/pyproject.toml
Normal file
|
@ -0,0 +1,56 @@
|
|||
[project]
|
||||
name = "onnx2ann"
|
||||
version = "1.107.2"
|
||||
dependencies = [
|
||||
"onnx>=1.16.1",
|
||||
"psutil>=6.0.0",
|
||||
"flatbuffers>=24.3.25",
|
||||
"ml_dtypes>=0.3.1,<1.0.0",
|
||||
"typer-slim>=0.12.3,<1.0.0",
|
||||
"huggingface_hub>=0.23.4,<1.0.0",
|
||||
"onnxruntime>=1.18.1",
|
||||
"onnxsim>=0.4.36,<1.0.0",
|
||||
"onnx2tf>=1.24.0",
|
||||
"onnx_graphsurgeon>=0.5.2,<1.0.0",
|
||||
"simple_onnx_processing_tools>=1.1.32",
|
||||
"tf_keras>=2.16.0",
|
||||
"onnxconverter-common @ git+https://github.com/microsoft/onnxconverter-common"
|
||||
]
|
||||
requires-python = ">=3.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
only-include = ["onnx2ann"]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
follow_imports = "silent"
|
||||
warn_redundant_casts = true
|
||||
disallow_any_generics = true
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pydantic-mypy]
|
||||
init_forbid_extra = true
|
||||
init_typed = true
|
||||
warn_required_dynamic_aliases = true
|
||||
warn_untyped_fields = true
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = ["E", "F", "I"]
|
||||
extend-ignore = ["FBT001", "FBT002"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py312']
|
|
@ -1,475 +0,0 @@
|
|||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
from typing import Callable, ClassVar
|
||||
|
||||
import onnx
|
||||
from onnx_graphsurgeon import Constant, Node, Variable, import_onnx, export_onnx
|
||||
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
||||
from huggingface_hub import snapshot_download
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
from huggingface_hub import login, upload_file
|
||||
import onnx2tf
|
||||
import numpy as np
|
||||
import onnxsim
|
||||
from shutil import rmtree
|
||||
|
||||
# hack: changed Mul op in onnx2tf to skip broadcast if graph_node.o().op == 'Sigmoid'
|
||||
|
||||
# i can explain
|
||||
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
||||
# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
||||
# it also switches from gather ops to slices since armnn has different dimension semantics for gathers
|
||||
# also fixes batch normalization being in training mode
|
||||
def make_onnx_armnn_compatible(model_path: str):
|
||||
proto = onnx.load(model_path)
|
||||
graph = import_onnx(proto)
|
||||
|
||||
gather_idx = 1
|
||||
squeeze_idx = 1
|
||||
for node in graph.nodes:
|
||||
for link1 in node.outputs:
|
||||
if "Unsqueeze" in link1.name:
|
||||
for node1 in link1.outputs:
|
||||
for link2 in node1.outputs:
|
||||
if "Transpose" in link2.name:
|
||||
for node2 in link2.outputs:
|
||||
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
|
||||
node2.attrs["perm"] = [2, 0, 1, 3]
|
||||
link2.shape = link1.shape
|
||||
for link3 in node2.outputs:
|
||||
if "Squeeze" in link3.name:
|
||||
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
|
||||
for node3 in link3.outputs:
|
||||
for link4 in node3.outputs:
|
||||
link4.shape = link3.shape
|
||||
try:
|
||||
idx = link2.inputs.index(node1)
|
||||
link2.inputs[idx] = node
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
node.outputs = [link2]
|
||||
if "Gather" in link4.name:
|
||||
for node4 in link4.outputs:
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node4.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=link4.dtype,
|
||||
shape=[1] + link3.shape[1:],
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
link3,
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
for link5 in node4.outputs:
|
||||
for node5 in link5.outputs:
|
||||
try:
|
||||
idx = node5.inputs.index(link5)
|
||||
node5.inputs[idx] = slice_link
|
||||
except ValueError:
|
||||
pass
|
||||
elif node.op == "LayerNormalization":
|
||||
for node1 in link1.outputs:
|
||||
if node1.op == "Gather":
|
||||
for link2 in node1.outputs:
|
||||
for node2 in link2.outputs:
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node1.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=link2.dtype,
|
||||
shape=[1] + link2.shape,
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
node1.inputs[0],
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
squeeze_link = Variable(
|
||||
f"onnx::Squeeze_123{squeeze_idx}",
|
||||
dtype=link2.dtype,
|
||||
shape=link2.shape,
|
||||
)
|
||||
squeeze_node = Node(
|
||||
op="Squeeze",
|
||||
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
|
||||
outputs=[squeeze_link],
|
||||
name=f"Squeeze_123{squeeze_idx}",
|
||||
)
|
||||
graph.nodes.append(squeeze_node)
|
||||
squeeze_idx += 1
|
||||
try:
|
||||
idx = node2.inputs.index(link2)
|
||||
node2.inputs[idx] = squeeze_link
|
||||
except ValueError:
|
||||
pass
|
||||
elif node.op == "Reshape":
|
||||
for node1 in link1.outputs:
|
||||
if node1.op == "Gather":
|
||||
node2s = [n for l in node1.outputs for n in l.outputs]
|
||||
if any(n.op == "Abs" for n in node2s):
|
||||
axis = node1.attrs.get("axis", 0)
|
||||
index = node1.inputs[1].values
|
||||
slice_link = Variable(
|
||||
f"onnx::Slice_123{gather_idx}",
|
||||
dtype=node1.outputs[0].dtype,
|
||||
shape=[1] + node1.outputs[0].shape,
|
||||
)
|
||||
slice_node = Node(
|
||||
op="Slice",
|
||||
inputs=[
|
||||
node1.inputs[0],
|
||||
Constant(
|
||||
f"SliceStart_123{gather_idx}",
|
||||
np.array([index]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceEnd_123{gather_idx}",
|
||||
np.array([index + 1]),
|
||||
),
|
||||
Constant(
|
||||
f"SliceAxis_123{gather_idx}",
|
||||
np.array([axis]),
|
||||
),
|
||||
],
|
||||
outputs=[slice_link],
|
||||
name=f"Slice_123{gather_idx}",
|
||||
)
|
||||
graph.nodes.append(slice_node)
|
||||
gather_idx += 1
|
||||
|
||||
squeeze_link = Variable(
|
||||
f"onnx::Squeeze_123{squeeze_idx}",
|
||||
dtype=node1.outputs[0].dtype,
|
||||
shape=node1.outputs[0].shape,
|
||||
)
|
||||
squeeze_node = Node(
|
||||
op="Squeeze",
|
||||
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
|
||||
outputs=[squeeze_link],
|
||||
name=f"Squeeze_123{squeeze_idx}",
|
||||
)
|
||||
graph.nodes.append(squeeze_node)
|
||||
squeeze_idx += 1
|
||||
for node2 in node2s:
|
||||
node2.inputs[0] = squeeze_link
|
||||
elif node.op == "BatchNormalization":
|
||||
if node.attrs.get("training_mode") == 1:
|
||||
node.attrs["training_mode"] = 0
|
||||
node.outputs = node.outputs[:1]
|
||||
|
||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||
graph.toposort()
|
||||
graph.fold_constants()
|
||||
updated = export_onnx(graph)
|
||||
onnx.save(updated, model_path)
|
||||
# infer_shapes_path(updated, check_type=True, strict_mode=False, data_prop=True)
|
||||
|
||||
# for some reason, reloading the model is necessary to apply the correct shape
|
||||
proto = onnx.load(model_path)
|
||||
graph = import_onnx(proto)
|
||||
for node in graph.nodes:
|
||||
if node.op == "Slice":
|
||||
for link in node.outputs:
|
||||
if "Slice_123" in link.name and link.shape[0] == 3:
|
||||
link.shape[0] = 1
|
||||
|
||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||
graph.toposort()
|
||||
graph.fold_constants()
|
||||
updated = export_onnx(graph)
|
||||
onnx.save(updated, model_path)
|
||||
infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
|
||||
|
||||
def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, ...]):
|
||||
simplified, success = onnxsim.simplify(input_path)
|
||||
if not success:
|
||||
raise RuntimeError(f"Failed to simplify {input_path}")
|
||||
try:
|
||||
onnx.save(simplified, output_path)
|
||||
except:
|
||||
onnx.save(simplified, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
||||
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
model = onnx.load_model(output_path)
|
||||
make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape)
|
||||
fix_output_shapes(model)
|
||||
try:
|
||||
onnx.save(model, output_path)
|
||||
except:
|
||||
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
||||
onnx.save(model, output_path)
|
||||
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
||||
|
||||
|
||||
class ExportBase:
|
||||
task: ClassVar[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
input_shape: tuple[int, ...],
|
||||
pretrained: str | None = None,
|
||||
optimization_level: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.optimize = optimization_level
|
||||
self.input_shape = input_shape
|
||||
self.pretrained = pretrained
|
||||
self.cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name)
|
||||
|
||||
def download(self) -> str:
|
||||
model_path = os.path.join(self.cache_dir, self.task, "model.onnx")
|
||||
if not os.path.isfile(model_path):
|
||||
print(f"Downloading {self.model_name}...")
|
||||
snapshot_download(self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False)
|
||||
return model_path
|
||||
|
||||
def to_onnx_static(self) -> str:
|
||||
onnx_path_original = self.download()
|
||||
static_dir = os.path.join(self.cache_dir, self.task, "static")
|
||||
os.makedirs(static_dir, exist_ok=True)
|
||||
|
||||
static_path = os.path.join(static_dir, "model.onnx")
|
||||
if not os.path.isfile(static_path):
|
||||
print(f"Making {self.model_name} ({self.task}) static")
|
||||
onnx_make_fixed(onnx_path_original, static_path, self.input_shape)
|
||||
make_onnx_armnn_compatible(static_path)
|
||||
static_model = onnx.load_model(static_path)
|
||||
self.inputs = [input_.name for input_ in static_model.graph.input]
|
||||
self.outputs = [output_.name for output_ in static_model.graph.output]
|
||||
return static_path
|
||||
|
||||
def to_tflite(self, output_dir: str) -> tuple[str, str]:
|
||||
input_path = self.to_onnx_static()
|
||||
tflite_fp32 = os.path.join(output_dir, "model_float32.tflite")
|
||||
tflite_fp16 = os.path.join(output_dir, "model_float16.tflite")
|
||||
if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16):
|
||||
print(f"Exporting {self.model_name} ({self.task}) to TFLite (this might take a few minutes)")
|
||||
onnx2tf.convert(
|
||||
input_onnx_file_path=input_path,
|
||||
output_folder_path=output_dir,
|
||||
keep_shape_absolutely_input_names=self.inputs,
|
||||
verbosity="warn",
|
||||
copy_onnx_input_output_names_to_tflite=True,
|
||||
output_signaturedefs=True,
|
||||
)
|
||||
|
||||
return tflite_fp32, tflite_fp16
|
||||
|
||||
def to_armnn(self, output_dir: str) -> tuple[str, str]:
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
tflite_model_dir = os.path.join(output_dir, "tflite")
|
||||
tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir)
|
||||
|
||||
fp16_dir = os.path.join(output_dir, "fp16")
|
||||
os.makedirs(fp16_dir, exist_ok=True)
|
||||
armnn_fp32 = os.path.join(output_dir, "model.armnn")
|
||||
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
|
||||
|
||||
args = ["./armnnconverter", "-f", "tflite-binary"]
|
||||
args.append("-i")
|
||||
args.extend(self.inputs)
|
||||
args.append("-o")
|
||||
args.extend(self.outputs)
|
||||
|
||||
fp32_args = args.copy()
|
||||
fp32_args.extend(["-m", tflite_fp32, "-p", armnn_fp32])
|
||||
|
||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
|
||||
try:
|
||||
print(subprocess.check_output(fp32_args, stderr=subprocess.STDOUT).decode())
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output.decode())
|
||||
try:
|
||||
rmtree(tflite_model_dir, ignore_errors=True)
|
||||
finally:
|
||||
raise e
|
||||
print(f"Finished exporting {self.model_name} ({self.task}) with fp32 precision")
|
||||
|
||||
fp16_args = args.copy()
|
||||
fp16_args.extend(["-m", tflite_fp16, "-p", armnn_fp16])
|
||||
|
||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
|
||||
try:
|
||||
print(subprocess.check_output(fp16_args, stderr=subprocess.STDOUT).decode())
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output.decode())
|
||||
try:
|
||||
rmtree(tflite_model_dir, ignore_errors=True)
|
||||
finally:
|
||||
raise e
|
||||
print(f"Finished exporting {self.model_name} ({self.task}) with fp16 precision")
|
||||
|
||||
return armnn_fp32, armnn_fp16
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return f"{self.name}__{self.pretrained}" if self.pretrained else self.name
|
||||
|
||||
@property
|
||||
def repo_name(self) -> str:
|
||||
return f"immich-app/{self.model_name}"
|
||||
|
||||
class ArcFace(ExportBase):
|
||||
task = "recognition"
|
||||
|
||||
|
||||
class RetinaFace(ExportBase):
|
||||
task = "detection"
|
||||
|
||||
|
||||
class OpenClipVisual(ExportBase):
|
||||
task = "visual"
|
||||
|
||||
|
||||
class OpenClipTextual(ExportBase):
|
||||
task = "textual"
|
||||
|
||||
|
||||
class MClipTextual(ExportBase):
|
||||
task = "textual"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if platform.machine() not in ("x86_64", "AMD64"):
|
||||
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
|
||||
hf_token = os.environ.get("HF_AUTH_TOKEN")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
os.environ["LD_LIBRARY_PATH"] = "armnn"
|
||||
failed: list[Callable[[], ExportBase]] = [
|
||||
lambda: OpenClipVisual("ViT-H-14-378-quickgelu", (1, 3, 378, 378), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
|
||||
lambda: OpenClipVisual("ViT-H-14-quickgelu", (1, 3, 224, 224), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
|
||||
lambda: OpenClipVisual("ViT-H-14", (1, 3, 224, 224), pretrained="laion2b-s32b-b79k"),
|
||||
lambda: OpenClipTextual("ViT-H-14", (1, 77), pretrained="laion2b-s32b-b79k"),
|
||||
lambda: OpenClipVisual("ViT-g-14", (1, 3, 224, 224), pretrained="laion2b-s12b-b42k"),
|
||||
lambda: OpenClipTextual("ViT-g-14", (1, 77), pretrained="laion2b-s12b-b42k"),
|
||||
lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-16Plus", (1, 3, 240, 240)),
|
||||
lambda: OpenClipVisual("XLM-Roberta-Large-ViT-H-14", (1, 3, 224, 224), pretrained="frozen_laion5b_s13b_b90k"),
|
||||
lambda: MClipTextual("XLM-Roberta-Large-Vit-L-14", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
||||
lambda: MClipTextual("XLM-Roberta-Large-Vit-B-16Plus", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
||||
lambda: MClipTextual("LABSE-Vit-L-14", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
||||
lambda: OpenClipTextual("XLM-Roberta-Large-ViT-H-14", (1, 77), pretrained="frozen_laion5b_s13b_b90k"), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
||||
]
|
||||
|
||||
oom = [
|
||||
lambda: OpenClipVisual("nllb-clip-base-siglip", (1, 3, 384, 384), pretrained="v1"),
|
||||
lambda: OpenClipTextual("nllb-clip-base-siglip", (1, 77), pretrained="v1"),
|
||||
lambda: OpenClipVisual("nllb-clip-large-siglip", (1, 3, 384, 384), pretrained="v1"),
|
||||
lambda: OpenClipTextual("nllb-clip-large-siglip", (1, 77), pretrained="v1"), # ERROR (tinynn.converter.base) Unsupported ops: aten::logical_not
|
||||
# lambda: OpenClipTextual("ViT-H-14-quickgelu", (1, 77), pretrained="dfn5b"),
|
||||
# lambda: OpenClipTextual("ViT-H-14-378-quickgelu", (1, 77), pretrained="dfn5b"),
|
||||
# lambda: OpenClipVisual("XLM-Roberta-Large-Vit-L-14", (1, 3, 224, 224)),
|
||||
]
|
||||
|
||||
succeeded: list[Callable[[], ExportBase]] = [
|
||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion2b_e16"),
|
||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion2b_e16"),
|
||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion400m_e31"),
|
||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion400m_e32"),
|
||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion2b-s34b-b79k"),
|
||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion2b-s34b-b79k"),
|
||||
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
||||
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="laion400m_e31"),
|
||||
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
||||
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="laion400m_e32"),
|
||||
# lambda: OpenClipVisual("ViT-B-16-plus-240", (1, 3, 240, 240), pretrained="laion400m_e31"),
|
||||
# lambda: OpenClipTextual("ViT-B-16-plus-240", (1, 77), pretrained="laion400m_e31"),
|
||||
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="openai"),
|
||||
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="openai"),
|
||||
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="openai"),
|
||||
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="openai"),
|
||||
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="openai"),
|
||||
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="openai"),
|
||||
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="yfcc15m"),
|
||||
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="yfcc15m"),
|
||||
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="cc12m"),
|
||||
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="cc12m"),
|
||||
# lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-32", (1, 3, 224, 224)),
|
||||
# lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="openai"),
|
||||
# lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="openai"),
|
||||
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
||||
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion400m_e31"),
|
||||
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
||||
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion400m_e32"),
|
||||
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion2b-s32b-b82k"),
|
||||
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion2b-s32b-b82k"),
|
||||
# lambda: OpenClipVisual("ViT-L-14-336", (1, 3, 336, 336), pretrained="openai"),
|
||||
# lambda: OpenClipTextual("ViT-L-14-336", (1, 77), pretrained="openai"),
|
||||
# lambda: ArcFace("buffalo_s", (1, 3, 112, 112), optimization_level=3),
|
||||
# lambda: RetinaFace("buffalo_s", (1, 3, 640, 640), optimization_level=3),
|
||||
# lambda: ArcFace("buffalo_m", (1, 3, 112, 112), optimization_level=3),
|
||||
# lambda: RetinaFace("buffalo_m", (1, 3, 640, 640), optimization_level=3),
|
||||
# lambda: ArcFace("buffalo_l", (1, 3, 112, 112), optimization_level=3),
|
||||
# lambda: RetinaFace("buffalo_l", (1, 3, 640, 640), optimization_level=3),
|
||||
# lambda: ArcFace("antelopev2", (1, 3, 112, 112), optimization_level=3),
|
||||
# lambda: RetinaFace("antelopev2", (1, 3, 640, 640), optimization_level=3),
|
||||
]
|
||||
|
||||
models: list[Callable[[], ExportBase]] = [*failed, *succeeded]
|
||||
for _model in succeeded:
|
||||
model = _model()
|
||||
try:
|
||||
model_dir = os.path.join("output", model.model_name)
|
||||
output_dir = os.path.join(model_dir, model.task)
|
||||
armnn_fp32, armnn_fp16 = model.to_armnn(output_dir)
|
||||
relative_fp32 = os.path.relpath(armnn_fp32, start=model_dir)
|
||||
relative_fp16 = os.path.relpath(armnn_fp16, start=model_dir)
|
||||
if hf_token and os.path.isfile(armnn_fp32):
|
||||
print(f"Uploading {model.model_name} ({model.task}) ARM NN model with fp32 precision")
|
||||
upload_file(path_or_fileobj=armnn_fp32, path_in_repo=relative_fp32, repo_id=model.repo_name)
|
||||
print(f"Finished uploading {model.model_name} ({model.task}) ARM NN model with fp32 precision")
|
||||
if hf_token and os.path.isfile(armnn_fp16):
|
||||
print(f"Uploading {model.model_name} ({model.task}) ARM NN model with fp16 precision")
|
||||
upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name)
|
||||
print(f"Finished uploading {model.model_name} ({model.task}) ARM NN model with fp16 precision")
|
||||
except Exception as exc:
|
||||
print(f"Failed to export {model.model_name} ({model.task}): {exc}")
|
||||
raise exc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Reference in a new issue