0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-21 00:52:43 -05:00
This commit is contained in:
mertalev 2024-07-08 18:19:35 -04:00
parent 3d62011ae3
commit b39cca1b43
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3

View file

@ -12,6 +12,7 @@ from huggingface_hub import login, upload_file
import onnx2tf import onnx2tf
import numpy as np import numpy as np
import onnxsim import onnxsim
from shutil import rmtree
# i can explain # i can explain
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze # armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
@ -167,9 +168,9 @@ def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, .
simplified, success = onnxsim.simplify(input_path) simplified, success = onnxsim.simplify(input_path)
if not success: if not success:
raise RuntimeError(f"Failed to simplify {input_path}") raise RuntimeError(f"Failed to simplify {input_path}")
onnx.save(simplified, input_path) onnx.save(simplified, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
infer_shapes_path(input_path, check_type=True, strict_mode=True, data_prop=True) infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
model = onnx.load_model(input_path) model = onnx.load_model(output_path)
make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape) make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape)
fix_output_shapes(model) fix_output_shapes(model)
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False) onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
@ -218,20 +219,23 @@ class ExportBase:
def to_tflite(self, output_dir: str) -> tuple[str, str]: def to_tflite(self, output_dir: str) -> tuple[str, str]:
input_path = self.to_onnx_static() input_path = self.to_onnx_static()
os.makedirs(output_dir, exist_ok=True)
tflite_fp32 = os.path.join(output_dir, "model_float32.tflite") tflite_fp32 = os.path.join(output_dir, "model_float32.tflite")
tflite_fp16 = os.path.join(output_dir, "model_float16.tflite") tflite_fp16 = os.path.join(output_dir, "model_float16.tflite")
if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16): if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16):
print(f"Exporting {self.model_name} ({self.task}) to TFLite") print(f"Exporting {self.model_name} ({self.task}) to TFLite (this might take a few minutes)")
onnx2tf.convert( onnx2tf.convert(
input_onnx_file_path=input_path, input_onnx_file_path=input_path,
output_folder_path=output_dir, output_folder_path=output_dir,
keep_shape_absolutely_input_names=self.inputs,
verbosity="warn",
copy_onnx_input_output_names_to_tflite=True, copy_onnx_input_output_names_to_tflite=True,
output_signaturedefs=True,
) )
return tflite_fp32, tflite_fp16 return tflite_fp32, tflite_fp16
def to_armnn(self, output_dir: str) -> tuple[str, str]: def to_armnn(self, output_dir: str) -> tuple[str, str]:
output_dir = os.path.abspath(output_dir)
tflite_model_dir = os.path.join(output_dir, "tflite") tflite_model_dir = os.path.join(output_dir, "tflite")
tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir) tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir)
@ -240,28 +244,38 @@ class ExportBase:
armnn_fp32 = os.path.join(output_dir, "model.armnn") armnn_fp32 = os.path.join(output_dir, "model.armnn")
armnn_fp16 = os.path.join(fp16_dir, "model.armnn") armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
args = [ args = ["./armnnconverter", "-f", "tflite-binary"]
"./armnnconverter",
"-f",
"tflite-binary",
]
for input_ in self.inputs: for input_ in self.inputs:
args.extend(["-i", input_]) args.extend(["-i", input_])
for output_ in self.outputs: for output_ in self.outputs:
args.extend(["-o", output_]) args.extend(["-o", output_])
fp32_args = args.copy() fp32_args = args.copy()
fp32_args.extend(["-m", tflite_fp32, "-p", tflite_fp32]) fp32_args.extend(["-m", tflite_fp32, "-p", armnn_fp32])
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision") print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
subprocess.run(fp32_args, capture_output=True) try:
print(subprocess.check_output(fp32_args, stderr=subprocess.STDOUT).decode())
except subprocess.CalledProcessError as e:
print(e.output.decode())
try:
rmtree(tflite_model_dir, ignore_errors=True)
finally:
raise e
print(f"Finished exporting {self.name} ({self.task}) with fp32 precision") print(f"Finished exporting {self.name} ({self.task}) with fp32 precision")
fp16_args = args.copy() fp16_args = args.copy()
fp32_args.extend(["-m", tflite_fp16, "-p", tflite_fp16]) fp16_args.extend(["-m", tflite_fp16, "-p", armnn_fp16])
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision") print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
subprocess.run(fp16_args, capture_output=True) try:
print(subprocess.check_output(fp16_args, stderr=subprocess.STDOUT).decode())
except subprocess.CalledProcessError as e:
print(e.output.decode())
try:
rmtree(tflite_model_dir, ignore_errors=True)
finally:
raise e
print(f"Finished exporting {self.name} ({self.task}) with fp16 precision") print(f"Finished exporting {self.name} ({self.task}) with fp16 precision")
return armnn_fp32, armnn_fp16 return armnn_fp32, armnn_fp16