diff --git a/.gitignore b/.gitignore index 534fed2..9389a60 100644 --- a/.gitignore +++ b/.gitignore @@ -108,4 +108,5 @@ venv.bak/ *.jpg *.jpeg *.pt +*.onnx .DS_Store diff --git a/dpt/vit.py b/dpt/vit.py index 00ae380..e023e1d 100644 --- a/dpt/vit.py +++ b/dpt/vit.py @@ -318,11 +318,17 @@ def forward(self, x): out_size = torch.Size((h // self.patch_size[1], w // self.patch_size[0])) if not self.hybrid_backbone: - layer_1 = self.act_postprocess1(layer_1.unflatten(2, out_size)) - layer_2 = self.act_postprocess2(layer_2.unflatten(2, out_size)) - - layer_3 = self.act_postprocess3(layer_3.unflatten(2, out_size)) - layer_4 = self.act_postprocess4(layer_4.unflatten(2, out_size)) + # according to https://github.com/isl-org/DPT/issues/42#issuecomment-944657114 + # layer_1 = self.act_postprocess1(layer_1.unflatten(2, out_size)) + # layer_2 = self.act_postprocess2(layer_2.unflatten(2, out_size)) + layer_1 = self.act_postprocess1(layer_1.view(layer_1.shape[0], layer_1.shape[1], *out_size)) + layer_2 = self.act_postprocess2(layer_2.view(layer_2.shape[0], layer_2.shape[1], *out_size)) + + # according to https://github.com/isl-org/DPT/issues/42#issuecomment-944657114 + # layer_3 = self.act_postprocess3(layer_3.unflatten(2, out_size)) + # layer_4 = self.act_postprocess4(layer_4.unflatten(2, out_size)) + layer_3 = self.act_postprocess3(layer_3.view(layer_3.shape[0], layer_3.shape[1], *out_size)) + layer_4 = self.act_postprocess4(layer_4.view(layer_4.shape[0], layer_4.shape[1], *out_size)) return layer_1, layer_2, layer_3, layer_4 diff --git a/export_monodepth_onnx.py b/export_monodepth_onnx.py new file mode 100644 index 0000000..8b0a5c1 --- /dev/null +++ b/export_monodepth_onnx.py @@ -0,0 +1,163 @@ +import torch +import argparse +import onnx +import onnxruntime +import json +import numpy as np +import cv2 + +from dpt.models import DPTDepthModel +from dpt.midas_net import MidasNet_large +import util.io + + +def main(model_path, model_type, output_path, batch_size, test_image_path): + # load network + if model_type == "dpt_large": # DPT-Large + net_w = net_h = 384 + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + enable_attention_hooks=False, + ) + normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + prediction_factor = 1 + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w = net_h = 384 + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + enable_attention_hooks=False, + ) + normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + prediction_factor = 1 + elif model_type == "dpt_hybrid_kitti": + net_w = 1216 + net_h = 352 + + model = DPTDepthModel( + path=model_path, + scale=0.00006016, + shift=0.00579, + invert=True, + backbone="vitb_rn50_384", + non_negative=True, + enable_attention_hooks=False, + ) + + normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + prediction_factor = 256 + elif model_type == "dpt_hybrid_nyu": + net_w = 640 + net_h = 480 + + model = DPTDepthModel( + path=model_path, + scale=0.000305, + shift=0.1378, + invert=True, + backbone="vitb_rn50_384", + non_negative=True, + enable_attention_hooks=False, + ) + + normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + prediction_factor = 1000.0 + elif model_type == "midas_v21": # Convolutional model + net_w = net_h = 384 + + model = MidasNet_large(model_path, non_negative=True) + normalization = dict( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + prediction_factor = 1 + else: + assert ( + False + ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid|dpt_hybrid_kitti|dpt_hybrid_nyu|midas_v21]" + + model.eval() + + dummy_input = torch.zeros((batch_size, 3, net_h, net_w)) + # TODO: right now, the batch size is not dynamic due to the PyTorch tracer + # treating the batch size as constant (see get_attention() in vit.py). + # Therefore you have to use a batch size of one to use this together with + # run_monodepth_onnx.py. + torch.onnx.export( + model, + dummy_input, + output_path, + input_names=["input"], + output_names=["output"], + opset_version=11, + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + + # store normalization configuration + model_onnx = onnx.load(output_path) + meta_imagesize = model_onnx.metadata_props.add() + meta_imagesize.key = "ImageSize" + meta_imagesize.value = json.dumps([net_w, net_h]) + meta_normalization = model_onnx.metadata_props.add() + meta_normalization.key = "Normalization" + meta_normalization.value = json.dumps(normalization) + meta_prediction_factor = model_onnx.metadata_props.add() + meta_prediction_factor.key = "PredictionFactor" + meta_prediction_factor.value = str(prediction_factor) + onnx.save(model_onnx, output_path) + del model_onnx + + if test_image_path is not None: + # load test image + img = util.io.read_image(test_image_path) + + # resize + img_input = cv2.resize(img, (net_h, net_w), cv2.INTER_AREA) + + # normalize + img_input = (img_input - np.array(normalization["mean"])) / np.array(normalization["std"]) + + # transpose from HWC to CHW + img_input = img_input.transpose(2, 0, 1) + + # add batch dimension + img_input = np.stack([img_input] * batch_size) + + # validate accuracy of exported model + torch_out = model(torch.from_numpy(img_input.astype(np.float32))).detach().cpu().numpy() + session = onnxruntime.InferenceSession( + output_path, + providers=[ + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + ], + ) + onnx_out = session.run(["output"], {"input": img_input.astype(np.float32)})[0] + + # compare ONNX Runtime and PyTorch results + np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-02, atol=1e-04) + print("Exported model predictions match original") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_weights", help="path to input model weights") + parser.add_argument("output_path", help="path to output model weights") + parser.add_argument( + "-t", + "--model_type", + default="dpt_hybrid", + help="model type [dpt_large|dpt_hybrid|midas_v21]", + ) + parser.add_argument("--batch_size", default=1, help="batch size used for tracing") + parser.add_argument( + "--test_image_path", + type=str, + help="path to some image to test the accuracy of the exported model against the original" + ) + + args = parser.parse_args() + main(args.model_weights, args.model_type, args.output_path, args.batch_size, args.test_image_path) \ No newline at end of file diff --git a/export_segmentation_onnx.py b/export_segmentation_onnx.py new file mode 100644 index 0000000..0675ad5 --- /dev/null +++ b/export_segmentation_onnx.py @@ -0,0 +1,114 @@ +import torch +import argparse +import onnx +import onnxruntime +import json +import numpy as np +import cv2 + +from dpt.models import DPTSegmentationModel +import util.io + + +def main(model_path, model_type, output_path, batch_size, test_image_path): + net_w = net_h = 480 + normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + # load network + if model_type == "dpt_large": + model = DPTSegmentationModel( + 150, + path=model_path, + backbone="vitl16_384", + ) + elif model_type == "dpt_hybrid": + model = DPTSegmentationModel( + 150, + path=model_path, + backbone="vitb_rn50_384", + ) + else: + assert ( + False + ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]" + + model.eval() + + dummy_input = torch.zeros((batch_size, 3, net_h, net_w)) + # TODO: right now, the batch size is not dynamic due to the PyTorch tracer + # treating the batch size as constant (see get_attention() in vit.py). + # Therefore you have to use a batch size of one to use this together with + # run_monodepth_onnx.py. + torch.onnx.export( + model, + dummy_input, + output_path, + input_names=["input"], + output_names=["output"], + opset_version=11, + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + + # store normalization configuration + model_onnx = onnx.load(output_path) + meta_imagesize = model_onnx.metadata_props.add() + meta_imagesize.key = "ImageSize" + meta_imagesize.value = json.dumps([net_w, net_h]) + meta_normalization = model_onnx.metadata_props.add() + meta_normalization.key = "Normalization" + meta_normalization.value = json.dumps(normalization) + onnx.save(model_onnx, output_path) + del model_onnx + + if test_image_path is not None: + # load test image + img = util.io.read_image(test_image_path) + + # resize + img_input = cv2.resize(img, (net_h, net_w), cv2.INTER_AREA) + + # normalize + img_input = (img_input - np.array(normalization["mean"])) / np.array(normalization["std"]) + + # transpose from HWC to CHW + img_input = img_input.transpose(2, 0, 1) + + # add batch dimension + img_input = np.stack([img_input] * batch_size) + + # validate accuracy of exported model + torch_out = model(torch.from_numpy(img_input.astype(np.float32))).detach().cpu().numpy() + session = onnxruntime.InferenceSession( + output_path, + providers=[ + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + ], + ) + onnx_out = session.run(["output"], {"input": img_input.astype(np.float32)})[0] + + # compare ONNX Runtime and PyTorch results + np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-02, atol=1e-04) + print("Exported model predictions match original") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_weights", help="path to input model weights") + parser.add_argument("output_path", help="path to output model weights") + parser.add_argument( + "-t", + "--model_type", + default="dpt_hybrid", + help="model type [dpt_large|dpt_hybrid]", + ) + parser.add_argument("--batch_size", default=1, help="batch size used for tracing") + parser.add_argument( + "--test_image_path", + type=str, + help="path to some image to test the accuracy of the exported model against the original" + ) + + args = parser.parse_args() + main(args.model_weights, args.model_type, args.output_path, args.batch_size, args.test_image_path) \ No newline at end of file diff --git a/run_monodepth_onnx.py b/run_monodepth_onnx.py new file mode 100644 index 0000000..87b92ed --- /dev/null +++ b/run_monodepth_onnx.py @@ -0,0 +1,115 @@ +import argparse +import json +import os +import glob +import numpy as np +import cv2 +import onnx +import onnxruntime + +import util.io + + +def run(input_path, output_path, model_path, kitti_crop, absolute_depth): + + model = onnx.load(model_path) + net_w, net_h = json.loads(model.metadata_props[0].value) + normalization = json.loads(model.metadata_props[1].value) + prediction_factor = float(model.metadata_props[2].value) + mean = np.array(normalization["mean"]) + std = np.array(normalization["std"]) + del model + + session = onnxruntime.InferenceSession( + model_path, + providers=[ + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + ], + ) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + for ind, img_name in enumerate(img_names): + if os.path.isdir(img_name): + continue + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + # input + + img = util.io.read_image(img_name) + + if kitti_crop is True: + height, width, _ = img.shape + top = height - 352 + left = (width - 1216) // 2 + img = img[top : top + 352, left : left + 1216, :] + + # resize + img_input = cv2.resize(img, (net_h, net_w), cv2.INTER_AREA) + + # normalize + img_input = (img_input - mean) / std + + # transpose from HWC to CHW + img_input = img_input.transpose(2, 0, 1) + + # add batch dimension + img_input = img_input[None, ...] + + # compute + prediction = session.run(["output"], {"input": img_input.astype(np.float32)})[0][0] + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), cv2.INTER_CUBIC) + prediction *= prediction_factor + + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + util.io.write_depth( + filename, prediction, bits=2, absolute_depth=absolute_depth + ) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", "--input_path", default="input", help="folder with input images" + ) + + parser.add_argument( + "-o", + "--output_path", + default="output_monodepth", + help="folder for output images", + ) + + parser.add_argument( + "-m", "--model_weights", default=None, help="path to model weights" + ) + + parser.add_argument("--kitti_crop", dest="kitti_crop", action="store_true") + parser.add_argument("--absolute_depth", dest="absolute_depth", action="store_true") + + parser.set_defaults(kitti_crop=False) + parser.set_defaults(absolute_depth=False) + + args = parser.parse_args() + + # compute depth maps + run( + args.input_path, + args.output_path, + args.model_weights, + args.kitti_crop, + args.absolute_depth, + ) \ No newline at end of file diff --git a/run_segmentation_onnx.py b/run_segmentation_onnx.py new file mode 100644 index 0000000..e0f2713 --- /dev/null +++ b/run_segmentation_onnx.py @@ -0,0 +1,96 @@ +import argparse +import json +import os +import glob +import numpy as np +import cv2 +import onnx +import onnxruntime + +import util.io + + +def run(input_path, output_path, model_path): + + model = onnx.load(model_path) + net_w, net_h = json.loads(model.metadata_props[0].value) + normalization = json.loads(model.metadata_props[1].value) + mean = np.array(normalization["mean"]) + std = np.array(normalization["std"]) + del model + + session = onnxruntime.InferenceSession( + model_path, + providers=[ + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + ], + ) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + for ind, img_name in enumerate(img_names): + if os.path.isdir(img_name): + continue + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + # input + + img = util.io.read_image(img_name) + + # resize + img_input = cv2.resize(img, (net_h, net_w), cv2.INTER_AREA) + + # normalize + img_input = (img_input - mean) / std + + # transpose from HWC to CHW + img_input = img_input.transpose(2, 0, 1) + + # add batch dimension + img_input = img_input[None, ...] + + # compute + prediction = session.run(["output"], {"input": img_input.astype(np.float32)})[0][0] + prediction = prediction.transpose(1, 2, 0) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), cv2.INTER_CUBIC) + prediction = np.argmax(prediction, axis=-1) + 1 + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + util.io.write_segm_img(filename, img, prediction, alpha=0.5) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", "--input_path", default="input", help="folder with input images" + ) + + parser.add_argument( + "-o", + "--output_path", + default="output_monodepth", + help="folder for output images", + ) + + parser.add_argument( + "-m", "--model_weights", default=None, help="path to model weights" + ) + + args = parser.parse_args() + + # compute segmentation maps + run(args.input_path, args.output_path, args.model_weights) \ No newline at end of file