diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py index 6e251a092..b465fb30a 100644 --- a/roll/utils/functionals.py +++ b/roll/utils/functionals.py @@ -73,15 +73,15 @@ def traverse_obj(value, visitor, path=()): continue -def union_two_dict(dict1: Dict, dict2: Dict): - """Union two dict. Will throw an error if there is an item not the same object with the same key. +def union_two_dict(dict1: Dict, dict2: Dict) -> Dict: + """Union two dicts. Will throw an error if there is an item not the same object with the same key. Args: - dict1: - dict2: + dict1: First dictionary to merge into. + dict2: Second dictionary to merge from. Returns: - + The merged dictionary (dict1 with dict2's items added). """ for key, val in dict2.items(): if key in dict1: @@ -116,7 +116,16 @@ def divide_by_chunk_size( return split_data -def append_to_dict(data: Dict, new_data: Dict): +def append_to_dict(data: Dict, new_data: Dict) -> None: + """Append values from new_data to lists in data dictionary. + + Args: + data: Dictionary to append to (values are lists). + new_data: Dictionary with values to append. + + Returns: + None. Modifies data in place. + """ for key, val in new_data.items(): if key not in data: data[key] = [] @@ -170,7 +179,17 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]: return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() -def compute_clip_fraction(values: torch.Tensor, clip_max: float, clip_min: float): +def compute_clip_fraction(values: torch.Tensor, clip_max: float, clip_min: float) -> float: + """Compute the fraction of values that are outside the clip range. + + Args: + values: Input tensor to check. + clip_max: Maximum value for clipping. + clip_min: Minimum value for clipping. + + Returns: + Fraction of values outside [clip_min, clip_max] range. + """ numel = values.numel() num_clipped = (values > clip_max).sum().item() + (values < clip_min).sum().item() clipfrac = num_clipped / numel if numel > 0 else 0.0 @@ -217,16 +236,29 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return log_probs_labels.squeeze(-1) -def entropy_from_logits(logits: torch.Tensor): - """Calculate entropy from logits.""" +def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: + """Calculate entropy from logits. + + Args: + logits: Input logits tensor. + + Returns: + Entropy tensor computed from the logits. + """ logits = logits.float() pd = torch.nn.functional.softmax(logits, dim=-1) entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) return entropy -def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str, batch_num_tokens: int = None, - global_valid_samples: int = None, weights: Optional[torch.Tensor] = None): +def agg_loss( + loss_mat: torch.Tensor, + loss_mask: torch.Tensor, + loss_agg_mode: str, + batch_num_tokens: int = None, + global_valid_samples: int = None, + weights: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ ref: https://github.com/volcengine/verl/blob/78532923368aeb058f62201489546d013df47710/verl/trainer/ppo/core_algos.py#L370 Aggregate the loss matrix into a scalar. @@ -314,11 +346,20 @@ def masked_var(values, mask, unbiased=True): return variance -def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64): - """ - e.g. end of sentence token=1 - response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0] - eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] +def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64) -> torch.Tensor: + """Generate a mask that is 1 before (and including) the first EOS token, 0 after. + + Args: + response_id: Token IDs tensor, shape (batch_size, seq_len). + eos_token: The end-of-sentence token ID. Defaults to 2. + dtype: Output tensor dtype. Defaults to torch.int64. + + Returns: + EOS mask tensor where positions before/including first EOS are 1, rest are 0. + + Example: + >>> response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0] + >>> eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] """ eos_mask = response_id.eq(eos_token).long() eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool() @@ -326,20 +367,33 @@ def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int6 return eos_mask -def get_pad_mask(response_id: torch.Tensor, pad_token: int = 0, eos_token: int = 1, dtype=torch.int64): - """ - e.g. pad token=0 - response_id: [1, 2, 2, 42, 3, 5, 1, 0, 0] - pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] - - If eos_token == pad_token, the first pad token (which is the eos token) should be kept. - e.g. pad_token=0, eos_token=0 - response_id: [1, 2, 2, 42, 3, 5, 0, 0, 0] - pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] (first pad token/eos token is kept) +def get_pad_mask( + response_id: torch.Tensor, pad_token: int = 0, eos_token: int = 1, dtype=torch.int64 +) -> torch.Tensor: + """Generate a mask that is 1 for non-pad tokens, 0 for pad tokens. + + Args: + response_id: Token IDs tensor, shape (batch_size, seq_len). + pad_token: The padding token ID. Defaults to 0. + eos_token: The end-of-sentence token ID. Defaults to 1. + dtype: Output tensor dtype. Defaults to torch.int64. + + Returns: + Pad mask tensor where non-pad positions are 1, pad positions are 0. + + Example: + >>> response_id: [1, 2, 2, 42, 3, 5, 1, 0, 0] + >>> pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] + + Note: + If eos_token == pad_token, the first pad token (which is the EOS token) is kept. + >>> pad_token=0, eos_token=0 + >>> response_id: [1, 2, 2, 42, 3, 5, 0, 0, 0] + >>> pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] """ pad_mask = response_id.not_equal(pad_token).to(dtype) - # eos_token == pad_token, 需要保留第一个pad token否则会误将eos token mask掉 + # If eos_token == pad_token, keep the first pad token to avoid masking the EOS token if eos_token == pad_token: pad_positions = response_id.eq(pad_token).to(dtype) cumsum_pad = torch.cumsum(pad_positions, dim=-1) @@ -360,8 +414,18 @@ def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps return mean_centered * var.clamp(min=eps).rsqrt() -def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True): - """Whiten values with masked values.""" +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values using masked mean and variance normalization. + + Args: + values: Input tensor to whiten. + mask: Boolean mask tensor, True for valid positions. + shift_mean: If False, add the original mean back after whitening. + Defaults to True. + + Returns: + Whitened tensor with zero mean and unit variance (over masked positions). + """ mean, var = masked_mean(values, mask), masked_var(values, mask) whitened = (values - mean) * torch.rsqrt(var + 1e-8) if not shift_mean: @@ -369,9 +433,23 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = T return whitened -def response_level_masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True): - """Whiten values with masked values.""" - # 考虑response的影响? +def response_level_masked_whiten( + values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True +) -> torch.Tensor: + """Whiten values at the response level using masked mean and variance. + + This function computes the mean per response (along dim=-1) and then + whitens based on the variance of those response-level means. + + Args: + values: Input tensor to whiten, shape (batch_size, seq_len). + mask: Boolean mask tensor, True for valid positions. + shift_mean: If False, add the original mean back after whitening. + Defaults to True. + + Returns: + Whitened tensor normalized at the response level. + """ mean = masked_mean(values, mask, dim=-1) var = masked_var(mean, mask) mean = mean.mean() @@ -452,7 +530,22 @@ def reduce_metrics_list(metrics_list: list, reduce_func=np.mean) -> dict: return merged_metrics -def pad_to_length(tensor: torch.Tensor, length, pad_value, dim=-1): +def pad_to_length( + tensor: torch.Tensor, length: int, pad_value: float, dim: int = -1 +) -> torch.Tensor: + """Pad or truncate tensor to a specified length along a dimension. + + Args: + tensor: Input tensor to pad or truncate. + length: Target length along the specified dimension. + pad_value: Value to use for padding. + dim: Dimension along which to pad. Defaults to -1 (last dimension). + + Returns: + Tensor with the specified length along the given dimension. + If the tensor is longer than the target length, it is truncated. + If shorter, it is padded with `pad_value`. + """ if tensor.size(dim) >= length: indices = [slice(None)] * tensor.ndim indices[dim] = slice(0, length)