Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion LGHackerton/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def convert_to_submission(pred_df: pd.DataFrame, sample_path: str) -> pd.DataFra
return out_df

def main():
device = select_device()
# Default to environment variable ``DEVICE`` or CPU without interactive input
device = select_device(os.environ.get("DEVICE", "cpu"))

pp = Preprocessor(); pp.load(ARTIFACTS_PATH)

Expand Down
3 changes: 2 additions & 1 deletion LGHackerton/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def _read_table(path: str) -> pd.DataFrame:
raise ValueError("Unsupported file type. Use .csv or .xlsx")

def main():
device = select_device() # ask user for compute environment
# Default to environment variable ``DEVICE`` or CPU without interactive input
device = select_device(os.environ.get("DEVICE", "cpu"))

df_train_raw = _read_table(TRAIN_PATH)
pp = Preprocessor()
Expand Down
64 changes: 46 additions & 18 deletions LGHackerton/utils/device.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,52 @@
import torch
"""Utilities for selecting the computation device.

This module tries to import :mod:`torch` but falls back to ``None`` if the
library is not available. Downstream code can therefore still run on
environments where PyTorch is not installed (e.g. when only using the LightGBM
model).
"""

def select_device() -> str:
"""Interactively select computing device.
from __future__ import annotations

Returns 'cpu', 'cuda', or 'mps'.
import os

try: # pragma: no cover - best effort in absence of torch
import torch
except Exception: # torch is optional, treat as unavailable if import fails
torch = None


def select_device(default: str | None = None) -> str:
"""Select a computation device.

The function interacts with the user to select ``'cpu'``, ``'cuda'`` or
``'mps'`` (Apple Metal) when no default is provided. If a default device is
supplied via the ``default`` argument or the ``DEVICE`` environment
variable, that value is returned immediately without prompting. When
running in a non-interactive environment where ``input`` raises
:class:`EOFError`, the ``default``/environment variable value is returned or
``'cpu'`` if none was given.
"""
while True:
choice = input("Select compute environment (macOS/gpu/cpu): ").strip().lower()
if choice == "macos":
if torch.backends.mps.is_available():
return "mps"
else:

default_device = default or os.environ.get("DEVICE")
if default_device:
return default_device

try:
while True:
choice = input("Select compute environment (macOS/gpu/cpu): ").strip().lower()
if choice == "macos":
if torch and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return "mps"
print("MPS not available. Please choose another option.")
elif choice in ("gpu", "cuda"):
if torch.cuda.is_available():
return "cuda"
else:
elif choice in ("gpu", "cuda"):
if torch and torch.cuda.is_available():
return "cuda"
print("CUDA GPU not available. Please choose another option.")
elif choice == "cpu":
return "cpu"
else:
print("Invalid option. Choose from macOS/gpu/cpu.")
elif choice == "cpu":
return "cpu"
else:
print("Invalid option. Choose from macOS/gpu/cpu.")
except EOFError:
# When running without a TTY, return the safest option.
return "cpu"