From 49150338fd012dc2d8086b1a2bfc3c3f652d4b3a Mon Sep 17 00:00:00 2001 From: Castorp <50649074+ShinDongWoon@users.noreply.github.com> Date: Sat, 16 Aug 2025 21:33:01 +0900 Subject: [PATCH] Allow default device selection without interaction --- LGHackerton/predict.py | 3 +- LGHackerton/train.py | 3 +- LGHackerton/utils/device.py | 64 ++++++++++++++++++++++++++----------- 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/LGHackerton/predict.py b/LGHackerton/predict.py index 6ddb083..69cbda4 100644 --- a/LGHackerton/predict.py +++ b/LGHackerton/predict.py @@ -44,7 +44,8 @@ def convert_to_submission(pred_df: pd.DataFrame, sample_path: str) -> pd.DataFra return out_df def main(): - device = select_device() + # Default to environment variable ``DEVICE`` or CPU without interactive input + device = select_device(os.environ.get("DEVICE", "cpu")) pp = Preprocessor(); pp.load(ARTIFACTS_PATH) diff --git a/LGHackerton/train.py b/LGHackerton/train.py index c960a52..6ced657 100644 --- a/LGHackerton/train.py +++ b/LGHackerton/train.py @@ -29,7 +29,8 @@ def _read_table(path: str) -> pd.DataFrame: raise ValueError("Unsupported file type. Use .csv or .xlsx") def main(): - device = select_device() # ask user for compute environment + # Default to environment variable ``DEVICE`` or CPU without interactive input + device = select_device(os.environ.get("DEVICE", "cpu")) df_train_raw = _read_table(TRAIN_PATH) pp = Preprocessor() diff --git a/LGHackerton/utils/device.py b/LGHackerton/utils/device.py index b6b5812..165bd1b 100644 --- a/LGHackerton/utils/device.py +++ b/LGHackerton/utils/device.py @@ -1,24 +1,52 @@ -import torch +"""Utilities for selecting the computation device. +This module tries to import :mod:`torch` but falls back to ``None`` if the +library is not available. Downstream code can therefore still run on +environments where PyTorch is not installed (e.g. when only using the LightGBM +model). +""" -def select_device() -> str: - """Interactively select computing device. +from __future__ import annotations - Returns 'cpu', 'cuda', or 'mps'. +import os + +try: # pragma: no cover - best effort in absence of torch + import torch +except Exception: # torch is optional, treat as unavailable if import fails + torch = None + + +def select_device(default: str | None = None) -> str: + """Select a computation device. + + The function interacts with the user to select ``'cpu'``, ``'cuda'`` or + ``'mps'`` (Apple Metal) when no default is provided. If a default device is + supplied via the ``default`` argument or the ``DEVICE`` environment + variable, that value is returned immediately without prompting. When + running in a non-interactive environment where ``input`` raises + :class:`EOFError`, the ``default``/environment variable value is returned or + ``'cpu'`` if none was given. """ - while True: - choice = input("Select compute environment (macOS/gpu/cpu): ").strip().lower() - if choice == "macos": - if torch.backends.mps.is_available(): - return "mps" - else: + + default_device = default or os.environ.get("DEVICE") + if default_device: + return default_device + + try: + while True: + choice = input("Select compute environment (macOS/gpu/cpu): ").strip().lower() + if choice == "macos": + if torch and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + return "mps" print("MPS not available. Please choose another option.") - elif choice in ("gpu", "cuda"): - if torch.cuda.is_available(): - return "cuda" - else: + elif choice in ("gpu", "cuda"): + if torch and torch.cuda.is_available(): + return "cuda" print("CUDA GPU not available. Please choose another option.") - elif choice == "cpu": - return "cpu" - else: - print("Invalid option. Choose from macOS/gpu/cpu.") + elif choice == "cpu": + return "cpu" + else: + print("Invalid option. Choose from macOS/gpu/cpu.") + except EOFError: + # When running without a TTY, return the safest option. + return "cpu"