Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/functional/entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _fused_reverse_kl_base(
# Compute loss terms: student_probs * log_ratio, then sum over vocab
# This is equivalent to kl_div(..., log_target=True) but more memory efficient
log_ratio = predicted_log_probability - target_log_probability
per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1)
per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1, keepdim=True)
if group is not None:
all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group)

Expand All @@ -130,7 +130,7 @@ def _fused_reverse_kl_base(
else:
# Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)])
# where E_q[log(q/p)] is the expected log ratio under the student distribution
grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output
grad = (log_ratio - per_sample_loss) * predicted_probability * grad_output

return per_sample_loss, grad

Expand Down
3 changes: 1 addition & 2 deletions tests/functional/test_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@ def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_maskin
out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch)
out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused)

# TODO: Why is the error so high with loss masking for reverse KL?
_compare_entropy_loss_outputs(
out_fused,
out_torch,
grad_output is not None,
grad_fused,
grad_torch,
loss_min_threshold=2e-4 if entropy_loss_type == EntropyLossType.reverse_kl and loss_masking else 5e-6,
loss_min_threshold=5e-6,
)

if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available():
Expand Down