Skip to content

Help me inferencing model #2

@md-rifatkhan

Description

@md-rifatkhan

Can you please look at this, how to correctly inference this model? The output image is weired.

Model Link: https://drive.google.com/file/d/1BIiHr5JI-QhQHXWU9QPNqVkMCGg1Y3So/view?usp=sharing

Input Image :
inputimage

Output Image :
outputimage (1)

LOGS:

Model input shape: ['batch_size', 3, 'width', 'height']
Model output shape: ['batch_size', 3, 'width', 'height']
flutter: ORT Environment initialized
flutter: Model initialized
flutter: Is normalized: true
flutter: Tensor shape: [1, 3, 357, 201]
flutter: Image normalized successfully.
flutter: Input tensor created successfully.
flutter: Width: 1428, Height: 804, Channel: 3

Full Code:

class _MyHomePageState extends State<MyHomePage> {
  late OrtSession ortSession;
  img.Image? selectedImage;

  @override
  void initState() {
    OrtEnv.instance.init();
    debugPrint("ORT Environment initialized");
    initializeModel();
    super.initState();
  }

  @override
  void dispose() {
    OrtEnv.instance.release();
    super.dispose();
  }

  Future<void> initializeModel() async {
    final sessionOptions = OrtSessionOptions();
    const assetFileName = 'assets/4xSPANkendata_fp32.onnx';
    final rawAssetFile = await rootBundle.load(assetFileName);
    final bytes = rawAssetFile.buffer.asUint8List();
    ortSession = OrtSession.fromBuffer(bytes, sessionOptions);
    debugPrint("Model initialized");
  }

  Future<void> inference() async {
    if (selectedImage == null) {
      debugPrint('No image selected');
      return;
    }

    if (selectedImage != null) {
      Float32List? floatData;
      try {
        final normalizedImage = selectedImage!.convert(
          format: img.Format.float32,
          numChannels: 3,
        );
        final pixelData = normalizedImage.buffer.asFloat32List();
        floatData = pixelData;
        debugPrint("Is normalized: ${isNormalized(pixelData)}");
      } catch (e) {
        debugPrint("Error during normalization: $e");
      }
      final shape = [1, 3, selectedImage!.width, selectedImage!.height];
      debugPrint('Tensor shape: $shape');

      debugPrint('Image normalized successfully.');

      final inputOrt =
          OrtValueTensor.createTensorWithDataList(floatData!, shape);

      final inputs = {'input': inputOrt};

      debugPrint('Input tensor created successfully.');
      final runOptions = OrtRunOptions();
      final outputs = await ortSession.runAsync(runOptions, inputs);

      inputOrt.release();
      runOptions.release();

      List c = outputs?[0]?.value as List;
      if (c is List<List<List<List<double>>>>) {
        img.Image generatedImage = generateImageFromOutput(c);
        showDialog(
          context: context,
          builder: (BuildContext context) {
            return Dialog(
              child: SizedBox(
                width: generatedImage.width.toDouble(),
                height: generatedImage.height.toDouble(),
                child: Image.memory(
                  Uint8List.fromList(img.encodeJpg(generatedImage)),
                  fit: BoxFit.contain,
                ),
              ),
            );
          },
        );
      } else {
        debugPrint("Output is of unknown type");
      }
      outputs?.forEach((element) {
        element?.release();
      });
    }
  }
  img.Image generateImageFromOutput(
      List<List<List<List<double>>>> outputValue) {
    int width = outputValue[0][0].length;
    int height = outputValue[0][0][0].length;
    int channel = outputValue[0].length;

    print("Width: $width, Height: $height, Channel: $channel");
    img.Image generatedImage = img.Image(
      width: width,
      height: height,
      format: img.Format.uint8,
      numChannels: 3
    );
    for (int y = 0; y < height; y++) {
      for (int x = 0; x < width; x++) {
        int r = (outputValue[0][0][x][y] * 255).toInt().clamp(0, 255);
        int g = (outputValue[0][1][x][y] * 255).toInt().clamp(0, 255);
        int b = (outputValue[0][2][x][y] * 255).toInt().clamp(0, 255);
        // Set pixel value in the generated image
        generatedImage.setPixelRgb(x, y, r, g, b);
      }
    }
    return generatedImage;
  }

  bool isNormalized(Float32List data) {
    for (var pixelValue in data) {
      if (pixelValue < 0 || pixelValue > 1) {
        return false;
      }
    }
    return true;
  }

  Future<void> _pickImage() async {
    final picker = ImagePicker();
    final pickedFile = await picker.pickImage(source: ImageSource.gallery);
    if (pickedFile != null) {
      final bytes = await pickedFile.readAsBytes();
      setState(() {
        selectedImage = img.decodeImage(Uint8List.fromList(bytes));
      });
    }
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: const Text('Image Inference'),
      ),
      body: Center(
        child: Column(
          mainAxisAlignment: MainAxisAlignment.center,
          crossAxisAlignment: CrossAxisAlignment.center,
          children: <Widget>[
            if (selectedImage != null)
              Image.memory(Uint8List.fromList(img.encodePng(selectedImage!))),
            if (selectedImage != null)
              const SizedBox(
                height: 10,
              ),
            ElevatedButton(
              onPressed: _pickImage,
              style: ElevatedButton.styleFrom(
                  shape: const RoundedRectangleBorder(),
                  fixedSize: const Size(150, 35)),
              child: const Text('Select Image'),
            ),
            const SizedBox(
              height: 10,
            ),
            ElevatedButton(
              onPressed: inference,
              style: ElevatedButton.styleFrom(
                  shape: const RoundedRectangleBorder(),
                  fixedSize: const Size(150, 35)),
              child: const Text('Run Inference'),
            ),
          ],
        ),
      ),
    );
  }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions