Skip to content

Commit c4933ec

Browse files
[Wan 2.2][Diffusion] Add TP Support (#964)
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
1 parent 398ae95 commit c4933ec

File tree

3 files changed

+239
-56
lines changed

3 files changed

+239
-56
lines changed

docs/user_guide/diffusion/parallelism_acceleration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ The following table shows which models are currently supported by parallelism me
4949

5050
| Model | Model Identifier | Ulysses-SP | Ring-SP | Tensor-Parallel |
5151
|-------|------------------|------------|---------|--------------------------|
52-
| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` ||| |
52+
| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` ||| |
5353

5454
### Tensor Parallelism
5555

examples/offline_inference/text_to_video/text_to_video.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,12 @@ def parse_args() -> argparse.Namespace:
109109
choices=[1, 2],
110110
help="Number of GPUs used for classifier free guidance parallel size.",
111111
)
112-
112+
parser.add_argument(
113+
"--tensor_parallel_size",
114+
type=int,
115+
default=1,
116+
help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
117+
)
113118
return parser.parse_args()
114119

115120

@@ -141,6 +146,7 @@ def main():
141146
ulysses_degree=args.ulysses_degree,
142147
ring_degree=args.ring_degree,
143148
cfg_parallel_size=args.cfg_parallel_size,
149+
tensor_parallel_size=args.tensor_parallel_size,
144150
)
145151

146152
# Check if profiling is requested via environment variable
@@ -173,7 +179,7 @@ def main():
173179
print(f" Inference steps: {args.num_inference_steps}")
174180
print(f" Frames: {args.num_frames}")
175181
print(
176-
f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}"
182+
f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}"
177183
)
178184
print(f" Video size: {args.width}x{args.height}")
179185
print(f"{'=' * 60}\n")

0 commit comments

Comments
 (0)