From 5dae920ac6e80fe54115eb3bc67432e3511e2da4 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Sat, 6 Jul 2024 21:17:55 -0400 Subject: [PATCH] onnx2tf, 4d transpose --- machine-learning/export/ann/run.py | 331 ++++++++++++++--------------- 1 file changed, 160 insertions(+), 171 deletions(-) diff --git a/machine-learning/export/ann/run.py b/machine-learning/export/ann/run.py index f3f9610a99..1475f1d670 100644 --- a/machine-learning/export/ann/run.py +++ b/machine-learning/export/ann/run.py @@ -1,52 +1,63 @@ import os import platform import subprocess -from tempfile import TemporaryDirectory from typing import Callable, ClassVar import onnx -import torch -from onnx2torch import convert -from onnx2torch.node_converters.registry import add_converter +from onnx_graphsurgeon import import_onnx, export_onnx from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed -from tinynn.converter import TFLiteConverter from huggingface_hub import snapshot_download -from onnx2torch.onnx_graph import OnnxGraph -from onnx2torch.onnx_node import OnnxNode -from onnx2torch.utils.common import OperationConverterResult, onnx_mapping_from_node from onnx.shape_inference import infer_shapes_path from huggingface_hub import login, upload_file +import onnx2tf -# egregious hacks: -# changed `Clip`'s min/max logic to skip empty strings -# changed OnnxSqueezeDynamicAxes to use `sorted` instead of `torch.sort`` -# commented out shape inference in `fix_output_shapes`` +# i can explain +# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze +# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose +def onnx_transpose_4d(model_path: str): + proto = onnx.load(model_path) + graph = import_onnx(proto) + + for node in graph.nodes: + for i, link1 in enumerate(node.outputs): + if "Unsqueeze" in link1.name: + for node1 in link1.outputs: + for link2 in node1.outputs: + if "Transpose" in link2.name: + for node2 in link2.outputs: + if node2.attrs.get("perm") == [3, 1, 2, 0, 4]: + node2.attrs["perm"] = [2, 0, 1, 3] + link2.shape = link1.shape + for link3 in node2.outputs: + if "Squeeze" in link3.name: + for node3 in link3.outputs: + for link4 in node3.outputs: + link4.shape = [link3.shape[x] for x in [0, 1, 2, 4]] + for inputs in link4.inputs: + if inputs.name == node3.name: + i = link2.inputs.index(node1) + if i >= 0: + link2.inputs[i] = node + + i = link4.inputs.index(node3) + if i >= 0: + link4.inputs[i] = node2 + + node.outputs = [link2] + node1.inputs = [] + node1.outputs = [] + node3.inputs = [] + node3.outputs = [] + + graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True) + graph.toposort() + graph.fold_constants() + updated = export_onnx(graph) + onnx.save(updated, model_path, save_as_external_data=True, all_tensors_to_one_file=False) + infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True) -class ArgMax(torch.nn.Module): - def __init__(self, dim: int = -1, keepdim: bool = False): - super().__init__() - self.dim = dim - self.keepdim = keepdim - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.argmax(input, dim=self.dim, keepdim=self.keepdim) - - -class Erf(torch.nn.Module): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.erf(input) - - -@add_converter(operation_type="ArgMax", version=13) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: - return OperationConverterResult( - torch_module=ArgMax(), - onnx_mapping=onnx_mapping_from_node(node=node), - ) - - -class ExportBase(torch.nn.Module): +class ExportBase: task: ClassVar[str] def __init__( @@ -62,84 +73,94 @@ class ExportBase(torch.nn.Module): self.nchw_transpose = False self.input_shape = input_shape self.pretrained = pretrained - self.dummy_param = torch.nn.Parameter(torch.empty(0)) - self.model = self.load().eval() - for param in self.parameters(): - param.requires_grad_(False) - self.eval() - def load(self) -> torch.nn.Module: + def to_onnx_static(self) -> str: cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name) task_path = os.path.join(cache_dir, self.task) model_path = os.path.join(task_path, "model.onnx") if not os.path.isfile(model_path): + print(f"Downloading {self.model_name}...") snapshot_download(self.repo_name, cache_dir=cache_dir, local_dir=cache_dir) - infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True) - onnx_model = onnx.load_model(model_path) - make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape) - fix_output_shapes(onnx_model) - # try: - # onnx.save(onnx_model, model_path) - # except: - # onnx.save(onnx_model, model_path, save_as_external_data=True, all_tensors_to_one_file=False) - # infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True) - # onnx_model = onnx.load_model(model_path) - # onnx_model = infer_shapes(onnx_model, check_type=True, strict_mode=True, data_prop=True) - return convert(onnx_model) - - def forward(self, *inputs: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]: - if self.precision == "fp16": - inputs = tuple(i.half() for i in inputs) - - out = self._forward(*inputs) - if self.precision == "fp16": - if isinstance(out, tuple): - return tuple(o.float() for o in out) - return out.float() - return out - - def _forward(self, *inputs: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]: - return self.model(*inputs) - - def to_armnn(self, output_path: str) -> None: - output_dir = os.path.dirname(output_path) + + static_dir = os.path.join(task_path, "static") + static_path = os.path.join(static_dir, "model.onnx") + os.makedirs(static_dir, exist_ok=True) + + if not os.path.isfile(static_path): + print(f"Making {self.model_name} ({self.task}) static") + infer_shapes_path(onnx_path_original, check_type=True, strict_mode=True, data_prop=True) + onnx_path_original = os.path.join(cache_dir, "model.onnx") + static_model = onnx.load_model(onnx_path_original) + make_input_shape_fixed(static_model.graph, static_model.graph.input[0].name, (1, 3, 224, 224)) + fix_output_shapes(static_model) + onnx.save(static_model, static_path, save_as_external_data=True, all_tensors_to_one_file=False) + infer_shapes_path(static_path, check_type=True, strict_mode=True, data_prop=True) + onnx_transpose_4d(static_path) + return static_path + + def to_tflite(self, output_dir: str) -> tuple[str, str]: + input_path = self.to_onnx_static() os.makedirs(output_dir, exist_ok=True) - self(*self.dummy_inputs) - print(f"Exporting {self.model_name} ({self.task}) with {self.precision} precision") - jit = torch.jit.trace(self, self.dummy_inputs).eval() - with TemporaryDirectory() as tmpdir: - tflite_model_path = os.path.join(tmpdir, "model.tflite") - converter = TFLiteConverter( - jit, - self.dummy_inputs, - tflite_model_path, - optimize=self.optimize, - nchw_transpose=self.nchw_transpose, + tflite_fp32 = os.path.join(output_dir, "model_float32.tflite") + tflite_fp16 = os.path.join(output_dir, "model_float16.tflite") + if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16): + print(f"Exporting {self.model_name} ({self.task}) to TFLite") + onnx2tf.convert( + input_onnx_file_path=input_path, + output_folder_path=output_dir, + copy_onnx_input_output_names_to_tflite=True, ) - # segfaults on ARM, must run on x86_64 / AMD64 - converter.convert() + + return tflite_fp32, tflite_fp16 - subprocess.run( - [ - "./armnnconverter", - "-f", - "tflite-binary", - "-m", - tflite_model_path, - "-i", - "input_tensor", - "-o", - "output_tensor", - "-p", - output_path, - ], - capture_output=True, - ) - print(f"Finished exporting {self.name} ({self.task}) with {self.precision} precision") - - @property - def dummy_inputs(self) -> tuple[torch.FloatTensor]: - return (torch.rand(self.input_shape, device=self.device, dtype=self.dtype),) + def to_armnn(self, output_dir: str) -> tuple[str, str]: + tflite_model_dir = os.path.join(output_dir, "tflite") + tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir) + + fp16_dir = os.path.join(output_dir, "fp16") + os.makedirs(fp16_dir, exist_ok=True) + armnn_fp32 = os.path.join(output_dir, "model.armnn") + armnn_fp16 = os.path.join(fp16_dir, "model.armnn") + + print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision") + subprocess.run( + [ + "./armnnconverter", + "-f", + "tflite-binary", + "-m", + tflite_fp32, + "-i", + "input_tensor", + "-o", + "output_tensor", + "-p", + armnn_fp32, + ], + capture_output=True, + ) + print(f"Finished exporting {self.name} ({self.task}) with fp32 precision") + + print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision") + subprocess.run( + [ + "./armnnconverter", + "-f", + "tflite-binary", + "-m", + tflite_fp16, + "-i", + "input_tensor", + "-o", + "output_tensor", + "-p", + armnn_fp16, + ], + capture_output=True, + ) + print(f"Finished exporting {self.name} ({self.task}) with fp16 precision") + + return armnn_fp32, armnn_fp16 @property def model_name(self) -> str: @@ -149,25 +170,6 @@ class ExportBase(torch.nn.Module): def repo_name(self) -> str: return f"immich-app/{self.model_name}" - @property - def device(self) -> torch.device: - return self.dummy_param.device - - @property - def dtype(self) -> torch.dtype: - return self.dummy_param.dtype - - @property - def precision(self) -> str: - match self.dtype: - case torch.float32: - return "fp32" - case torch.float16: - return "fp16" - case _: - raise ValueError(f"Unsupported dtype {self.dtype}") - - class ArcFace(ExportBase): task = "recognition" @@ -183,28 +185,18 @@ class OpenClipVisual(ExportBase): class OpenClipTextual(ExportBase): task = "textual" - @property - def dummy_inputs(self) -> tuple[torch.LongTensor]: - return (torch.randint(0, 5000, self.input_shape, device=self.device, dtype=torch.int32),) - class MClipTextual(ExportBase): task = "textual" - @property - def dummy_inputs(self) -> tuple[torch.LongTensor]: - return ( - torch.randint(0, 5000, self.input_shape, device=self.device, dtype=torch.int32), - torch.randint(0, 1, self.input_shape, device=self.device, dtype=torch.int32), - ) - def main() -> None: if platform.machine() not in ("x86_64", "AMD64"): raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}") - login(token=os.environ["HF_AUTH_TOKEN"]) + upload_to_hf = "HF_AUTH_TOKEN" in os.environ + if upload_to_hf: + login(token=os.environ["HF_AUTH_TOKEN"]) os.environ["LD_LIBRARY_PATH"] = "armnn" - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") failed: list[Callable[[], ExportBase]] = [ lambda: OpenClipVisual("ViT-H-14-378-quickgelu", (1, 3, 378, 378), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16) lambda: OpenClipVisual("ViT-H-14-quickgelu", (1, 3, 224, 224), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16) @@ -249,49 +241,46 @@ def main() -> None: ] succeeded: list[Callable[[], ExportBase]] = [ - lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="openai"), - lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="openai"), + # lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="openai"), + # lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="openai"), lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="openai"), lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="openai"), - lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="openai"), - lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="openai"), - lambda: OpenClipVisual("ViT-L-14-336", (1, 3, 336, 336), pretrained="openai"), - lambda: OpenClipTextual("ViT-L-14-336", (1, 77), pretrained="openai"), - lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="openai"), - lambda: OpenClipTextual("RN50", (1, 77), pretrained="openai"), - lambda: OpenClipTextual("ViT-H-14-quickgelu", (1, 77), pretrained="dfn5b"), - lambda: OpenClipTextual("ViT-H-14-378-quickgelu", (1, 77), pretrained="dfn5b"), - lambda: OpenClipVisual("XLM-Roberta-Large-Vit-L-14", (1, 3, 224, 224)), - lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-32", (1, 3, 224, 224)), - lambda: ArcFace("buffalo_s", (1, 3, 112, 112), optimization_level=3), - lambda: RetinaFace("buffalo_s", (1, 3, 640, 640), optimization_level=3), - lambda: ArcFace("buffalo_m", (1, 3, 112, 112), optimization_level=3), - lambda: RetinaFace("buffalo_m", (1, 3, 640, 640), optimization_level=3), - lambda: ArcFace("buffalo_l", (1, 3, 112, 112), optimization_level=3), - lambda: RetinaFace("buffalo_l", (1, 3, 640, 640), optimization_level=3), - lambda: ArcFace("antelopev2", (1, 3, 112, 112), optimization_level=3), - lambda: RetinaFace("antelopev2", (1, 3, 640, 640), optimization_level=3), + # lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="openai"), + # lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="openai"), + # lambda: OpenClipVisual("ViT-L-14-336", (1, 3, 336, 336), pretrained="openai"), + # lambda: OpenClipTextual("ViT-L-14-336", (1, 77), pretrained="openai"), + # lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="openai"), + # lambda: OpenClipTextual("RN50", (1, 77), pretrained="openai"), + # lambda: OpenClipTextual("ViT-H-14-quickgelu", (1, 77), pretrained="dfn5b"), + # lambda: OpenClipTextual("ViT-H-14-378-quickgelu", (1, 77), pretrained="dfn5b"), + # lambda: OpenClipVisual("XLM-Roberta-Large-Vit-L-14", (1, 3, 224, 224)), + # lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-32", (1, 3, 224, 224)), + # lambda: ArcFace("buffalo_s", (1, 3, 112, 112), optimization_level=3), + # lambda: RetinaFace("buffalo_s", (1, 3, 640, 640), optimization_level=3), + # lambda: ArcFace("buffalo_m", (1, 3, 112, 112), optimization_level=3), + # lambda: RetinaFace("buffalo_m", (1, 3, 640, 640), optimization_level=3), + # lambda: ArcFace("buffalo_l", (1, 3, 112, 112), optimization_level=3), + # lambda: RetinaFace("buffalo_l", (1, 3, 640, 640), optimization_level=3), + # lambda: ArcFace("antelopev2", (1, 3, 112, 112), optimization_level=3), + # lambda: RetinaFace("antelopev2", (1, 3, 640, 640), optimization_level=3), ] models: list[Callable[[], ExportBase]] = [*failed, *succeeded] for _model in succeeded: - model = _model().to(device) + model = _model() try: - relative_path = os.path.join(model.task, "model.armnn") - output_path = os.path.join("output", model.model_name, relative_path) - model.to_armnn(output_path) - upload_file(path_or_fileobj=output_path, path_in_repo=relative_path, repo_id=model.repo_name) - if device == torch.device("cuda"): - model.half() - relative_path = os.path.join(model.task, "fp16", "model.armnn") - output_path = os.path.join("output", model.model_name, relative_path) - model.to_armnn(output_path) - upload_file(path_or_fileobj=output_path, path_in_repo=relative_path, repo_id=model.repo_name) - + model_dir = os.path.join("output", model.model_name) + output_dir = os.path.join(model_dir, model.task) + armnn_fp32, armnn_fp16 = model.to_armnn(output_dir) + relative_fp32 = os.path.relpath(armnn_fp32, start=model_dir) + relative_fp16 = os.path.relpath(armnn_fp16, start=model_dir) + if upload_to_hf and os.path.isfile(armnn_fp32): + upload_file(path_or_fileobj=armnn_fp32, path_in_repo=relative_fp32, repo_id=model.repo_name) + if upload_to_hf and os.path.isfile(armnn_fp16): + upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name) except Exception as exc: print(f"Failed to export {model.model_name} ({model.task}): {exc}") if __name__ == "__main__": - with torch.no_grad(): - main() + main()