mirror of
https://github.com/immich-app/immich.git
synced 2025-01-21 00:52:43 -05:00
handle gather at the end
This commit is contained in:
parent
1ad348c407
commit
3d62011ae3
1 changed files with 77 additions and 40 deletions
|
@ -10,19 +10,19 @@ from huggingface_hub import snapshot_download
|
||||||
from onnx.shape_inference import infer_shapes_path
|
from onnx.shape_inference import infer_shapes_path
|
||||||
from huggingface_hub import login, upload_file
|
from huggingface_hub import login, upload_file
|
||||||
import onnx2tf
|
import onnx2tf
|
||||||
from itertools import chain
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxsim
|
import onnxsim
|
||||||
|
|
||||||
# i can explain
|
# i can explain
|
||||||
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
||||||
# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
||||||
# it also switches from gather ops to slices since armnn doesn't support 3d gather
|
# it also switches from gather ops to slices since armnn has different dimension semantics for gathers
|
||||||
def onnx_transpose_4d(model_path: str):
|
def onnx_transpose_4d(model_path: str):
|
||||||
proto = onnx.load(model_path)
|
proto = onnx.load(model_path)
|
||||||
graph = import_onnx(proto)
|
graph = import_onnx(proto)
|
||||||
|
|
||||||
gather_idx = 1
|
gather_idx = 1
|
||||||
|
squeeze_idx = 1
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
for link1 in node.outputs:
|
for link1 in node.outputs:
|
||||||
if "Unsqueeze" in link1.name:
|
if "Unsqueeze" in link1.name:
|
||||||
|
@ -48,6 +48,7 @@ def onnx_transpose_4d(model_path: str):
|
||||||
node.outputs = [link2]
|
node.outputs = [link2]
|
||||||
if "Gather" in link4.name:
|
if "Gather" in link4.name:
|
||||||
for node4 in link4.outputs:
|
for node4 in link4.outputs:
|
||||||
|
axis = node1.attrs.get("axis", 0)
|
||||||
index = node4.inputs[1].values
|
index = node4.inputs[1].values
|
||||||
slice_link = Variable(
|
slice_link = Variable(
|
||||||
f"onnx::Slice_123{gather_idx}",
|
f"onnx::Slice_123{gather_idx}",
|
||||||
|
@ -60,11 +61,15 @@ def onnx_transpose_4d(model_path: str):
|
||||||
link3,
|
link3,
|
||||||
Constant(
|
Constant(
|
||||||
f"SliceStart_123{gather_idx}",
|
f"SliceStart_123{gather_idx}",
|
||||||
np.array([index, 0, 0, 0]),
|
np.array([index]),
|
||||||
),
|
),
|
||||||
Constant(
|
Constant(
|
||||||
f"SliceEnd_123{gather_idx}",
|
f"SliceEnd_123{gather_idx}",
|
||||||
np.array([index + 1] + link3.shape[1:]),
|
np.array([index + 1]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceAxis_123{gather_idx}",
|
||||||
|
np.array([axis]),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[slice_link],
|
outputs=[slice_link],
|
||||||
|
@ -80,6 +85,59 @@ def onnx_transpose_4d(model_path: str):
|
||||||
node5.inputs[idx] = slice_link
|
node5.inputs[idx] = slice_link
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
elif node.op == "LayerNormalization":
|
||||||
|
for node1 in link1.outputs:
|
||||||
|
if node1.op == "Gather":
|
||||||
|
for link2 in node1.outputs:
|
||||||
|
for node2 in link2.outputs:
|
||||||
|
axis = node1.attrs.get("axis", 0)
|
||||||
|
index = node1.inputs[1].values
|
||||||
|
slice_link = Variable(
|
||||||
|
f"onnx::Slice_123{gather_idx}",
|
||||||
|
dtype=link2.dtype,
|
||||||
|
shape=[1] + link2.shape,
|
||||||
|
)
|
||||||
|
slice_node = Node(
|
||||||
|
op="Slice",
|
||||||
|
inputs=[
|
||||||
|
node1.inputs[0],
|
||||||
|
Constant(
|
||||||
|
f"SliceStart_123{gather_idx}",
|
||||||
|
np.array([index]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceEnd_123{gather_idx}",
|
||||||
|
np.array([index + 1]),
|
||||||
|
),
|
||||||
|
Constant(
|
||||||
|
f"SliceAxis_123{gather_idx}",
|
||||||
|
np.array([axis]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[slice_link],
|
||||||
|
name=f"Slice_123{gather_idx}",
|
||||||
|
)
|
||||||
|
graph.nodes.append(slice_node)
|
||||||
|
gather_idx += 1
|
||||||
|
|
||||||
|
squeeze_link = Variable(
|
||||||
|
f"onnx::Squeeze_123{squeeze_idx}",
|
||||||
|
dtype=link2.dtype,
|
||||||
|
shape=link2.shape,
|
||||||
|
)
|
||||||
|
squeeze_node = Node(
|
||||||
|
op="Squeeze",
|
||||||
|
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
|
||||||
|
outputs=[squeeze_link],
|
||||||
|
name=f"Squeeze_123{squeeze_idx}",
|
||||||
|
)
|
||||||
|
graph.nodes.append(squeeze_node)
|
||||||
|
squeeze_idx += 1
|
||||||
|
try:
|
||||||
|
idx = node2.inputs.index(link2)
|
||||||
|
node2.inputs[idx] = squeeze_link
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
||||||
graph.toposort()
|
graph.toposort()
|
||||||
|
@ -149,6 +207,7 @@ class ExportBase:
|
||||||
os.makedirs(static_dir, exist_ok=True)
|
os.makedirs(static_dir, exist_ok=True)
|
||||||
|
|
||||||
static_path = os.path.join(static_dir, "model.onnx")
|
static_path = os.path.join(static_dir, "model.onnx")
|
||||||
|
if not os.path.isfile(static_path):
|
||||||
print(f"Making {self.model_name} ({self.task}) static")
|
print(f"Making {self.model_name} ({self.task}) static")
|
||||||
onnx_make_fixed(onnx_path_original, static_path, self.input_shape)
|
onnx_make_fixed(onnx_path_original, static_path, self.input_shape)
|
||||||
onnx_transpose_4d(static_path)
|
onnx_transpose_4d(static_path)
|
||||||
|
@ -181,50 +240,28 @@ class ExportBase:
|
||||||
armnn_fp32 = os.path.join(output_dir, "model.armnn")
|
armnn_fp32 = os.path.join(output_dir, "model.armnn")
|
||||||
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
|
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
|
||||||
|
|
||||||
input_tensors = list(chain.from_iterable(("-i", input_) for input_ in self.inputs)),
|
|
||||||
output_tensors = list(chain.from_iterable(("-o", output_) for output_ in self.outputs)),
|
|
||||||
print(f"{input_tensors=}")
|
|
||||||
print(f"{output_tensors=}")
|
|
||||||
args = [
|
args = [
|
||||||
"./armnnconverter",
|
"./armnnconverter",
|
||||||
"-f",
|
"-f",
|
||||||
"tflite-binary",
|
"tflite-binary",
|
||||||
"-m",
|
|
||||||
tflite_fp32,
|
|
||||||
"-p",
|
|
||||||
armnn_fp32,
|
|
||||||
]
|
]
|
||||||
for input_ in self.inputs:
|
for input_ in self.inputs:
|
||||||
args.extend(["-i", input_])
|
args.extend(["-i", input_])
|
||||||
for output_ in self.outputs:
|
for output_ in self.outputs:
|
||||||
args.extend(["-o", output_])
|
args.extend(["-o", output_])
|
||||||
|
|
||||||
|
fp32_args = args.copy()
|
||||||
|
fp32_args.extend(["-m", tflite_fp32, "-p", tflite_fp32])
|
||||||
|
|
||||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
|
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
|
||||||
subprocess.run(
|
subprocess.run(fp32_args, capture_output=True)
|
||||||
args,
|
|
||||||
capture_output=True,
|
|
||||||
)
|
|
||||||
print(f"Finished exporting {self.name} ({self.task}) with fp32 precision")
|
print(f"Finished exporting {self.name} ({self.task}) with fp32 precision")
|
||||||
|
|
||||||
args = [
|
fp16_args = args.copy()
|
||||||
"./armnnconverter",
|
fp32_args.extend(["-m", tflite_fp16, "-p", tflite_fp16])
|
||||||
"-f",
|
|
||||||
"tflite-binary",
|
|
||||||
"-m",
|
|
||||||
tflite_fp16,
|
|
||||||
"-p",
|
|
||||||
armnn_fp16,
|
|
||||||
]
|
|
||||||
for input_ in self.inputs:
|
|
||||||
args.extend(["-i", input_])
|
|
||||||
for output_ in self.outputs:
|
|
||||||
args.extend(["-o", output_])
|
|
||||||
|
|
||||||
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
|
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
|
||||||
subprocess.run(
|
subprocess.run(fp16_args, capture_output=True)
|
||||||
args,
|
|
||||||
capture_output=True,
|
|
||||||
)
|
|
||||||
print(f"Finished exporting {self.name} ({self.task}) with fp16 precision")
|
print(f"Finished exporting {self.name} ({self.task}) with fp16 precision")
|
||||||
|
|
||||||
return armnn_fp32, armnn_fp16
|
return armnn_fp32, armnn_fp16
|
||||||
|
|
Loading…
Add table
Reference in a new issue