Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output image is weired while trying to inference esrgan, Someone help me please. #423

Open
md-rifatkhan opened this issue May 11, 2024 · 1 comment

Comments

@md-rifatkhan
Copy link

I'm trying to inference real esrgan, but cant able to get output correctly.
I'm using com.microsoft.onnxruntime:onnxruntime-android:1.17.3

WhatsApp Image 2024-05-11 at 2 33 10 PM

Model Link: Google Drive

Inference Class:

public class ImageInference {

    private static final String TAG = "ImageInference";

    // Load the image from assets
    public static Bitmap loadImageFromAssets(Context context, String fileName) throws IOException {
        Log.d(TAG, "Loading image from assets: " + fileName);
        InputStream is = context.getAssets().open(fileName);
        Bitmap image = BitmapFactory.decodeStream(is);
        if (image != null) {
            Log.d(TAG, "Image loaded successfully: " + fileName);
        } else {
            Log.d(TAG, "Failed to load image: " + fileName);
        }
        return image;
    }

    // Convert Bitmap to FloatBuffer for ONNX Runtime
    public static FloatBuffer bitmapToFloatBuffer(Bitmap bitmap, float mean, float std) {
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        Log.d(TAG, "Preparing to convert bitmap to FloatBuffer. Width: " + width + ", Height: " + height);

        int[] pixels = new int[width * height];
        bitmap.getPixels(pixels, 0, width, 0, 0, width, height);
        FloatBuffer buffer = FloatBuffer.allocate(width * height * 3);

        for (final int val : pixels) {
            buffer.put(((val >> 16) & 0xFF) / 255.f - mean / std); // RED
            buffer.put(((val >> 8) & 0xFF) / 255.f - mean / std);  // GREEN
            buffer.put((val & 0xFF) / 255.f - mean / std);         // BLUE
        }
        buffer.flip(); // Prepare buffer for reading
        Log.d(TAG, "Bitmap successfully converted to FloatBuffer.");
        return buffer;
    }

    public static Bitmap tensorToBitmap(float[][][][] tensor) {
        // Assume tensor dimensions are [1][3][height][width]
        int height = tensor[0][0].length;
        int width = tensor[0][0][0].length;

        Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
        int[] pixels = new int[width * height];

        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int r = (int) (tensor[0][0][y][x] * 255);
                int g = (int) (tensor[0][1][y][x] * 255);
                int b = (int) (tensor[0][2][y][x] * 255);
                pixels[y * width + x] = 0xFF000000 | (r << 16) | (g << 8) | b;
            }
        }

        bitmap.setPixels(pixels, 0, width, 0, 0, width, height);
        return bitmap;
    }

    public static float[][][][] runInference(Context context, String modelPath, Bitmap image) throws OrtException, IOException {

        InputStream is = context.getAssets().open(modelPath);
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        byte[] buf = new byte[1024];
        for (int readNum; (readNum = is.read(buf)) != -1;) {
            bos.write(buf, 0, readNum);
        }
        byte[] modelBytes = bos.toByteArray();

        Log.d(TAG, "Starting inference with model: " + modelPath);

        OrtEnvironment env = OrtEnvironment.getEnvironment();
        InputStream modelInputStream = context.getAssets().open(modelPath);
        OrtSession session = env.createSession(modelBytes, new OrtSession.SessionOptions());
        Log.d(TAG, "Model and Session created successfully." );
        try {
            int width = image.getWidth();
            int height = image.getHeight();
            FloatBuffer inputBuffer = bitmapToFloatBuffer(image, 0f, 1f);
            OnnxTensor tensor = OnnxTensor.createTensor(env, inputBuffer, new long[]{1, 3, height, width});
            OrtSession.Result results = session.run(Collections.singletonMap("input", tensor));
            float[][][][] output = (float[][][][]) results.get(0).getValue();
            tensor.close();
            Log.d(TAG, "Inference completed successfully.");
            return output;
        } finally {
            session.close();
            modelInputStream.close();
            env.close();
            Log.d(TAG, "Cleaned up ONNX resources.");
        }
    }
}

Main Activity

public class MainActivity extends Activity {

    private ImageView orginalImageView;
    private ImageView outputImageView;
    private TextView textViewResult;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        orginalImageView = findViewById(R.id.orginalImage);
        outputImageView = findViewById(R.id.outputImage);
        textViewResult = findViewById(R.id.textViewResult);

        try {
            Bitmap image = ImageInference.loadImageFromAssets(this, "LR.png");
            orginalImageView.setImageBitmap(image);
            float[][][][] output = ImageInference.runInference(this, "realesr-general-x4v3-fp32.onnx", image);
            if (output != null && output.length > 0 && output[0].length > 0 && output[0][0].length > 0) {
                int height = output[0][0].length;
                int width = output[0][0][0].length;

                Bitmap outputImage = ImageInference.tensorToBitmap(output);
                outputImageView.setImageBitmap(outputImage);  // Display the output image
                textViewResult.setText("Inference complete with output shape: " + "Height " + height + " Width "  + width);
            } else {
                textViewResult.setText("Inference complete but no valid output!");
            }
        } catch (Exception e) {
            textViewResult.setText("Inference failed: " + e.getMessage());
            e.printStackTrace();
        }
    }
@md-rifatkhan
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant