{ "cells": [ { "cell_type": "markdown", "id": "703f27836e444a82", "metadata": {}, "source": [ "# PyTorch to ONNX" ] }, { "cell_type": "markdown", "id": "f1b7a378121f36ea", "metadata": {}, "source": [ "**eIQ AI Toolkit** uses the **Olive** (https://microsoft.github.io/Olive/) framework for model conversions.\n", "This framework applies **passes**, where each pass represents a single transformation of a model.\n", "Passes can also be chained together to perform multiple transformations in sequence.\n", "\n", "This guide explains how to use the **OnnxConversion** pass to convert a PyTorch model into ONNX format.\n", "It demonstrates how to use the **eIQ AI Toolkit** API to:\n", "- Load a PyTorch model into the application\n", "- Run the conversion to ONNX\n", "- Retrieve the converted model\n", "\n", "This guide requires the eIQ AI Toolkit backend to be running.\n", "If you haven’t set it up yet, please refer to the following tutorial:\n", "[eIQ AI Toolkit setup & launch](../tools/aiToolkit/installRun.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "id": "584d5bb9caf04394", "metadata": {}, "outputs": [], "source": [ "import requests\n", "import sys\n", "from pathlib import Path\n", "\n", "# Set your eIQ AI Toolkit url:\n", "AI_TOOLKIT_BACKEND_URL = \"http://localhost:8000\"" ] }, { "cell_type": "markdown", "id": "9f802b2ca8df20f6", "metadata": {}, "source": [ "## Load PyTorch model\n", "\n", "Loading any model into the application involves two steps:\n", "1. Specify the model metadata\n", "2. Upload the model file\n", "\n", "The metadata must include the model type you are uploading. For example, use `pytorch` for a PyTorch model or `onnx` for an ONNX model.\n", "\n", "For PyTorch models, you also need to define the following parameters:\n", "- Input names\n", "- Input shapes\n", "- Input types\n", "- Output names\n", "\n", "You’ll see an example later showing how to structure and upload this data correctly.\n", "\n", "\n", "eIQ AI Toolkit accepts PyTorch models only as a .zip archive containing:\n", "- The model weights (`.pth` or `.pt`)\n", "- A model definition script (`.py`)\n", "\n", "The script must define a `load_model` function that accepts the path to the model weights and returns a `torch.nn.Module` instance (your model)." ] }, { "cell_type": "markdown", "id": "1f399149867d10e2", "metadata": {}, "source": [ "### 1. Prepare PyTorch model\n", "\n", "If you already have a model, update the path to point to its location.\n", "If you don’t have a model yet, set the path to the location where the model will be saved.\n", "(See the following sections for instructions on downloading a sample model.)" ] }, { "cell_type": "code", "execution_count": null, "id": "48ac0f9251cddc79", "metadata": {}, "outputs": [], "source": [ "# Modify the path to your PyTorch model\n", "model_path = Path(\"path_to_pytorch_model.zip\")" ] }, { "cell_type": "markdown", "id": "bd92c3a152b7105b", "metadata": {}, "source": [ "Use the following script to download the example model:\n", "\n", "*Note: Skip this step if you already have your own model.*" ] }, { "cell_type": "code", "execution_count": null, "id": "d07e904436db12c6", "metadata": {}, "outputs": [], "source": [ "example_model_url = \"https://eiq.nxp.com/training-materials/_misc/models/model.zip\"\n", "with open(model_path, \"wb\") as f:\n", " response = requests.get(\n", " url=example_model_url\n", " )\n", " f.write(response.content)" ] }, { "cell_type": "markdown", "id": "f5f564d5-cc76-4d6b-b81c-b5fc05cc3a9d", "metadata": {}, "source": [ "### 2. Specify metadata\n", "\n", "Now let’s specify the model metadata. As mentioned earlier, metadata for a PyTorch model must include:\n", "- Input names\n", "- Input shapes\n", "- Input types\n", "- Output names\n", "\n", "The structure can look like this:\n", "```json\n", "{\n", " \"model_type\": \"pytorch\",\n", " \"io_config\": {\n", " \"input_config\": [\n", " {\n", " \"name\": \"images\",\n", " \"shape\": [1,1,49,10],\n", " \"type\": \"float32\"\n", " },\n", " {\n", " \"name\": \"descriptions\",\n", " \"shape\": [1, 50, 50],\n", " \"type\": \"float32\"\n", " }\n", " ],\n", " \"output_config\": [\n", " {\"name\": \"prediction\"}\n", " ]\n", " }\n", "}\n", "```\n", "\n", "In this example, we define a PyTorch model with two inputs—*images* and *descriptions*—and one output, *prediction*.\n", "\n", "Next, let’s implement and use it in code.\n", "\n", "*Note: The code below is prepared for the example model. If you’re using your own model, update the values accordingly.*" ] }, { "cell_type": "code", "execution_count": null, "id": "8d892a307e12cd6f", "metadata": {}, "outputs": [], "source": [ "model_name = \"PyTorch model\"\n", "\n", "# Define inputs metadata\n", "inputs = [\n", " {\n", " \"name\": \"images\",\n", " \"shape\": [1,1,49,10],\n", " \"type\": \"float32\"\n", " }\n", " ]\n", " \n", "# Define outputs metadata\n", "outputs = [{\"name\": \"y\"}]\n", "\n", "MODELS_API_URL = f\"{AI_TOOLKIT_BACKEND_URL}/models\"\n", "\n", "# Full model metadata\n", "model_metadata = {\n", " \"model_type\": \"pytorch\",\n", " \"io_config\": {\n", " \"input_config\": inputs,\n", " \"output_config\": outputs,\n", " }\n", " }\n", "\n", "# Upload metadata\n", "response = requests.post(MODELS_API_URL, json=model_metadata)\n", "response_data = response.json()\n", "model_uuid = response_data[\"data\"][\"model\"][\"uuid\"]\n", "\n", "# Upload model file\n", "with open(model_path, \"rb\") as zip_file:\n", " files = {\n", " \"model_file\": (\"model.zip\", zip_file, \"application/zip\")\n", " }\n", " response = requests.post(url=f\"{AI_TOOLKIT_BACKEND_URL}/models/{model_uuid}\", # Model identifier is part of the request URL\n", " files=files)\n", "\n", " if response.status_code == 200:\n", " response_data = response.json()\n", " print(f'Pytorch model named \"{model_name}\" has been uploaded successfully!')\n", " else:\n", " print(\"Something went wrong while uploading the model: \\n\\n\", file=sys.stderr)\n", " print(response.text, file=sys.stderr)" ] }, { "cell_type": "markdown", "id": "420139b2059d98f9", "metadata": {}, "source": [ "After uploading the model metadata and file, you can verify its registration and readiness status using the following endpoint. If the status remains `in_progress`, repeat the check until it changes to `ready`." ] }, { "cell_type": "code", "execution_count": null, "id": "2d750d9f8731405c", "metadata": {}, "outputs": [], "source": [ "response = requests.get(f\"{AI_TOOLKIT_BACKEND_URL}/models/{model_uuid}\")\n", "data = response.json()\n", "print(f'Model status: {data[\"data\"][\"model\"][\"status\"]}')\n", "print(f'Model status description: {data[\"data\"][\"model\"][\"status_description\"]}')" ] }, { "cell_type": "markdown", "id": "86455caa25c90f68", "metadata": {}, "source": [ "## Conversion of PyTorch model to ONNX\n", "\n", "Now the model is ready, and you can run the conversion.\n", "\n", "In the code below, some parameters are set for the `OnnxConversion` pass. To view the full list of supported parameters, check this endpoint:" ] }, { "cell_type": "code", "execution_count": null, "id": "3ba67ac6b01e0110", "metadata": {}, "outputs": [], "source": [ "available_passes_response = requests.get(f\"{AI_TOOLKIT_BACKEND_URL}/optimizations/passes\")\n", "available_passes = available_passes_response.json()\n", "\n", "# This prints configuration parameters only for OnnxConversion pass. Feel free to change it and explore\n", "# other passes as well.\n", "onnx_conversion_pass_config = next(_pass for _pass in available_passes[\"data\"][\"passes\"] if _pass[\"type\"] == \"OnnxConversion\")\n", "print(onnx_conversion_pass_config)" ] }, { "cell_type": "markdown", "id": "e51f8241efd88b58", "metadata": {}, "source": [ "To run the conversion, send a request to the */optimizations/run* endpoint:" ] }, { "cell_type": "code", "execution_count": null, "id": "3eede4a5edbe9ef6", "metadata": {}, "outputs": [], "source": [ "OPTIMIZATIONS_API_URL = f\"{AI_TOOLKIT_BACKEND_URL}/optimizations\"\n", "RUN_OPTIMIZATION_API_URL = f\"{OPTIMIZATIONS_API_URL}/run\"\n", "\n", "pass_config = {\n", " \"model_uuid\": model_uuid,\n", " \"passes\": [\n", " {\n", " \"type\": \"OnnxConversion\",\n", " \"config\": {\n", " \"target_opset\": 14,\n", " \"optimize\": True\n", " }\n", " }\n", " ]\n", "}\n", "\n", "optimization_response = requests.post(RUN_OPTIMIZATION_API_URL, json=pass_config)\n", "data = optimization_response.json()\n", "optimization_uuid = data[\"data\"][\"optimization\"][\"uuid\"]" ] }, { "cell_type": "markdown", "id": "caa424e618b54af8", "metadata": {}, "source": [ "Running the code above starts the conversion. To check the status, use the following endpoint. Again, you may need to run it multiple times until status changes to `success`:" ] }, { "cell_type": "code", "execution_count": null, "id": "660aed3929e5b8f1", "metadata": {}, "outputs": [], "source": [ "response = requests.get(f\"{OPTIMIZATIONS_API_URL}/{optimization_uuid}\")\n", "data = response.json()\n", "status = data[\"data\"][\"optimization\"][\"status\"]\n", "print(f\"Conversion status: {status}\")\n", "\n", "if status == \"success\":\n", " artifact_id = data[\"data\"][\"optimization\"][\"artifacts\"][0][\"artifact_id\"]" ] }, { "cell_type": "markdown", "id": "30910dae-11f3-486c-b0dc-68db7879d5f3", "metadata": {}, "source": [ "## Download converted model\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c3d81afe-f847-4564-872b-0dc8d734706a", "metadata": {}, "outputs": [], "source": [ "# Change model path to your location\n", "dest_model_path = Path(\"my_path_to_converted_model.onnx\")" ] }, { "cell_type": "code", "execution_count": null, "id": "9187efa3-c11f-498a-8267-bd13b26ef320", "metadata": {}, "outputs": [], "source": [ "download_response = requests.get(f\"{AI_TOOLKIT_BACKEND_URL}/optimizations/{optimization_uuid}/resources/{artifact_id}\")\n", "with dest_model_path.open(\"wb\") as f:\n", " f.write(download_response.content)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }