Skip to content

Commit 32705dc

Browse files
authored
[trainer] feat: add padding for tensor alignment in preprocess_thd_no_padding function (#5410)
### What does this PR do? Fix the context parallel to align the tensor size ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent 5f7c345 commit 32705dc

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

verl/models/mcore/util.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
import math
18+
import os
1719

1820
import torch
1921
from megatron.core import parallel_state as mpu
2022
from megatron.core.packed_seq_params import PackedSeqParams
2123

2224
from verl.utils.model import CausalLMOutputForPPO
2325

26+
logger = logging.getLogger(__file__)
27+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
28+
2429

2530
def preprocess_packed_seqs(
2631
input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding=False
@@ -333,6 +338,19 @@ def preprocess_thd_no_padding(
333338
start_idx = cu_seqlens_padded_cpu[i] // cp_size
334339
# split to 2 chunks
335340
d = input_ids[i]
341+
# If the number of elements in `d` is smaller than the required
342+
# alignment size, pad the tensor with zeros so that its total
343+
# length matches `align_size`. This ensures size alignment for
344+
# downstream operations (e.g., communication or memory alignment).
345+
if d.numel() < align_size:
346+
original_size = d.numel()
347+
pad = torch.zeros(align_size - d.numel(), dtype=d.dtype, device=d.device)
348+
d = torch.cat([d, pad], dim=0)
349+
logger.warning_once(
350+
f"Padding tensor for context parallel alignment, original_size={original_size}, "
351+
f"align_size={align_size}"
352+
)
353+
336354
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[
337355
half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)
338356
]

0 commit comments

Comments
 (0)