Can you please look at this, how to correctly inference this model? The output image is weired.
Model input shape: ['batch_size', 3, 'width', 'height']
Model output shape: ['batch_size', 3, 'width', 'height']
flutter: ORT Environment initialized
flutter: Model initialized
flutter: Is normalized: true
flutter: Tensor shape: [1, 3, 357, 201]
flutter: Image normalized successfully.
flutter: Input tensor created successfully.
flutter: Width: 1428, Height: 804, Channel: 3
class _MyHomePageState extends State<MyHomePage> {
late OrtSession ortSession;
img.Image? selectedImage;
@override
void initState() {
OrtEnv.instance.init();
debugPrint("ORT Environment initialized");
initializeModel();
super.initState();
}
@override
void dispose() {
OrtEnv.instance.release();
super.dispose();
}
Future<void> initializeModel() async {
final sessionOptions = OrtSessionOptions();
const assetFileName = 'assets/4xSPANkendata_fp32.onnx';
final rawAssetFile = await rootBundle.load(assetFileName);
final bytes = rawAssetFile.buffer.asUint8List();
ortSession = OrtSession.fromBuffer(bytes, sessionOptions);
debugPrint("Model initialized");
}
Future<void> inference() async {
if (selectedImage == null) {
debugPrint('No image selected');
return;
}
if (selectedImage != null) {
Float32List? floatData;
try {
final normalizedImage = selectedImage!.convert(
format: img.Format.float32,
numChannels: 3,
);
final pixelData = normalizedImage.buffer.asFloat32List();
floatData = pixelData;
debugPrint("Is normalized: ${isNormalized(pixelData)}");
} catch (e) {
debugPrint("Error during normalization: $e");
}
final shape = [1, 3, selectedImage!.width, selectedImage!.height];
debugPrint('Tensor shape: $shape');
debugPrint('Image normalized successfully.');
final inputOrt =
OrtValueTensor.createTensorWithDataList(floatData!, shape);
final inputs = {'input': inputOrt};
debugPrint('Input tensor created successfully.');
final runOptions = OrtRunOptions();
final outputs = await ortSession.runAsync(runOptions, inputs);
inputOrt.release();
runOptions.release();
List c = outputs?[0]?.value as List;
if (c is List<List<List<List<double>>>>) {
img.Image generatedImage = generateImageFromOutput(c);
showDialog(
context: context,
builder: (BuildContext context) {
return Dialog(
child: SizedBox(
width: generatedImage.width.toDouble(),
height: generatedImage.height.toDouble(),
child: Image.memory(
Uint8List.fromList(img.encodeJpg(generatedImage)),
fit: BoxFit.contain,
),
),
);
},
);
} else {
debugPrint("Output is of unknown type");
}
outputs?.forEach((element) {
element?.release();
});
}
}
img.Image generateImageFromOutput(
List<List<List<List<double>>>> outputValue) {
int width = outputValue[0][0].length;
int height = outputValue[0][0][0].length;
int channel = outputValue[0].length;
print("Width: $width, Height: $height, Channel: $channel");
img.Image generatedImage = img.Image(
width: width,
height: height,
format: img.Format.uint8,
numChannels: 3
);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int r = (outputValue[0][0][x][y] * 255).toInt().clamp(0, 255);
int g = (outputValue[0][1][x][y] * 255).toInt().clamp(0, 255);
int b = (outputValue[0][2][x][y] * 255).toInt().clamp(0, 255);
// Set pixel value in the generated image
generatedImage.setPixelRgb(x, y, r, g, b);
}
}
return generatedImage;
}
bool isNormalized(Float32List data) {
for (var pixelValue in data) {
if (pixelValue < 0 || pixelValue > 1) {
return false;
}
}
return true;
}
Future<void> _pickImage() async {
final picker = ImagePicker();
final pickedFile = await picker.pickImage(source: ImageSource.gallery);
if (pickedFile != null) {
final bytes = await pickedFile.readAsBytes();
setState(() {
selectedImage = img.decodeImage(Uint8List.fromList(bytes));
});
}
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: const Text('Image Inference'),
),
body: Center(
child: Column(
mainAxisAlignment: MainAxisAlignment.center,
crossAxisAlignment: CrossAxisAlignment.center,
children: <Widget>[
if (selectedImage != null)
Image.memory(Uint8List.fromList(img.encodePng(selectedImage!))),
if (selectedImage != null)
const SizedBox(
height: 10,
),
ElevatedButton(
onPressed: _pickImage,
style: ElevatedButton.styleFrom(
shape: const RoundedRectangleBorder(),
fixedSize: const Size(150, 35)),
child: const Text('Select Image'),
),
const SizedBox(
height: 10,
),
ElevatedButton(
onPressed: inference,
style: ElevatedButton.styleFrom(
shape: const RoundedRectangleBorder(),
fixedSize: const Size(150, 35)),
child: const Text('Run Inference'),
),
],
),
),
);
}
}
Can you please look at this, how to correctly inference this model? The output image is weired.
Model Link: https://drive.google.com/file/d/1BIiHr5JI-QhQHXWU9QPNqVkMCGg1Y3So/view?usp=sharing
Input Image :

Output Image :

LOGS:
Full Code: