Skip to content

Commit 200bc6d

Browse files
[model] fix: fix temp dtype (#4813)
### What does this PR do? - As title. Prevent temperature to be int. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 2bb42ba commit 200bc6d

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

verl/workers/engine/fsdp/transformer_impl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ def prepare_model_inputs(self, micro_batch: TensorDict):
736736
if not isinstance(temperature, torch.Tensor):
737737
temperature = torch.tensor([temperature] * input_ids.shape[0], device=input_ids.device)
738738

739+
temperature = temperature.to(torch.float32)
739740
assert temperature.shape[0] == input_ids.shape[0]
740741

741742
# args used to get outputs
@@ -876,7 +877,7 @@ def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):
876877
entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,)
877878
else:
878879
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
879-
logits_rmpad.div_(temperature_rmpad.unsqueeze(-1))
880+
logits_rmpad.div_(temperature_rmpad.clamp(min=1e-8).unsqueeze(-1).to(logits_rmpad.dtype))
880881

881882
# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
882883
inplace_backward = True
@@ -935,7 +936,7 @@ def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict):
935936
logits = output.logits # (bsz, response_length, vocab_size)
936937
temperature = output_args["temperature"] # (bsz,)
937938
temperature = temperature.unsqueeze(-1).unsqueeze(-1)
938-
logits.div_(temperature)
939+
logits.div_(temperature.clamp(min=1e-8).to(logits.dtype))
939940

940941
if calculate_entropy:
941942
if not self.engine_config.entropy_checkpointing:

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ def forward_step(self, batch_iter: Iterator[TensorDict], model, postprocess_micr
619619
if not isinstance(temperature, torch.Tensor):
620620
temperature = torch.tensor([temperature] * input_ids.shape[0], device=input_ids.device)
621621

622+
temperature = temperature.to(torch.float32)
622623
assert temperature.shape[0] == input_ids.shape[0]
623624
temperature = verl_F.expand_as_nested(temperature, input_ids) # (bsz, j1)
624625

@@ -639,7 +640,7 @@ def logits_processor(logits, label, temperature):
639640
# avoid non-positive temperature such as padding
640641
temperature[temperature <= 0] = 1e-8
641642
assert torch.all(temperature > 0).item(), f"temperature tensor must be positive. Got {temperature}"
642-
logits.div_(temperature.unsqueeze(dim=-1))
643+
logits.div_(temperature.unsqueeze(dim=-1).to(logits.dtype))
643644
ret = {}
644645
if calculate_entropy:
645646
logits_bak = logits.clone()

0 commit comments

Comments
 (0)