diff --git a/classification.py b/classification.py index 5d518b2..7d5c481 100644 --- a/classification.py +++ b/classification.py @@ -418,7 +418,7 @@ def __init__(self, config): from transformers import AutoProcessor, AutoModel # Set up device - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = get_device() # Initialize model and processor self.processor = AutoProcessor.from_pretrained(