{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "LTj2A0RUQFNo" }, "source": [ "# Convert our BiRefNet weights to onnx format.\n", "\n", "> This colab file is modified from [Kazuhito00](https://github.com/Kazuhito00)'s nice work.\n", "\n", "> Repo: https://github.com/Kazuhito00/BiRefNet-ONNX-Sample \n", "> Original Colab: https://colab.research.google.com/github/Kazuhito00/BiRefNet-ONNX-Sample/blob/main/Convert2ONNX.ipynb\n", "\n", "+ Transforming a standard BiRefNet on GPU needs **19.7GB** GPU memory.\n", "+ Currently, Colab with 12.7GB RAM / 15GB GPU Mem cannot hold the transformation of BiRefNet in default setting. So, I take BiRefNet with swin_v1_tiny backbone as an example on Colab." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Online Colab version: https://colab.research.google.com/drive/1z6OruR52LOvDDpnp516F-N4EyPGrp5om" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install onnx onnxscript onnxruntime-gpu" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cd .." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "781JHjLJmveh" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "weights_file = 'BiRefNet_dynamic-general-epoch_174.pth' # https://github.com/ZhengPeng7/BiRefNet/releases/download/v1/BiRefNet_dynamic-general-epoch_174.pth\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open('config.py') as fp:\n", " file_lines = fp.read()\n", "if 'swin_v1_tiny' in weights_file:\n", " print('Set `swin_v1_tiny` as the backbone.')\n", " file_lines = file_lines.replace(\n", " '''\n", " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", " ][6]\n", " ''',\n", " '''\n", " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", " ][3]\n", " ''',\n", " )\n", " with open('config.py', mode=\"w\") as fp:\n", " fp.write(file_lines)\n", "else:\n", " file_lines = file_lines.replace(\n", " '''\n", " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", " ][3]\n", " ''',\n", " '''\n", " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", " ][6]\n", " ''',\n", " )\n", " with open('config.py', mode=\"w\") as fp:\n", " fp.write(file_lines)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7lFgKfPS8Icy" }, "outputs": [], "source": [ "from utils import check_state_dict\n", "from models.birefnet import BiRefNet\n", "\n", "\n", "birefnet = BiRefNet(bb_pretrained=False)\n", "state_dict = torch.load('./{}'.format(weights_file), map_location=device, weights_only=True)\n", "state_dict = check_state_dict(state_dict)\n", "birefnet.load_state_dict(state_dict)\n", "\n", "torch.set_float32_matmul_precision(['high', 'highest'][0])\n", "\n", "birefnet.to(device)\n", "_ = birefnet.eval()" ] }, { "cell_type": "markdown", "metadata": { "id": "JVgJAdgxQVJW" }, "source": [ "# Process deform_conv2d in the conversion to ONNX" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!git clone https://github.com/masamitsu-murase/deform_conv2d_onnx_exporter\n", "%cp deform_conv2d_onnx_exporter/src/deform_conv2d_onnx_exporter.py .\n", "!rm -rf deform_conv2d_onnx_exporter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open('deform_conv2d_onnx_exporter.py') as fp:\n", " file_lines = fp.read()\n", "\n", "file_lines = file_lines.replace(\n", " \"return sym_help._get_tensor_dim_size(tensor, dim)\",\n", " '''\n", " tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim)\n", " if tensor_dim_size == None and (dim == 2 or dim == 3):\n", " import typing\n", " from torch import _C\n", "\n", " x_type = typing.cast(_C.TensorType, tensor.type())\n", " x_strides = x_type.strides()\n", "\n", " tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2]\n", " elif tensor_dim_size == None and (dim == 0):\n", " import typing\n", " from torch import _C\n", "\n", " x_type = typing.cast(_C.TensorType, tensor.type())\n", " x_strides = x_type.strides()\n", " tensor_dim_size = x_strides[3]\n", "\n", " return tensor_dim_size\n", " ''',\n", ")\n", "\n", "with open('deform_conv2d_onnx_exporter.py', mode=\"w\") as fp:\n", " fp.write(file_lines)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vJiZv0L75kTe" }, "outputs": [], "source": [ "from torchvision.ops.deform_conv import DeformConv2d\n", "import deform_conv2d_onnx_exporter\n", "\n", "# register deform_conv2d operator\n", "deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n", "\n", "def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):\n", " input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)\n", "\n", " input_layer_names = ['input_image']\n", " output_layer_names = ['output_image']\n", "\n", " torch.onnx.export(\n", " net,\n", " input,\n", " file_name,\n", " verbose=False,\n", " opset_version=17,\n", " input_names=input_layer_names,\n", " output_names=output_layer_names,\n", " )\n", "convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)" ] }, { "cell_type": "markdown", "metadata": { "id": "-eU-g40P1zS-" }, "source": [ "# Load ONNX weights and do the inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "from torchvision import transforms\n", "\n", "\n", "transform_image = transforms.Compose([\n", " transforms.Resize((1024, 1024)),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", "])\n", "\n", "imagepath = 'images_todo/0-onnx_test-2.png'\n", "image = Image.open(imagepath)\n", "image = image.convert(\"RGB\") if image.mode != \"RGB\" else image\n", "input_images = transform_image(image).unsqueeze(0).to(device)\n", "input_images_numpy = input_images.cpu().numpy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rwzdKX1EfYkd" }, "outputs": [], "source": [ "import onnxruntime\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']\n", "onnx_session = onnxruntime.InferenceSession(\n", " weights_file.replace('.pth', '.onnx'),\n", " providers=providers\n", ")\n", "input_name = onnx_session.get_inputs()[0].name\n", "print(onnxruntime.get_device(), onnx_session.get_providers())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DJVtxZUZum4-" }, "outputs": [], "source": [ "from time import time\n", "import matplotlib.pyplot as plt\n", "\n", "time_st = time()\n", "pred_onnx = torch.tensor(\n", " onnx_session.run(None, {input_name: input_images_numpy if device == 'cpu' else input_images_numpy})[-1]\n", ").squeeze(0).sigmoid().cpu()\n", "print(time() - time_st)\n", "\n", "plt.imshow(pred_onnx.squeeze(), cmap='gray'); plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " preds = birefnet(input_images)[-1].sigmoid().to(torch.float32).cpu()\n", "plt.imshow(preds.squeeze(), cmap='gray'); plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "diff = abs(preds - pred_onnx)\n", "print('sum(diff):', diff.sum())\n", "plt.imshow((diff).squeeze(), cmap='gray'); plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "qzYHflt92Bjd" }, "source": [ "# Efficiency Comparison between .pth and .onnx" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "A5IYfT-uzphA", "outputId": "2999e345-950e-41b3-ddd3-9f58a71a3f21" }, "outputs": [], "source": [ "%%timeit\n", "with torch.no_grad():\n", " preds = birefnet(input_images)[-1].sigmoid().to(torch.float32).cpu()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0Ul4rfNg1za" }, "outputs": [], "source": [ "%%timeit\n", "pred_onnx = torch.tensor(\n", " onnx_session.run(None, {input_name: input_images_numpy})[-1]\n", ").squeeze(0).sigmoid().cpu()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 4 }