Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 126 additions & 33 deletions roll/utils/functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ def traverse_obj(value, visitor, path=()):
continue


def union_two_dict(dict1: Dict, dict2: Dict):
"""Union two dict. Will throw an error if there is an item not the same object with the same key.
def union_two_dict(dict1: Dict, dict2: Dict) -> Dict:
"""Union two dicts. Will throw an error if there is an item not the same object with the same key.

Args:
dict1:
dict2:
dict1: First dictionary to merge into.
dict2: Second dictionary to merge from.

Returns:

The merged dictionary (dict1 with dict2's items added).
"""
for key, val in dict2.items():
if key in dict1:
Expand Down Expand Up @@ -116,7 +116,16 @@ def divide_by_chunk_size(
return split_data


def append_to_dict(data: Dict, new_data: Dict):
def append_to_dict(data: Dict, new_data: Dict) -> None:
"""Append values from new_data to lists in data dictionary.

Args:
data: Dictionary to append to (values are lists).
new_data: Dictionary with values to append.

Returns:
None. Modifies data in place.
"""
for key, val in new_data.items():
if key not in data:
data[key] = []
Expand Down Expand Up @@ -170,7 +179,17 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()


def compute_clip_fraction(values: torch.Tensor, clip_max: float, clip_min: float):
def compute_clip_fraction(values: torch.Tensor, clip_max: float, clip_min: float) -> float:
"""Compute the fraction of values that are outside the clip range.

Args:
values: Input tensor to check.
clip_max: Maximum value for clipping.
clip_min: Minimum value for clipping.

Returns:
Fraction of values outside [clip_min, clip_max] range.
"""
numel = values.numel()
num_clipped = (values > clip_max).sum().item() + (values < clip_min).sum().item()
clipfrac = num_clipped / numel if numel > 0 else 0.0
Expand Down Expand Up @@ -217,16 +236,29 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return log_probs_labels.squeeze(-1)


def entropy_from_logits(logits: torch.Tensor):
"""Calculate entropy from logits."""
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""Calculate entropy from logits.

Args:
logits: Input logits tensor.

Returns:
Entropy tensor computed from the logits.
"""
logits = logits.float()
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
return entropy


def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str, batch_num_tokens: int = None,
global_valid_samples: int = None, weights: Optional[torch.Tensor] = None):
def agg_loss(
loss_mat: torch.Tensor,
loss_mask: torch.Tensor,
loss_agg_mode: str,
batch_num_tokens: int = None,
global_valid_samples: int = None,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
ref: https://github.com/volcengine/verl/blob/78532923368aeb058f62201489546d013df47710/verl/trainer/ppo/core_algos.py#L370
Aggregate the loss matrix into a scalar.
Expand Down Expand Up @@ -314,32 +346,54 @@ def masked_var(values, mask, unbiased=True):
return variance


def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64):
"""
e.g. end of sentence token=1
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64) -> torch.Tensor:
"""Generate a mask that is 1 before (and including) the first EOS token, 0 after.

Args:
response_id: Token IDs tensor, shape (batch_size, seq_len).
eos_token: The end-of-sentence token ID. Defaults to 2.
dtype: Output tensor dtype. Defaults to torch.int64.

Returns:
EOS mask tensor where positions before/including first EOS are 1, rest are 0.

Example:
>>> response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
>>> eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
"""
eos_mask = response_id.eq(eos_token).long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask


def get_pad_mask(response_id: torch.Tensor, pad_token: int = 0, eos_token: int = 1, dtype=torch.int64):
"""
e.g. pad token=0
response_id: [1, 2, 2, 42, 3, 5, 1, 0, 0]
pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]

If eos_token == pad_token, the first pad token (which is the eos token) should be kept.
e.g. pad_token=0, eos_token=0
response_id: [1, 2, 2, 42, 3, 5, 0, 0, 0]
pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] (first pad token/eos token is kept)
def get_pad_mask(
response_id: torch.Tensor, pad_token: int = 0, eos_token: int = 1, dtype=torch.int64
) -> torch.Tensor:
"""Generate a mask that is 1 for non-pad tokens, 0 for pad tokens.

Args:
response_id: Token IDs tensor, shape (batch_size, seq_len).
pad_token: The padding token ID. Defaults to 0.
eos_token: The end-of-sentence token ID. Defaults to 1.
dtype: Output tensor dtype. Defaults to torch.int64.

Returns:
Pad mask tensor where non-pad positions are 1, pad positions are 0.

Example:
>>> response_id: [1, 2, 2, 42, 3, 5, 1, 0, 0]
>>> pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]

Note:
If eos_token == pad_token, the first pad token (which is the EOS token) is kept.
>>> pad_token=0, eos_token=0
>>> response_id: [1, 2, 2, 42, 3, 5, 0, 0, 0]
>>> pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
"""
pad_mask = response_id.not_equal(pad_token).to(dtype)

# eos_token == pad_token, 需要保留第一个pad token否则会误将eos token mask掉
# If eos_token == pad_token, keep the first pad token to avoid masking the EOS token
if eos_token == pad_token:
pad_positions = response_id.eq(pad_token).to(dtype)
cumsum_pad = torch.cumsum(pad_positions, dim=-1)
Expand All @@ -360,18 +414,42 @@ def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps
return mean_centered * var.clamp(min=eps).rsqrt()


def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True):
"""Whiten values with masked values."""
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
"""Whiten values using masked mean and variance normalization.

Args:
values: Input tensor to whiten.
mask: Boolean mask tensor, True for valid positions.
shift_mean: If False, add the original mean back after whitening.
Defaults to True.

Returns:
Whitened tensor with zero mean and unit variance (over masked positions).
"""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened


def response_level_masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True):
"""Whiten values with masked values."""
# 考虑response的影响?
def response_level_masked_whiten(
values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True
) -> torch.Tensor:
"""Whiten values at the response level using masked mean and variance.

This function computes the mean per response (along dim=-1) and then
whitens based on the variance of those response-level means.

Args:
values: Input tensor to whiten, shape (batch_size, seq_len).
mask: Boolean mask tensor, True for valid positions.
shift_mean: If False, add the original mean back after whitening.
Defaults to True.

Returns:
Whitened tensor normalized at the response level.
"""
mean = masked_mean(values, mask, dim=-1)
var = masked_var(mean, mask)
mean = mean.mean()
Expand Down Expand Up @@ -452,7 +530,22 @@ def reduce_metrics_list(metrics_list: list, reduce_func=np.mean) -> dict:
return merged_metrics


def pad_to_length(tensor: torch.Tensor, length, pad_value, dim=-1):
def pad_to_length(
tensor: torch.Tensor, length: int, pad_value: float, dim: int = -1
) -> torch.Tensor:
"""Pad or truncate tensor to a specified length along a dimension.

Args:
tensor: Input tensor to pad or truncate.
length: Target length along the specified dimension.
pad_value: Value to use for padding.
dim: Dimension along which to pad. Defaults to -1 (last dimension).

Returns:
Tensor with the specified length along the given dimension.
If the tensor is longer than the target length, it is truncated.
If shorter, it is padded with `pad_value`.
"""
if tensor.size(dim) >= length:
indices = [slice(None)] * tensor.ndim
indices[dim] = slice(0, length)
Expand Down