Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 71 additions & 15 deletions roll/utils/functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum
import traceback
import heapq
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -24,35 +24,67 @@
logger = get_logger()


def tensor_to_cpu_visitor(obj, path):
def tensor_to_cpu_visitor(obj: Any, path: tuple) -> bool:
"""Move tensor to CPU if it's not already on CPU.

Args:
obj: Any Python object to check.
path: Current traversal path (used for debugging/logging).

Returns:
True if the object is a tensor, False otherwise.
"""
if torch.is_tensor(obj):
if not obj.is_cpu:
obj.data = obj.data.detach().cpu()
return True
return False


def tensor_to_cuda_visitor(obj, path):
def tensor_to_cuda_visitor(obj: Any, path: tuple) -> bool:
"""Move tensor to CUDA device if it's not already on GPU.

Args:
obj: Any Python object to check.
path: Current traversal path (used for debugging/logging).

Returns:
True if the object is a tensor, False otherwise.
"""
if torch.is_tensor(obj):
if not obj.is_cuda:
obj.data = obj.data.detach().to(device=torch.device(current_platform.device_type))
return True
return False


def delete_tensor_grad_visitor(obj, path):
def delete_tensor_grad_visitor(obj: Any, path: tuple) -> bool:
"""Delete gradient of a tensor if present.

Args:
obj: Any Python object to check.
path: Current traversal path (used for debugging/logging).

Returns:
True if the object is a tensor, False otherwise.
"""
if torch.is_tensor(obj):
obj.grad = None
return True
return False


def traverse_obj(value, visitor, path=()):
"""
遍历对象的所有属性,包括属性的属性,找到所有的 Tensor。
:param value: 任意 Python 对象
:visitor
:path
def traverse_obj(value: Any, visitor: Callable[[Any, tuple], bool], path: tuple = ()) -> None:
"""Traverse all attributes of an object recursively to find all tensors.

This function recursively traverses through nested dictionaries, lists, tuples,
and object attributes, applying the visitor function to each element.

Args:
value: Any Python object to traverse.
visitor: A callable that takes (obj, path) and returns True if traversal
should stop for that branch, False to continue traversing.
path: Current traversal path as a tuple of keys/indices. Defaults to empty tuple.
"""
if visitor(value, path):
return
Expand Down Expand Up @@ -123,9 +155,19 @@ def append_to_dict(data: Dict, new_data: Dict):
data[key].append(val)


def flatten_sum(values):
"""Flatten nested lists/tuples and sum all numeric values."""
total = 0
def flatten_sum(values: list | tuple) -> float:
"""Flatten nested lists/tuples and sum all numeric values.

Recursively traverses nested list/tuple structures and sums all
integer and float values found.

Args:
values: A nested structure of lists and/or tuples containing numeric values.

Returns:
The sum of all numeric values in the nested structure.
"""
total = 0.0
for v in values:
if isinstance(v, (list, tuple)):
total += flatten_sum(v)
Expand Down Expand Up @@ -296,8 +338,22 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> tor
return (tensor * mask).sum()


def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
"""Compute variance of tensor with masked values.

Args:
values: Input tensor to compute variance for.
mask: Boolean mask tensor, True for valid positions.
unbiased: If True, applies Bessel's correction (N-1 denominator).
Defaults to True.

Returns:
The masked variance as a scalar tensor.

Raises:
ValueError: If mask has no valid elements (sum is 0).
ValueError: If mask has exactly one valid element with unbiased=True.
"""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
Expand Down