diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 757832a7..f1212f4b 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -121,7 +121,7 @@ def _fused_reverse_kl_base( # Compute loss terms: student_probs * log_ratio, then sum over vocab # This is equivalent to kl_div(..., log_target=True) but more memory efficient log_ratio = predicted_log_probability - target_log_probability - per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1) + per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1, keepdim=True) if group is not None: all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group) @@ -130,7 +130,7 @@ def _fused_reverse_kl_base( else: # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) # where E_q[log(q/p)] is the expected log ratio under the student distribution - grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output + grad = (log_ratio - per_sample_loss) * predicted_probability * grad_output return per_sample_loss, grad diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py index 9c06c191..35d1ef64 100644 --- a/tests/functional/test_entropy_loss.py +++ b/tests/functional/test_entropy_loss.py @@ -82,14 +82,13 @@ def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_maskin out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch) out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused) - # TODO: Why is the error so high with loss masking for reverse KL? _compare_entropy_loss_outputs( out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, - loss_min_threshold=2e-4 if entropy_loss_type == EntropyLossType.reverse_kl and loss_masking else 5e-6, + loss_min_threshold=5e-6, ) if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available():