Skip to content

Conversation

@noemotiovon
Copy link
Contributor

@noemotiovon noemotiovon commented Feb 2, 2026

Add backward_with_grad_and_value method to TorchLMHeadJSD to compute gradients using torch.func.grad_and_value on NPU devices. This ensures the reference implementation uses the same gradient computation method as the Liger implementation, allowing tests to pass on NPU despite known precision issues with grad_and_value on this device.

Changes:

  • Add backward_with_grad_and_value method to TorchLMHeadJSD class
  • Update test_correctness to use grad_and_value for reference model on NPU
  • Handle both bias=True and bias=False cases in backward_with_grad_and_value

Hardware Type: Atlas 800I A2

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Add backward_with_grad_and_value method to TorchLMHeadJSD to compute
gradients using torch.func.grad_and_value on NPU devices. This ensures
the reference implementation uses the same gradient computation method
as the Liger implementation, allowing tests to pass on NPU despite known
precision issues with grad_and_value on this device.

Changes:
- Add backward_with_grad_and_value method to TorchLMHeadJSD class
- Update test_correctness to use grad_and_value for reference model on NPU
- Handle both bias=True and bias=False cases in backward_with_grad_and_value
@noemotiovon
Copy link
Contributor Author

We found minor numerical discrepancies between grad_and_value and Autograd in the core JSD computation, and have filed a corresponding issue with torch-npu.

@noemotiovon
Copy link
Contributor Author

Test Script:
torch.inductor support in torch-npu 2.6.0 is not yet fully mature, so we disable torch.compile by default.

TORCH_COMPILE_DISABLE=1 python -m pytest test/chunked_loss/test_jsd_loss.py

Test Result:

========================================================================= slowest durations =========================================================================
12.25s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-True-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.09s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-True-1.0-dtype0-0.05-0.5-3-47-31-123]
0.05s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.04s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-0.5-1.0-0.0-0.2-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.04s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-2.0-0.0-1.0-0.8-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.04s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-2.0-0.0-1.0-0.8-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.03s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-1.0-0.5-0.5-0.5-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.03s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[1.0-0.5-0.5-0.5--100-True-1.0-dtype0-0.05-0.05-2-2-8-8]
0.03s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[1.0-0.5-0.5-0.5--100-True-1.0-dtype0-0.05-0.05-9-7-41-41]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[1.0-0.5-0.5-0.5--100-False-1.0-dtype0-0.05-0.05-2-2-8-8]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[1.0-0.5-0.5-0.5--100-True-1.0-dtype1-0.0001-0.005-9-7-41-41]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[1.0-0.5-0.5-0.5--100-True-1.0-dtype1-0.0001-0.005-2-2-8-8]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-0.5-1.0-0.0-0.2-True-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-0.5-1.0-0.0-0.2-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[2.0-0.1-0.9-0.5-42-False-1.0-dtype0-0.05-0.05-9-7-41-41]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-1.0-0.5-0.5-0.5-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-0.5-1.0-0.0-0.2-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-2.0-0.0-1.0-0.8-True-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[2.0-0.1-0.9-0.5-42-True-1.0-dtype0-0.05-0.05-9-7-41-41]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[2.0-0.1-0.9-0.5-42-True-1.0-dtype0-0.05-0.05-2-2-8-8]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-True-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-2.0-0.0-1.0-0.8-True-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-1.0-0.5-0.5-0.5-True-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[2.0-0.1-0.9-0.5-42-True-1.0-dtype1-0.0001-0.005-2-2-8-8]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[2.0-0.1-0.9-0.5-42-True-1.0-dtype1-0.0001-0.005-9-7-41-41]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-0.5-1.0-0.0-0.2-True-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-2.0-0.0-1.0-0.8-True-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-0.5-1.0-0.0-0.2-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-1.0-0.5-0.5-0.5-True-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-0.5-1.0-0.0-0.2-True-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-2.0-0.0-1.0-0.8-True-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-2.0-0.0-1.0-0.8-True-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-1.0-0.5-0.5-0.5-True-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-0.5-1.0-0.0-0.2-True-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-0.5-1.0-0.0-0.2-True-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-0.5-1.0-0.0-0.2-True-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[1.0-0.5-0.5-0.5--100-False-1.0-dtype0-0.05-0.05-9-7-41-41]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-2.0-0.0-1.0-0.8-True-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-2.0-0.0-1.0-0.8-False-1.0-dtype1-1e-05-0.0005-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[2.0-0.1-0.9-0.5-42-False-1.0-dtype0-0.05-0.05-2-2-8-8]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-2.0-0.0-1.0-0.8-True-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-0.5-1.0-0.0-0.2-False-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[1.0-0.5-0.5-0.5--100-False-1.0-dtype1-0.0001-0.005-9-7-41-41]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[2.0-0.1-0.9-0.5-42-False-1.0-dtype1-0.0001-0.005-9-7-41-41]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-1.0-0.5-0.5-0.5-False-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[2.0-0.1-0.9-0.5-42-False-1.0-dtype1-0.0001-0.005-2-2-8-8]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness_functional[1.0-0.5-0.5-0.5--100-False-1.0-dtype1-0.0001-0.005-2-2-8-8]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-2.0-0.0-1.0-0.8-False-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-False-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-0.5-1.0-0.0-0.2-False-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-2.0-0.0-1.0-0.8-False-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-False-1.0-dtype0-0.05-0.5-8-128-1024-4096]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-1.0-0.5-0.5-0.5-False-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-0.5-1.0-0.0-0.2-False-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-0.5-1.0-0.0-0.2-False-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-0.5-1.0-0.0-0.2-False-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-2.0-0.0-1.0-0.8-False-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-2.0-0.0-1.0-0.8-False-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-0.5-1.0-0.0-0.2-False-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[42-1.0-0.5-0.5-0.5-False-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-2.0-0.0-1.0-0.8-False-1.0-dtype1-1e-05-0.0005-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-False-1.0-dtype0-0.05-0.5-3-47-31-123]
0.02s call     test/chunked_loss/test_jsd_loss.py::test_correctness[-100-2.0-0.0-1.0-0.8-False-1.0-dtype0-0.05-0.5-3-47-31-123]
0.01s setup    test/chunked_loss/test_jsd_loss.py::test_correctness[-100-1.0-0.5-0.5-0.5-True-1.0-dtype0-0.05-0.5-8-128-1024-4096]

(127 durations < 0.005s hidden.  Use -vv to show these durations.)
======================================================================== 64 passed in 25.73s ========================================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant