diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py index 6e251a092..8b353828b 100644 --- a/roll/utils/functionals.py +++ b/roll/utils/functionals.py @@ -8,7 +8,7 @@ import enum import traceback import heapq -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -24,7 +24,16 @@ logger = get_logger() -def tensor_to_cpu_visitor(obj, path): +def tensor_to_cpu_visitor(obj: Any, path: tuple) -> bool: + """Move tensor to CPU if it's not already on CPU. + + Args: + obj: Any Python object to check. + path: Current traversal path (used for debugging/logging). + + Returns: + True if the object is a tensor, False otherwise. + """ if torch.is_tensor(obj): if not obj.is_cpu: obj.data = obj.data.detach().cpu() @@ -32,7 +41,16 @@ def tensor_to_cpu_visitor(obj, path): return False -def tensor_to_cuda_visitor(obj, path): +def tensor_to_cuda_visitor(obj: Any, path: tuple) -> bool: + """Move tensor to CUDA device if it's not already on GPU. + + Args: + obj: Any Python object to check. + path: Current traversal path (used for debugging/logging). + + Returns: + True if the object is a tensor, False otherwise. + """ if torch.is_tensor(obj): if not obj.is_cuda: obj.data = obj.data.detach().to(device=torch.device(current_platform.device_type)) @@ -40,19 +58,33 @@ def tensor_to_cuda_visitor(obj, path): return False -def delete_tensor_grad_visitor(obj, path): +def delete_tensor_grad_visitor(obj: Any, path: tuple) -> bool: + """Delete gradient of a tensor if present. + + Args: + obj: Any Python object to check. + path: Current traversal path (used for debugging/logging). + + Returns: + True if the object is a tensor, False otherwise. + """ if torch.is_tensor(obj): obj.grad = None return True return False -def traverse_obj(value, visitor, path=()): - """ - 遍历对象的所有属性,包括属性的属性,找到所有的 Tensor。 - :param value: 任意 Python 对象 - :visitor - :path +def traverse_obj(value: Any, visitor: Callable[[Any, tuple], bool], path: tuple = ()) -> None: + """Traverse all attributes of an object recursively to find all tensors. + + This function recursively traverses through nested dictionaries, lists, tuples, + and object attributes, applying the visitor function to each element. + + Args: + value: Any Python object to traverse. + visitor: A callable that takes (obj, path) and returns True if traversal + should stop for that branch, False to continue traversing. + path: Current traversal path as a tuple of keys/indices. Defaults to empty tuple. """ if visitor(value, path): return @@ -123,9 +155,19 @@ def append_to_dict(data: Dict, new_data: Dict): data[key].append(val) -def flatten_sum(values): - """Flatten nested lists/tuples and sum all numeric values.""" - total = 0 +def flatten_sum(values: list | tuple) -> float: + """Flatten nested lists/tuples and sum all numeric values. + + Recursively traverses nested list/tuple structures and sums all + integer and float values found. + + Args: + values: A nested structure of lists and/or tuples containing numeric values. + + Returns: + The sum of all numeric values in the nested structure. + """ + total = 0.0 for v in values: if isinstance(v, (list, tuple)): total += flatten_sum(v) @@ -296,8 +338,22 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> tor return (tensor * mask).sum() -def masked_var(values, mask, unbiased=True): - """Compute variance of tensor with masked values.""" +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values. + + Args: + values: Input tensor to compute variance for. + mask: Boolean mask tensor, True for valid positions. + unbiased: If True, applies Bessel's correction (N-1 denominator). + Defaults to True. + + Returns: + The masked variance as a scalar tensor. + + Raises: + ValueError: If mask has no valid elements (sum is 0). + ValueError: If mask has exactly one valid element with unbiased=True. + """ mean = masked_mean(values, mask) centered_values = values - mean variance = masked_mean(centered_values**2, mask)