Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
c2395e0
WIP: bringing Yifan's changes to main
pzelasko Feb 4, 2026
8d4c570
Add workaround for exp_manager issue
pzelasko Feb 10, 2026
ff54b12
Support reading indexed JSONL datasets with ShareGPT format
pzelasko Feb 10, 2026
b9e7b23
Merge remote-tracking branch 'origin/speechlm-yifan-mod-port' into sp…
pzelasko Feb 10, 2026
4a6324d
Support reading indexed tarred datasets with ShareGPT format
pzelasko Feb 10, 2026
9a2f78a
Refactor for compactness
pzelasko Feb 10, 2026
e222048
Fixes for real-life data
pzelasko Feb 10, 2026
c538a45
Fixes for real-life data
pzelasko Feb 11, 2026
9fc4b72
Fixes for real-life data
pzelasko Feb 11, 2026
4b4c529
Fixes for missing wids-meta.json
pzelasko Feb 11, 2026
fc3dffb
Fixes for tarfile edge cases
pzelasko Feb 11, 2026
c45ea47
Fixes for real-world tar files
pzelasko Feb 11, 2026
c80ed96
move salm llm init to configure_model
pzelasko Feb 12, 2026
794d300
fix: delayed perception init
pzelasko Feb 12, 2026
0516726
Add AutomodelParallelStrategy for Automodel LLM support
pzelasko Feb 12, 2026
c6c818c
Merge branch 'speechlm2-with-nemo-automodel' of https://github.com/NV…
pzelasko Feb 12, 2026
024c8d0
Replace HF Automodel with NeMo Automodel for SALM's LLM backbone
pzelasko Feb 12, 2026
c4a2a3b
Update salm default config with new options
pzelasko Feb 12, 2026
162117f
Init fixes
pzelasko Feb 12, 2026
43e1bb1
Fix dtype initialization
pzelasko Feb 12, 2026
20a2824
Fix mesh selection for speech encoder
pzelasko Feb 12, 2026
cd6ddf3
Fix for mismatched device_mesh axis names in gradient clipping - use …
pzelasko Feb 12, 2026
ff4beab
Fix for using embed_tokens in FSDP context before running forward on …
pzelasko Feb 12, 2026
b3658b1
Definitive fix for using embed_tokens outside of llm with fsdp
pzelasko Feb 12, 2026
71b6744
this version actually works with Automodel
pzelasko Feb 13, 2026
a5d33d2
fix from_pretrained with transformers v5
pzelasko Feb 17, 2026
1d9ed29
fix from_pretrained with transformers v5
pzelasko Feb 17, 2026
aaf828a
fix generate/eval
pzelasko Feb 17, 2026
4732230
fix to_hf
pzelasko Feb 17, 2026
f4bb443
Fixes for AutoTokenizer decoding in v5
pzelasko Feb 18, 2026
4c21c4d
Flag to run configure_model() at the end of __init__ for safetensors …
pzelasko Feb 18, 2026
e09418c
preliminary: support distributed models in to_hf.py
pzelasko Feb 19, 2026
a54828c
fix passing automodel kwargs
pzelasko Feb 19, 2026
cf40405
fix
pzelasko Feb 19, 2026
2b7f9d0
Enable inference with model parallelism
pzelasko Feb 19, 2026
05c69b8
Fix for lightning save_hyperparameters() call
pzelasko Feb 19, 2026
cf0b97f
Fix for loading into DTensor
pzelasko Feb 19, 2026
5c84827
Accelerate loading DTensor
pzelasko Feb 19, 2026
d595b7b
Accelerate loading DTensor
pzelasko Feb 19, 2026
b4ec5d2
Accelerate loading DTensor
pzelasko Feb 19, 2026
b6c8725
Fix for pe buffers not in ckpt (essentially strict=False)
pzelasko Feb 19, 2026
823c4ab
Add Nemotron Nano v3 prompt formatter with <think> reasoning support
pzelasko Feb 21, 2026
b241c67
fix
pzelasko Feb 21, 2026
80ff976
Automodel LoRA support
pzelasko Feb 23, 2026
42678cf
Merge branch 'main' into speechlm2-with-nemo-automodel-merge
pzelasko Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions examples/speechlm2/conf/salm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ model:

# Regexp (re.compile) patterns matching parameters to be frozen.
freeze_params:
# Frozen LLM
- "^llm\\..+$" # LLM
- "^embed_tokens\\..+$" # LLM embedding is moved
# Frozen LLM (embed_tokens stays inside llm, so this pattern covers it too)
- "^llm\\..+$"
# Frozen pretrained ASR (only the modality adapter layers are trainable)
- "^perception\\.preprocessor\\..+$"
- "^perception\\.encoder\\..+$"
Expand All @@ -18,16 +17,19 @@ model:
prompt_format: qwen
audio_locator_tag: "<|audioplaceholder|>" # placeholder token for audio turn is expected

# Note: Uncomment the block below to enable LoRA on LLM via HuggingFace PEFT library.
# It will automatically freeze LLM parameters even if freeze_params was unused,
# and prevent freezing any parameter that has the string '.lora_' in its name.
# Uncomment the block below to enable LoRA on the LLM via Automodel.
# LoRA parameters are kept trainable even when the LLM is frozen.
# lora:
# task_type: CAUSAL_LM
# r: 128
# lora_alpha: 256
# lora_dropout: 0.01
# # target_modules are only necessary if the `pretrained_llm` is not yet registered in PEFT library
# dim: 128
# alpha: 256
# dropout: 0.01
# target_modules: ["q_proj", "v_proj"]
# # match_all_linear: false
# # exclude_modules: []
# # use_dora: false
# # dropout_position: post
# # lora_A_init: xavier
# # use_triton: false

perception:
target: nemo.collections.speechlm2.modules.perception.AudioPerceptionModule
Expand Down Expand Up @@ -72,13 +74,38 @@ trainer:
gradient_clip_val: 1.0
accumulate_grad_batches: 1
strategy:
# Replace DDPStrategy with ModelParallelStrategy to enable model parallelism
# Replace DDPStrategy with AutomodelParallelStrategy to enable FSDP2/TP/MoE parallelism.
# AutomodelParallelStrategy delegates device mesh creation to nemo_automodel and supports
# FSDP2, TP, PP, CP, EP (MoE), and HSDP. The model's configure_model() receives the
# device_mesh and passes it to automodel's from_pretrained for memory-efficient loading
# (each GPU only loads its own shard).
_target_: lightning.pytorch.strategies.DDPStrategy
gradient_as_bucket_view: true
find_unused_parameters: true
# _target_: lightning.pytorch.strategies.ModelParallelStrategy
# tensor_parallel_size: 1
# data_parallel_size: 8 # This is FSDP2

# _target_: nemo.collections.speechlm2.parts.parallel.AutomodelParallelStrategy
#
# --- Parallelism dimensions ---
# dp_size: null # Data parallel size (null = inferred from world_size / other dims)
# dp_replicate_size: 1 # HSDP replication group size (>1 enables hybrid sharding)
# tp_size: 1 # Tensor parallel size
# pp_size: 1 # Pipeline parallel size
# cp_size: 1 # Context parallel size
# ep_size: 1 # Expert parallel size (for MoE models)
#
# --- FSDP2 distributed config (plain dict, resolved to FSDP2Config automatically) ---
# distributed_config:
# sequence_parallel: false # Enable sequence parallelism (requires tp_size > 1)
# activation_checkpointing: false # Checkpoint activations to save memory
# # offload_policy: # Uncomment to enable CPU offloading
# # _target_: torch.distributed.fsdp.CPUOffloadPolicy
#
# --- MoE config (plain dict, resolved to MoEParallelizerConfig automatically) ---
# moe_config:
# activation_checkpointing: false # Checkpoint activations in MoE blocks
# reshard_after_forward: false # Reshard params after forward (saves memory, more comms)
#
# save_distributed_checkpoint: true # Each rank saves its shard (false = gather to rank 0)

data:
train_ds:
Expand Down
29 changes: 25 additions & 4 deletions examples/speechlm2/salm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,38 @@ class SalmEvalConfig:
system_prompt: Optional[str] = None
user_prompt: Optional[str] = None
use_asr_decoder: bool = False # set this to True if using SALMWithAsrDecoder
# Parallelism sizes for distributed inference (launch with torchrun)
tp_size: int = 1
ep_size: int = 1
pp_size: int = 1
cp_size: int = 1


@hydra_runner(config_name="SalmEvalConfig", schema=SalmEvalConfig)
def main(cfg: SalmEvalConfig):
logging.info(f'Hydra config:\n{OmegaConf.to_yaml(cfg)}')

if cfg.use_asr_decoder:
model = SALMWithAsrDecoder.from_pretrained(cfg.pretrained_name)
is_distributed = any(s > 1 for s in [cfg.tp_size, cfg.ep_size, cfg.pp_size, cfg.cp_size])
model_cls = SALMWithAsrDecoder if cfg.use_asr_decoder else SALM

if is_distributed:
from nemo.collections.speechlm2.parts.parallel import setup_distributed

strategy = setup_distributed(
tp_size=cfg.tp_size, ep_size=cfg.ep_size, pp_size=cfg.pp_size, cp_size=cfg.cp_size
)
model = model_cls.from_pretrained(
cfg.pretrained_name,
device_mesh=strategy.device_mesh,
distributed_config=strategy.distributed_config,
moe_config=strategy.moe_config,
moe_mesh=strategy.moe_mesh,
torch_dtype=cfg.dtype,
)
else:
model = SALM.from_pretrained(cfg.pretrained_name)
model = model.eval().to(getattr(torch, cfg.dtype)).to(cfg.device)
model = model_cls.from_pretrained(cfg.pretrained_name)
model = model.to(getattr(torch, cfg.dtype)).to(cfg.device)
model = model.eval()

cuts = guess_parse_cutset(cfg.inputs).sort_by_duration()
dloader = torch.utils.data.DataLoader(
Expand Down
29 changes: 25 additions & 4 deletions examples/speechlm2/salm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,38 @@ class SalmEvalConfig:
system_prompt: Optional[str] = None
user_prompt: Optional[str] = None
use_asr_decoder: bool = False # set this to True if using SALMWithAsrDecoder
# Parallelism sizes for distributed inference (launch with torchrun)
tp_size: int = 1
ep_size: int = 1
pp_size: int = 1
cp_size: int = 1


@hydra_runner(config_name="SalmEvalConfig", schema=SalmEvalConfig)
def main(cfg: SalmEvalConfig):
logging.info(f"Hydra config:\n{OmegaConf.to_yaml(cfg)}")

if cfg.use_asr_decoder:
model = SALMWithAsrDecoder.from_pretrained(cfg.pretrained_name)
is_distributed = any(s > 1 for s in [cfg.tp_size, cfg.ep_size, cfg.pp_size, cfg.cp_size])
model_cls = SALMWithAsrDecoder if cfg.use_asr_decoder else SALM

if is_distributed:
from nemo.collections.speechlm2.parts.parallel import setup_distributed

strategy = setup_distributed(
tp_size=cfg.tp_size, ep_size=cfg.ep_size, pp_size=cfg.pp_size, cp_size=cfg.cp_size
)
model = model_cls.from_pretrained(
cfg.pretrained_name,
device_mesh=strategy.device_mesh,
distributed_config=strategy.distributed_config,
moe_config=strategy.moe_config,
moe_mesh=strategy.moe_mesh,
torch_dtype=cfg.dtype,
)
else:
model = SALM.from_pretrained(cfg.pretrained_name)
model = model.eval().to(getattr(torch, cfg.dtype)).to(cfg.device)
model = model_cls.from_pretrained(cfg.pretrained_name)
model = model.to(getattr(torch, cfg.dtype)).to(cfg.device)
model = model.eval()

conversations = (
guess_parse_cutset(cfg.inputs)
Expand Down
6 changes: 4 additions & 2 deletions examples/speechlm2/salm_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
from nemo.utils.exp_manager import exp_manager
from nemo.utils.trainer_utils import resolve_trainer_cfg

torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
if torch.cuda.is_available():
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


@hydra_runner(config_path="conf", config_name="salm")
def train(cfg):
OmegaConf.resolve(cfg)
torch.distributed.init_process_group(backend="nccl")
if torch.cuda.is_available():
torch.distributed.init_process_group(backend="nccl")
torch.set_float32_matmul_precision("medium")
trainer = Trainer(**resolve_trainer_cfg(cfg.trainer))
log_dir = exp_manager(trainer, cfg.get("exp_manager", None))
Expand Down
134 changes: 126 additions & 8 deletions examples/speechlm2/to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass
from pathlib import Path

import torch
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import save_file

from nemo.core.config import hydra_runner
from nemo.utils.model_utils import import_class_by_path


@dataclass
class HfExportConfig:
# Name of the model class to be imported, e.g. nemo.collections.speechlm2.models.DuplexS2SModel
# Name of the model class to be imported, e.g. nemo.collections.speechlm2.models.SALM
class_path: str

# Path to PyTorch Lightning checkpoint file (normal ckpt) or directory (distributed ckpt)
Expand Down Expand Up @@ -51,21 +54,136 @@ def load_checkpoint(model: torch.nn.Module, checkpoint_path: str):
model.load_state_dict(ckpt_data["state_dict"])


def setup_distributed_from_config(strategy_cfg: dict):
"""Initialize torch.distributed and create a device mesh from a Hydra strategy config.

Instantiates the strategy from the trainer config dict (as found in the
experiment YAML), initializes the process group, resolves automodel
configs, and calls :meth:`strategy.create_device_mesh`.

Returns:
An :class:`AutomodelParallelStrategy` with device_mesh ready.
"""
import hydra
import torch.distributed as dist

from nemo.utils.trainer_utils import _resolve_automodel_configs

if not dist.is_initialized():
dist.init_process_group(backend="nccl")

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

strategy = hydra.utils.instantiate(strategy_cfg)
_resolve_automodel_configs(strategy)
strategy.create_device_mesh()
return strategy


def consolidate_state_dict(model: torch.nn.Module):
"""Gather a full (non-sharded) state dict from a model with DTensor parameters."""
from torch.distributed.tensor import DTensor

consolidated = {}
for key, value in model.state_dict().items():
if isinstance(value, DTensor):
consolidated[key] = value.full_tensor().cpu()
else:
consolidated[key] = value.cpu()
return consolidated


def save_hf_checkpoint(model: torch.nn.Module, state_dict: dict, cfg: HfExportConfig):
"""Save a consolidated state dict and model config in HuggingFace Hub format."""
output_dir = Path(cfg.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

target_dtype = getattr(torch, cfg.dtype)
state_dict = {k: v.to(target_dtype) for k, v in state_dict.items()}

save_file(state_dict, output_dir / "model.safetensors")

config = OmegaConf.to_container(model.cfg) if isinstance(model.cfg, DictConfig) else model.cfg
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2)


def _uses_automodel_parallel(strategy_cfg: dict) -> bool:
"""Check if the strategy config targets AutomodelParallelStrategy."""
target = strategy_cfg.get("_target_", "")
return "AutomodelParallelStrategy" in target


@hydra_runner(config_name="HfExportConfig", schema=HfExportConfig)
def main(cfg: HfExportConfig):
"""
Read PyTorch Lightning checkpoint and export the model to HuggingFace Hub format.
The resulting model can be then initialized via ModelClass.from_pretrained(path).

Also supports distributed checkpoints for models trained with FSDP2/TP.
Also supports distributed checkpoints for models trained with FSDP2/TP
via AutomodelParallelStrategy. Parallelism sizes (tp_size, pp_size, etc.)
are read automatically from the ``trainer.strategy`` section of the
experiment config (``ckpt_config``).

When the checkpoint is a distributed checkpoint (a directory), launch this
script via ``torchrun`` with the same number of GPUs used for training.

Examples:
# Single-file checkpoint (no parallelism needed):
python to_hf.py \\
class_path=nemo.collections.speechlm2.models.SALM \\
ckpt_path=/path/to/checkpoint.ckpt \\
ckpt_config=/path/to/config.yaml \\
output_dir=/path/to/hf_output

# Distributed checkpoint (parallelism read from config automatically):
torchrun --nproc-per-node=8 to_hf.py \\
class_path=nemo.collections.speechlm2.models.SALM \\
ckpt_path=/path/to/distributed_ckpt_dir \\
ckpt_config=/path/to/config.yaml \\
output_dir=/path/to/hf_output
"""
model_cfg = OmegaConf.to_container(OmegaConf.load(cfg.ckpt_config).model, resolve=True)
full_cfg = OmegaConf.to_container(OmegaConf.load(cfg.ckpt_config), resolve=True)
model_cfg = full_cfg["model"]
model_cfg["torch_dtype"] = cfg.dtype
cls = import_class_by_path(cfg.class_path)
model = cls(model_cfg)
load_checkpoint(model, cfg.ckpt_path)
model = model.to(getattr(torch, cfg.dtype))
model.save_pretrained(cfg.output_dir)

strategy_cfg = full_cfg.get("trainer", {}).get("strategy", {})
is_distributed = Path(cfg.ckpt_path).is_dir() and _uses_automodel_parallel(strategy_cfg)

if is_distributed:
import torch.distributed as dist

strategy = setup_distributed_from_config(strategy_cfg)

# Don't call configure_model() inside __init__ — we set device_mesh first.
model_cfg["init_configure_model"] = False
model = cls(model_cfg)
model.configure_model(
device_mesh=strategy.device_mesh,
distributed_config=strategy.distributed_config,
moe_config=strategy.moe_config,
moe_mesh=strategy.moe_mesh,
)
model_cfg["pretrained_weights"] = False

load_checkpoint(model, cfg.ckpt_path)

# Consolidate DTensors to regular tensors and save on rank 0.
consolidated = consolidate_state_dict(model)
if dist.get_rank() == 0:
save_hf_checkpoint(model, consolidated, cfg)

dist.barrier()
dist.destroy_process_group()
else:
model_cfg["init_configure_model"] = True
model = cls(model_cfg)
load_checkpoint(model, cfg.ckpt_path)
model = model.to(getattr(torch, cfg.dtype))
model_cfg["pretrained_weights"] = False
model.save_pretrained(cfg.output_dir)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/common/data/lhotse/text_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'json' is not used.

Copilot Autofix

AI 1 day ago

In general, the correct way to fix an unused import in Python is to remove the import statement if the module is never referenced in the file. This reduces visual clutter, avoids implying unnecessary dependencies, and can slightly speed up module import time.

Here, the best fix is to delete the import json line in nemo/collections/common/data/lhotse/text_adapters.py (line 14 in the provided snippet), leaving the rest of the imports unchanged. No additional methods, definitions, or replacement imports are needed, since no code in the shown region uses json. This change preserves all existing functionality because it only removes an unused symbol.

Suggested changeset 1
nemo/collections/common/data/lhotse/text_adapters.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py
--- a/nemo/collections/common/data/lhotse/text_adapters.py
+++ b/nemo/collections/common/data/lhotse/text_adapters.py
@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import logging
 import math
 import random
EOF
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import math
import random
Copilot is powered by AI and may make mistakes. Always verify output.
import logging
import math
import random
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/common/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nemo.collections.common.prompts.llama import Llama2PromptFormatter, Llama3PromptFormatter
from nemo.collections.common.prompts.mistral import MistralPromptFormatter
from nemo.collections.common.prompts.nemotron_h import NemotronHPromptFormatter
from nemo.collections.common.prompts.nemotron_nano_v3 import NemotronNanoV3PromptFormatter
from nemo.collections.common.prompts.phi2 import (
Phi2ChatPromptFormatter,
Phi2CodePromptFormatter,
Expand Down
Loading
Loading