diff --git a/PytorchWildlife/data/datasets.py b/PytorchWildlife/data/datasets.py index ce280f02f..39870a626 100644 --- a/PytorchWildlife/data/datasets.py +++ b/PytorchWildlife/data/datasets.py @@ -14,6 +14,7 @@ # Making the DetectionImageFolder class available for import from this module __all__ = [ + "ClassificationImageFolder", "DetectionImageFolder", ] diff --git a/PytorchWildlife/models/classification/resnet_base/base_classifier.py b/PytorchWildlife/models/classification/resnet_base/base_classifier.py index 965ff4389..a7448c182 100644 --- a/PytorchWildlife/models/classification/resnet_base/base_classifier.py +++ b/PytorchWildlife/models/classification/resnet_base/base_classifier.py @@ -158,10 +158,9 @@ def batch_image_classification(self, data_path=None, det_results=None, id_strip= """ if data_path: - dataset = pw_data.ImageFolder( + dataset = pw_data.ClassificationImageFolder( data_path, transform=self.transform, - path_head='.' ) elif det_results: dataset = pw_data.DetectionCrops( diff --git a/PytorchWildlife/models/classification/timm_base/base_classifier.py b/PytorchWildlife/models/classification/timm_base/base_classifier.py index 682333b17..5a4268e04 100644 --- a/PytorchWildlife/models/classification/timm_base/base_classifier.py +++ b/PytorchWildlife/models/classification/timm_base/base_classifier.py @@ -177,10 +177,9 @@ def batch_image_classification(self, data_path=None, det_results=None, id_strip= """ if data_path: - dataset = pw_data.ImageFolder( + dataset = pw_data.ClassificationImageFolder( data_path, transform=self.transform, - path_head='.' ) elif det_results: dataset = pw_data.DetectionCrops(