Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions test/chunked_loss/test_jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,58 @@ def forward(self, student_input, teacher_input, target):
)
return jsd_loss

def backward_with_grad_and_value(self, student_input, teacher_input, target):
"""
Compute gradients using grad_and_value on NPU to match Liger implementation.
This method is used in tests on NPU devices to ensure consistency.
"""
# Use grad_and_value to compute gradients and loss
if self.student_lin.bias is not None:

def loss_fn(student_input, student_weight, student_bias):
return self.jsd(
student_input,
student_weight,
teacher_input,
self.teacher_lin.weight,
target,
student_bias,
self.teacher_lin.bias,
beta=self.beta,
)

(grad_input, grad_weight, grad_bias), loss = torch.func.grad_and_value(loss_fn, argnums=(0, 1, 2))(
student_input, self.student_lin.weight, self.student_lin.bias
)

# Set gradients
student_input.grad = grad_input
self.student_lin.weight.grad = grad_weight
self.student_lin.bias.grad = grad_bias
else:

def loss_fn(student_input, student_weight):
return self.jsd(
student_input,
student_weight,
teacher_input,
self.teacher_lin.weight,
target,
None, # student_bias is None when bias=False
self.teacher_lin.bias,
beta=self.beta,
)

(grad_input, grad_weight), loss = torch.func.grad_and_value(loss_fn, argnums=(0, 1))(
student_input, self.student_lin.weight
)

# Set gradients
student_input.grad = grad_input
self.student_lin.weight.grad = grad_weight

return loss


class LigerLMHeadJSD(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -261,12 +313,18 @@ def test_correctness(
target[indices_to_assign] = ignore_index

# Assign some random number of elements as ignore_index
loss1 = torch_lm_head_jsd(student_input1, teacher_input, target)
loss2 = liger_lm_head_jsd(student_input2, teacher_input, target)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

loss1.backward()
loss2.backward()
# On NPU, use grad_and_value for reference implementation to match Liger implementation
if device == "npu":
loss1 = torch_lm_head_jsd.backward_with_grad_and_value(student_input1, teacher_input, target)
loss2 = liger_lm_head_jsd(student_input2, teacher_input, target)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
loss2.backward()
else:
loss1 = torch_lm_head_jsd(student_input1, teacher_input, target)
loss2 = liger_lm_head_jsd(student_input2, teacher_input, target)
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
loss1.backward()
loss2.backward()

assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol)

Expand Down
Loading