Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add option to pass kwargs to CUDA graph module #945

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 203 additions & 6 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# See LICENSE for license information.

from dataclasses import dataclass
from typing import List, Tuple
import itertools
from typing import Iterable, List, Tuple, Union
import pytest

import torch
Expand Down Expand Up @@ -88,7 +89,7 @@ def generate_data(
dpa: bool = False,
warmup: bool = False,
return_grad_output: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
if dpa:
Expand Down Expand Up @@ -129,14 +130,20 @@ def generate_data(
return inputs, grad_output


def get_outputs(model, output):
def get_outputs(
model: torch.nn.Module,
output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
"""Return grads and params for comparsion."""
values = []
for param in model.parameters():
values.append(param)
if param.grad is not None:
values.append(param.grad)
values.append(output)
if isinstance(output, torch.Tensor):
values.append(output)
else:
values.extend(output)
return values


Expand All @@ -161,7 +168,7 @@ def _test_cuda_graphs(
module: str,
graph_mode: str,
) -> List[torch.Tensor]:
"""Helper function for test."""
"""Helper function for CUDA graph test."""
reset_rng_states()
FP8GlobalStateManager.reset()
dpa = module == "dpa"
Expand Down Expand Up @@ -247,7 +254,7 @@ def _test_cuda_graphs(
else:
model = modules[0] if dpa else _Sequential(*modules)

# Loss function and optimizer.
# Optimizer.
if not dpa:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

Expand Down Expand Up @@ -312,3 +319,193 @@ def test_gpt_make_graphed_callables(
# Check that results match
assert_all_equal(outputs, graph_outputs_mode1)
assert_all_equal(outputs, graph_outputs_mode2)


def _test_cuda_graphs_with_kwargs(
*,
config: ModelConfig,
dtype: torch.dtype,
with_graph: bool,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()

# Initialize model.
model = TransformerLayer(
config.hidden_size,
config.hidden_size,
config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
self_attn_mask_type="arbitrary",
fuse_qkv_params=True,
params_dtype=dtype,
)

# Initialize gradient buffers.
for param in model.parameters():
param.grad = torch.empty_like(param)

# Make graphed version of model if needed.
if with_graph:
attn_mask = torch.zeros(
(config.batch_size, 1, config.sequence_length, config.sequence_length),
dtype=torch.bool,
device="cuda",
)
model = make_graphed_callables(
model,
generate_data(config, dtype, warmup=True),
sample_kwargs=dict(attention_mask=attn_mask),
allow_unused_input=True,
)

# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# Training loop.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, return_grad_output=True)
attn_mask = torch.randint(
2,
(config.batch_size, 1, config.sequence_length, config.sequence_length),
dtype=torch.bool,
device="cuda",
)
output = model(*inputs, attention_mask=attn_mask)
output.backward(grad_output)
optimizer.step()

return get_outputs(model, output)


def test_make_graphed_callables_with_kwargs(
dtype: torch.dtype = torch.float32,
model: str = "small",
) -> None:
"""Test CUDA graphs with keyword arguments."""
config = model_configs[model]
kwargs = dict(config=config, dtype=dtype)
outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
*,
config: ModelConfig,
dtype: torch.dtype,
with_graph: bool,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()

# Pipeline parallel configuration.
num_layers = 2
num_microbatches = 3
layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]

# Initialize model.
model = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
]
)

# Initialize gradient buffers.
for param in model.parameters():
param.grad = torch.empty_like(param)

# Make graphed version of model if needed.
layer_forwards = {
(i % num_layers, i // num_layers): model[i % num_layers]
for i in range(num_layers * num_microbatches)
}
if with_graph:
sample_args = tuple(
generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches)
)
layer_forwards = make_graphed_callables(
tuple(model),
sample_args,
allow_unused_input=True,
_order=layer_order,
)
layer_forwards = {
(i // num_microbatches, i % num_microbatches): forward
for i, forward in enumerate(layer_forwards)
}

# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# Training loop.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)

# Generate data.
inputs = {}
grad_outputs = {}
for layer_idx in range(num_layers):
for microbatch_idx in range(num_microbatches):
x, dy = generate_data(config, dtype, return_grad_output=True)
idxs = (layer_idx, microbatch_idx)
inputs[idxs] = x[0]
grad_outputs[idxs] = dy

# Cache for layer outputs.
outputs = {}

def forward(layer_idx: int, microbatch_idx: int):
"""Helper function for forward steps"""
idxs = (layer_idx, microbatch_idx)
outputs[idxs] = layer_forwards[idxs](inputs[idxs])

def backward(layer_idx: int, microbatch_idx: int):
"""Helper function for backward steps"""
outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])

# Forward and backward steps.
forward(0, 0)
forward(1, 0)
forward(0, 1)
forward(1, 1)
backward(1, 0)
backward(0, 0)
forward(0, 2)
forward(1, 2)
backward(1, 1)
backward(0, 1)
backward(1, 2)
backward(0, 2)

# Optimizer step.
optimizer.step()

outputs = [y for _, y in sorted(outputs.items())]
return get_outputs(model, outputs)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
dtype: torch.dtype = torch.float16,
model: str = "small",
) -> None:
"""Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
config = model_configs[model]
kwargs = dict(config=config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
Loading
Loading