From 24a0a55416d14f0e043c3b8b96f791c0fd789955 Mon Sep 17 00:00:00 2001 From: jQuinRivero Date: Tue, 3 Mar 2026 17:32:15 -0300 Subject: [PATCH] Fix batch classification TypeError when using data_path (#611)\n\nReplace pw_data.ImageFolder with pw_data.ClassificationImageFolder in\nboth timm_base and resnet_base classifiers, and remove the invalid\npath_head keyword argument that ImageFolder does not accept.\n\nImageFolder.__getitem__ is abstract (returns None), so the correct\nsubclass for classification is ClassificationImageFolder, which returns\nthe (img, img_path) tuple the dataloader loop expects.\n\nAlso export ClassificationImageFolder from datasets __all__." --- PytorchWildlife/data/datasets.py | 1 + .../models/classification/resnet_base/base_classifier.py | 3 +-- .../models/classification/timm_base/base_classifier.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/PytorchWildlife/data/datasets.py b/PytorchWildlife/data/datasets.py index ce280f02f..39870a626 100644 --- a/PytorchWildlife/data/datasets.py +++ b/PytorchWildlife/data/datasets.py @@ -14,6 +14,7 @@ # Making the DetectionImageFolder class available for import from this module __all__ = [ + "ClassificationImageFolder", "DetectionImageFolder", ] diff --git a/PytorchWildlife/models/classification/resnet_base/base_classifier.py b/PytorchWildlife/models/classification/resnet_base/base_classifier.py index 965ff4389..a7448c182 100644 --- a/PytorchWildlife/models/classification/resnet_base/base_classifier.py +++ b/PytorchWildlife/models/classification/resnet_base/base_classifier.py @@ -158,10 +158,9 @@ def batch_image_classification(self, data_path=None, det_results=None, id_strip= """ if data_path: - dataset = pw_data.ImageFolder( + dataset = pw_data.ClassificationImageFolder( data_path, transform=self.transform, - path_head='.' ) elif det_results: dataset = pw_data.DetectionCrops( diff --git a/PytorchWildlife/models/classification/timm_base/base_classifier.py b/PytorchWildlife/models/classification/timm_base/base_classifier.py index 682333b17..5a4268e04 100644 --- a/PytorchWildlife/models/classification/timm_base/base_classifier.py +++ b/PytorchWildlife/models/classification/timm_base/base_classifier.py @@ -177,10 +177,9 @@ def batch_image_classification(self, data_path=None, det_results=None, id_strip= """ if data_path: - dataset = pw_data.ImageFolder( + dataset = pw_data.ClassificationImageFolder( data_path, transform=self.transform, - path_head='.' ) elif det_results: dataset = pw_data.DetectionCrops(