Monday, 25 November 2024

The Great Inference Adventure: From Debugging Despair to Deployment Delight

Background, Be Gone: Running Inference with ONNX and U2Net

Ever looked at an image and thought, “The background really ties this mess together”? Me neither. That’s why we’re diving into ONNX inference today—specifically how to remove backgrounds using the U^2-Net model. Whether you’re an AI pro or just here for fun, let’s explore the magic behind the scenes of turning an ONNX model into a background-removal wizard.

 Easily go from a picture with a background to an image without in roughly 20 seconds:


Follow this link for an ultra small model that runs in a browser if you want to test background removal. Unfortunately it will only work on PC as mobile browsers have too many restrictions.

For a complete working program that can be downloaded and installed on your Windows 10/11 pc please click here (download TinyBGR_Setup, unzip it, run setup.exe to install)

What is ONNX Inference?

First, ONNX (Open Neural Network Exchange) isn’t just an acronym; it’s an invitation to make AI work smarter, not harder. It lets you use pre-trained models across different platforms and frameworks—like bringing your favourite lunch to work without worrying if the microwave is compatible.

Inference, in non-jargon, means putting the AI model to work. It’s the moment your model steps onto the stage to predict, classify, or—our focus today—erase pesky image backgrounds.

Here is a representation of the model top in the ONNX model viewer showing exactly how the model expect the image input to be:


 

 

The U^2-Net Model

U^2-Net is the drama queen of image segmentation. It’s trained to identify objects in an image and separate them from the background with surgical precision. Think of it as the scalpel of AI, slicing through pixels to leave you with just what you want (and none of what you don’t).

 

Step 1: Preparing the Image

Before U^2-Net flexes its algorithmic muscles, the image needs a makeover. Here’s what you’ll do:

  1. Resize the Image: U^2-Net expects a specific size (e.g., 320x320 pixels). It’s picky, like a cat that won’t drink tap water.
  2. Normalize Pixel Values: Your image’s RGB values (0-255) must be scaled to 0-1. It’s like telling the model, “Calm down, we’re just pixels.”. 
  3. If the image is RGBA, the transparency layer needs to be removed so that only the "RGB" is left. This can be done quite easily with code. In this case here is an example for C++:

std::vector<float> PreprocessImage(
Gdiplus::Bitmap* bitmap, size_t targetWidth, size_t targetHeight) {
    std::vector<float> floatArr(targetWidth * targetHeight * 3);

    // Resize and draw image
    Gdiplus::Bitmap resizedBitmap(targetWidth, targetHeight, PixelFormat24bppRGB);
    Gdiplus::Graphics graphics(&resizedBitmap);
    graphics.DrawImage(bitmap, 0, 0, targetWidth, targetHeight);

    Gdiplus::Rect rect(0, 0, targetWidth, targetHeight);
    Gdiplus::BitmapData bitmapData;
    resizedBitmap.LockBits(&rect, Gdiplus::ImageLockModeRead,
 PixelFormat24bppRGB, &bitmapData);

    BYTE* pixels = (BYTE*)bitmapData.Scan0;
    int stride = bitmapData.Stride;

    size_t j = 0;
    for (size_t y = 0; y < targetHeight; ++y) {
        for (size_t x = 0; x < targetWidth; ++x) {
            BYTE* pixel = pixels + y * stride + x * 3;
            floatArr[j] = (pixel[2] / 255.0f - 0.485f) / 0.229f; // R
            j++;
            floatArr[j] = (pixel[1] / 255.0f - 0.456f) / 0.224f; // G
            j++;
            floatArr[j] = (pixel[0] / 255.0f - 0.406f) / 0.225f; // B
            j++;
        }
    }

    resizedBitmap.UnlockBits(&bitmapData);

    // Channel wise separation
    std::vector<float&gt floatArr2(floatArr.size());
    size_t k = 0, l = targetWidth * targetHeight, m = targetWidth * targetHeight * 2;
    for (size_t i = 0; i < floatArr.size(); i += 3) {
        floatArr2[k++] = floatArr[i];      // R
        floatArr2[l++] = floatArr[i + 1];  // G
        floatArr2[m++] = floatArr[i + 2];  // B
    }

    return floatArr2;
}

       

Step 2: Running the ONNX Model

Once the image is prepped, you feed it into the ONNX Runtime. Here’s what happens:

  1. Load the ONNX File: Ensure you’ve got the U^2-Net ONNX file handy. Loading it into your app should feel like the calm before the storm.
  2. Run Inference: This is where the magic happens—assuming you haven’t forgotten to set up the input tensor correctly (guilty as charged). A single forward pass processes the image, producing a segmentation mask. Here is the example in C++:
Here is an example of a bad output when the tensor values are off:



Gdiplus::Bitmap* RunModelInference(
Ort::Session& session, std::vector<float>& inputTensorValues,
 int targetWidth, int targetHeight) {
    // Define the known input and output names based on your model's specifications
    const char* inputName = "input.1";  // Update this name if needed
    const char* outputName = "output";  // Update this name if needed

    // Set input tensor dimensions: batch size 1, channels 3, height, width
    std::vector<int64_t> inputDims = { 1, 3, targetWidth, targetHeight };
    Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(
OrtDeviceAllocator, OrtMemTypeCPU);
    Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
memoryInfo, inputTensorValues.data(), inputTensorValues.size(), inputDims.data(),
 inputDims.size());

    // Run inference
    auto outputTensors = session.Run(Ort::RunOptions{ nullptr }, &inputName,
 &inputTensor, 1, &outputName, 1);

    // Access output tensor data
    float* outputData = outputTensors[0].GetTensorMutableData<float>();
    int outputSize = targetWidth * targetHeight;

    // Create a grayscale image to store the output mask
    Gdiplus::Bitmap* maskBitmap = new Gdiplus::Bitmap(targetWidth, targetHeight,
 PixelFormat32bppARGB);

    // Populate mask with output values
    for (int y = 0; y < targetHeight; ++y) {
        for (int x = 0; x < targetWidth; ++x) {
            float pixelValue = outputData[
            y * targetWidth + x] * 255.0f;  // Scale to 0-255
            BYTE intensity = static_cast<BYTE>(std::clamp(pixelValue, 0.0f, 255.0f));
            Gdiplus::Color color(intensity, intensity, intensity);  // Grayscale
            maskBitmap->SetPixel(x, y, color);
        }
    }

    return maskBitmap;
}

 

Step 3: Postprocessing

If preprocessing is the starter, postprocessing is dessert. The raw output from U^2-Net isn’t exactly pretty—it’s like staring at a cryptic Rorschach test.

Here’s how to polish it:

  1. Thresholding: Convert the grayscale mask into a binary one, where the object is white, and the background is black.
  2. Apply the Mask: Use the binary mask to “cut out” the object. Cue the “ta-da” moment as the background vanishes. Some more C++:

Gdiplus::Bitmap* PostprocessOutput(
const std::vector<float>& outputTensorValues, Gdiplus::Bitmap* originalBitmap,
 size_t targetWidth, size_t targetHeight) {
    const float amplificationFactor = 255.0f;
     // Scale sigmoid outputs to the range [0, 255]

    // Generate the mask bitmap
    Gdiplus::Bitmap* maskBitmap = GenerateMaskBitmap(outputTensorValues,
 targetWidth, targetHeight, amplificationFactor);

    // Resize the mask to match original image dimensions
    Gdiplus::Bitmap* resizedMaskBitmap = new Gdiplus::Bitmap(
    originalBitmap->GetWidth(),
    originalBitmap->GetHeight(), PixelFormat32bppARGB);
    Gdiplus::Graphics graphics(resizedMaskBitmap);
    graphics.DrawImage(maskBitmap, 0, 0, originalBitmap->GetWidth(),
 originalBitmap->GetHeight());

    // Apply the resized mask to the original image
    Gdiplus::Bitmap* resultBitmap = ApplyMaskToImage(
    originalBitmap, resizedMaskBitmap);

    // Clean up
    delete maskBitmap;
    delete resizedMaskBitmap;

    return resultBitmap;
}

 

Challenges and Frustrations

Not everything goes smoothly:

  • ONNX Error Messages: These are cryptic and unhelpful. You’ll question your life choices multiple times.
  • Preprocessing Mismatches: Forgetting to normalize the image properly leads to hilarious (or horrifying) results.
  • Output Mask Issues: Sometimes the mask looks great on one image but terrible on another. It’s the “one size fits all” paradox.

 

Why Use ONNX?

Because it’s:

  • Portable: Train in one framework, run anywhere.
  • Efficient: Optimized for speed (when set up correctly).
  • Cool: Who doesn’t love a buzzword-friendly solution?

 

Conclusion

Running inference on an ONNX model, particularly U^2-Net, is like cooking: follow the recipe, improvise a bit, and expect the fire alarm to go off at least once. But when it works, it’s a chef’s kiss moment. Whether you’re removing backgrounds for professional use or just to meme your friends, the process is satisfying, educational, and occasionally exasperating—just like all things tech.

 Cheers!


 

 

 

 

 

The Great Inference Adventure: From Debugging Despair to Deployment Delight

Background, Be Gone: Running Inference with ONNX and U2Net Ever looked at an image and thought, “The background really ties this mess to...