diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 8d3a9dca4f..60a5a1ea99 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -3,7 +3,8 @@ # See LICENSE for license information. from dataclasses import dataclass -from typing import List, Tuple +import itertools +from typing import Iterable, List, Tuple, Union import pytest import torch @@ -88,7 +89,7 @@ def generate_data( dpa: bool = False, warmup: bool = False, return_grad_output: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[List[torch.Tensor], torch.Tensor]: """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn if dpa: @@ -129,14 +130,20 @@ def generate_data( return inputs, grad_output -def get_outputs(model, output): +def get_outputs( + model: torch.nn.Module, + output: Union[torch.Tensor, Iterable[torch.Tensor]], +) -> List[torch.Tensor]: """Return grads and params for comparsion.""" values = [] for param in model.parameters(): values.append(param) if param.grad is not None: values.append(param.grad) - values.append(output) + if isinstance(output, torch.Tensor): + values.append(output) + else: + values.extend(output) return values @@ -161,7 +168,7 @@ def _test_cuda_graphs( module: str, graph_mode: str, ) -> List[torch.Tensor]: - """Helper function for test.""" + """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() dpa = module == "dpa" @@ -247,7 +254,7 @@ def _test_cuda_graphs( else: model = modules[0] if dpa else _Sequential(*modules) - # Loss function and optimizer. + # Optimizer. if not dpa: optimizer = torch.optim.SGD(model.parameters(), lr=0.001) @@ -312,3 +319,193 @@ def test_gpt_make_graphed_callables( # Check that results match assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode2) + + +def _test_cuda_graphs_with_kwargs( + *, + config: ModelConfig, + dtype: torch.dtype, + with_graph: bool, +) -> List[torch.Tensor]: + """Simulate Megatron-LM interleaved pipeline parallelism.""" + reset_rng_states() + + # Initialize model. + model = TransformerLayer( + config.hidden_size, + config.hidden_size, + config.num_heads, + hidden_dropout=0.0, + attention_dropout=0.0, + self_attn_mask_type="arbitrary", + fuse_qkv_params=True, + params_dtype=dtype, + ) + + # Initialize gradient buffers. + for param in model.parameters(): + param.grad = torch.empty_like(param) + + # Make graphed version of model if needed. + if with_graph: + attn_mask = torch.zeros( + (config.batch_size, 1, config.sequence_length, config.sequence_length), + dtype=torch.bool, + device="cuda", + ) + model = make_graphed_callables( + model, + generate_data(config, dtype, warmup=True), + sample_kwargs=dict(attention_mask=attn_mask), + allow_unused_input=True, + ) + + # Optimizer. + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + + # Training loop. + for _ in range(3): + optimizer.zero_grad(set_to_none=False) + for grad_accumulation_step in range(2): + inputs, grad_output = generate_data(config, dtype, return_grad_output=True) + attn_mask = torch.randint( + 2, + (config.batch_size, 1, config.sequence_length, config.sequence_length), + dtype=torch.bool, + device="cuda", + ) + output = model(*inputs, attention_mask=attn_mask) + output.backward(grad_output) + optimizer.step() + + return get_outputs(model, output) + + +def test_make_graphed_callables_with_kwargs( + dtype: torch.dtype = torch.float32, + model: str = "small", +) -> None: + """Test CUDA graphs with keyword arguments.""" + config = model_configs[model] + kwargs = dict(config=config, dtype=dtype) + outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs) + graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs) + assert_all_equal(outputs, graph_outputs) + + +def _test_cuda_graphs_with_interleaved_pipeline_parallelism( + *, + config: ModelConfig, + dtype: torch.dtype, + with_graph: bool, +) -> List[torch.Tensor]: + """Simulate Megatron-LM interleaved pipeline parallelism.""" + reset_rng_states() + + # Pipeline parallel configuration. + num_layers = 2 + num_microbatches = 3 + layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1] + + # Initialize model. + model = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + config.hidden_size, + params_dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + # Initialize gradient buffers. + for param in model.parameters(): + param.grad = torch.empty_like(param) + + # Make graphed version of model if needed. + layer_forwards = { + (i % num_layers, i // num_layers): model[i % num_layers] + for i in range(num_layers * num_microbatches) + } + if with_graph: + sample_args = tuple( + generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches) + ) + layer_forwards = make_graphed_callables( + tuple(model), + sample_args, + allow_unused_input=True, + _order=layer_order, + ) + layer_forwards = { + (i // num_microbatches, i % num_microbatches): forward + for i, forward in enumerate(layer_forwards) + } + + # Optimizer. + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + + # Training loop. + for _ in range(3): + optimizer.zero_grad(set_to_none=False) + + # Generate data. + inputs = {} + grad_outputs = {} + for layer_idx in range(num_layers): + for microbatch_idx in range(num_microbatches): + x, dy = generate_data(config, dtype, return_grad_output=True) + idxs = (layer_idx, microbatch_idx) + inputs[idxs] = x[0] + grad_outputs[idxs] = dy + + # Cache for layer outputs. + outputs = {} + + def forward(layer_idx: int, microbatch_idx: int): + """Helper function for forward steps""" + idxs = (layer_idx, microbatch_idx) + outputs[idxs] = layer_forwards[idxs](inputs[idxs]) + + def backward(layer_idx: int, microbatch_idx: int): + """Helper function for backward steps""" + outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx]) + + # Forward and backward steps. + forward(0, 0) + forward(1, 0) + forward(0, 1) + forward(1, 1) + backward(1, 0) + backward(0, 0) + forward(0, 2) + forward(1, 2) + backward(1, 1) + backward(0, 1) + backward(1, 2) + backward(0, 2) + + # Optimizer step. + optimizer.step() + + outputs = [y for _, y in sorted(outputs.items())] + return get_outputs(model, outputs) + + +def test_make_graphed_callables_with_interleaved_pipeline_parallelism( + dtype: torch.dtype = torch.float16, + model: str = "small", +) -> None: + """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" + config = model_configs[model] + kwargs = dict(config=config, dtype=dtype) + outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=False, + **kwargs, + ) + graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=True, + **kwargs, + ) + assert_all_equal(outputs, graph_outputs) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index a6f62ac457..f6331c9b2a 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -3,11 +3,14 @@ # See LICENSE for license information. """Functions for CUDA Graphs support in FP8""" +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union + import torch from torch.utils._pytree import tree_flatten as _tree_flatten from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch._C import _graph_pool_handle +from transformer_engine.common.recipe import DelayedScaling from .fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -22,6 +25,9 @@ _IS_GRAPH_CAPTURING = False +_T = TypeVar("_T") +SingleOrTuple = Union[_T, Tuple[_T, ...]] + def set_capture_start() -> None: """Record beginning of `make_graphed_callables`.""" @@ -48,13 +54,14 @@ def graph_pool_handle(): def _make_graphed_callables( - callables, - sample_args, - num_warmup_iters=3, - allow_unused_input=False, - fp8_weight_caching=False, - _order=None, -): + callables: SingleOrTuple[Callable], + sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + fp8_weight_caching: bool = False, + sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, + _order: Optional[List[int]] = None, +) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` """ @@ -65,16 +72,38 @@ def _make_graphed_callables( "caching. Please set `cache_enabled=False`." ) - just_one_callable = False + # Default is to pass no kwargs to callables + if sample_kwargs is None: + if isinstance(callables, tuple): + sample_kwargs = tuple({} for _ in range(len(sample_args))) + else: + sample_kwargs = {} + # Canonicalize args as tuples + just_one_callable = False if not isinstance(callables, tuple): just_one_callable = True callables = (callables,) sample_args = (sample_args,) + sample_kwargs = (sample_kwargs,) - flatten_sample_args = [] - if _order is not None: - # order is a list containing 1..model_chunk values in the order of microbatch schedule + # Check sizes of args + if _order is None: + assert len(sample_args) == len(callables) + assert len(sample_kwargs) == len(callables) + else: + # Custom logic for interleaved pipeline parallelism + # Note: This is tightly coupled with the Megatron-core + # implementation of interleaved pipeline parallelism at + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py. + # Note: The model is assumed to consist of layers + # (corresponding to callables) that are grouped into + # equally-sized model chunks. _order is a list of chunk + # indices (1-indexed) that indicates the order in which the + # layers are evaluated. Positive values indicate forward + # passes and negative values indicate backward passes. Each + # entry in sample_args corresponds to one of the forward + # passes. num_model_chunks = max(_order) num_microbatches = len(_order) // num_model_chunks // 2 assert num_model_chunks * num_microbatches * 2 == len(_order) @@ -90,10 +119,13 @@ def _make_graphed_callables( f"Expected {num_model_chunks * num_microbatches}" + f"args tuple, but got {len(sample_args)}." ) + assert len(sample_kwargs) == len(sample_args) if fp8_weight_caching: + # Initialize flag that controls FP8 weight updates FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + # Check callables for c in callables: if isinstance(c, torch.nn.Module): assert ( @@ -110,9 +142,14 @@ def _make_graphed_callables( + ":func:`~make_graphed_callables`, only parameters may be trainable. " + "All buffers must have ``requires_grad=False``." ) - for args in sample_args: + + # Flatten callable arguments + per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs] + flatten_sample_args = [] + for args, kwargs, kwargs_keys in zip(sample_args, sample_kwargs, per_callable_kwargs_keys): flatten_arg, _ = _tree_flatten(args) - flatten_sample_args.append(tuple(flatten_arg)) + flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys]) + flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg)) assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( "In the beta API, sample_args " + "for each callable must contain only Tensors. Other types are not allowed." @@ -120,6 +157,10 @@ def _make_graphed_callables( # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # passes to forward (ie, its sample_args) AND the module's parameter attributes. + # Note: These per_callable_* variables are not actually + # per-callable, but per-forward-pass (see description of _order). + # The names are kept for consistency with + # torch.cuda.make_graphed_callables. per_callable_len_user_args = [len(args) for args in flatten_sample_args] if _order is None: per_callable_module_params = [ @@ -144,6 +185,7 @@ def _make_graphed_callables( fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))] + # For cases with multiple active RNG states, e.g. TP. if graph_safe_rng_available(): for _, state in get_all_rng_states().items(): @@ -158,11 +200,12 @@ def _make_graphed_callables( # from ending up in any captures. torch.cuda.synchronize() with torch.cuda.stream(torch.cuda.Stream()): - for c_i, func in enumerate(callables): - args = sample_args[c_i] - static_input_surface = per_callable_static_input_surfaces[c_i] + for func_idx, func in enumerate(callables): + args = sample_args[func_idx] + kwargs = sample_kwargs[func_idx] + static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): - outputs, _ = _tree_flatten(func(*args)) + outputs, _ = _tree_flatten(func(*args, **kwargs)) grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -194,9 +237,10 @@ def _make_graphed_callables( fwd_idx[m_chunk] * num_layers + l_no ) args = sample_args[per_callable_fwd_idx] + kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + outputs = func(*args, **kwargs) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec @@ -245,9 +289,9 @@ def _make_graphed_callables( per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] graph_id = 0 - for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + outputs = func(*args, **kwargs) graph_callables[graph_id] = func graph_id += 1 @@ -300,6 +344,7 @@ def make_graphed_autograd_function( fwd_graph, bwd_graph, module_params, + kwargs_keys, len_user_args, output_unflatten_spec, static_input_surface, @@ -312,14 +357,18 @@ class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, skip_fp8_weight_update, *inputs): - # At this stage, only the user args may (potentially) be new tensors. + + # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) + # Copy values from new tensors into static tensors for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): static_input_surface[i].copy_(inputs[i]) + + # Replay forward graph fwd_graph.replay() assert isinstance(static_outputs, tuple) return tuple(o.detach() for o in static_outputs) @@ -327,6 +376,8 @@ def forward(ctx, skip_fp8_weight_update, *inputs): @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, *grads): + + # Replay backward graph assert len(grads) == len(static_grad_outputs) for g, grad in zip(static_grad_outputs, grads): if g is not None: @@ -336,6 +387,7 @@ def backward(ctx, *grads): g.copy_(grad) bwd_graph.replay() + # Update FP8 scale factors if needed if ctx.is_first_module: FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -346,10 +398,8 @@ def backward(ctx, *grads): ) def functionalized(*user_args, **user_kwargs): - # Runs the autograd function with inputs == all - # inputs to the graph that might require grad - # (explicit user args + module parameters) - # Assumes module params didn't change since capture. + + # Decide whether to update FP8 weights skip_fp8_weight_update = None if fp8_weight_caching: assert "is_first_microbatch" in user_kwargs and isinstance( @@ -358,8 +408,22 @@ def functionalized(*user_args, **user_kwargs): skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + # Check that required kwargs are provided + for key in kwargs_keys: + if key not in user_kwargs: + raise TypeError( + f"Graphed callable was initialized with kwarg {key} ," + "but it was not provided in graph replay" + ) + + # Runs the autograd function with inputs == all inputs to + # the graph that might require grad (explicit user args + + # module parameters) + # Assumes module params didn't change since capture. flatten_user_args, _ = _tree_flatten(user_args) - out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params)) + flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) + func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params + out = Graphed.apply(skip_fp8_weight_update, *func_args) return _tree_unflatten(out, output_unflatten_spec) return functionalized @@ -371,6 +435,7 @@ def functionalized(*user_args, **user_kwargs): fwd_graphs[i], bwd_graphs[i], per_callable_module_params[i], + per_callable_kwargs_keys[i], per_callable_len_user_args[i], per_callable_output_unflatten_spec[i], per_callable_static_input_surfaces[i], @@ -443,25 +508,42 @@ def restore_fp8_tensors(modules, fp8_tensors): def make_graphed_callables( - modules, - sample_args, - num_warmup_iters=3, - allow_unused_input=False, - fp8_enabled=False, - fp8_calibrating=False, - fp8_recipe=None, - fp8_weight_caching=False, - _order=None, -): + modules: SingleOrTuple[Callable], + sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, + fp8_enabled: bool = False, + fp8_calibrating: bool = False, + fp8_recipe: Optional[DelayedScaling] = None, + fp8_weight_caching: bool = False, + _order: Optional[List[int]] = None, +) -> Union[Callable, Tuple[Callable, ...]]: """ - A version of PyTorch's `make_graphed_callables` utility function with support for - TransformerEngine modules and FP8. Please see the original version in upstream PyTorch - `here `_ - for extensive documentation. The documentation for additional parameters which are - specific to FP8 are given below. - - FP8 specific parameters - ----------------------- + Make CUDA graph version of Transformer Engine modules + + A variation of PyTorch's `make_graphed_callables` utility function + with support for Transformer Engine modules and FP8. Please see + the + `original PyTorch implementation `_ + for more documentation. + + Graphing parameters + ------------------- + modules: (tuple of) callable + Callable or callables to graph. + sample_args: (tuple of) tuple of torch.Tensor + Positional arguments to callable(s). + num_warmup_iters: int, default = 3 + Number of warmup iterations. + allow_unused_input: bool, default = `False` + Whether to handle case where callable inputs + and outputs are disconnected in compute graph. + sample_kwargs: (tuple of) dict, optional + Keyword arguments to callable(s) + + FP8-related parameters + ---------------------- fp8_enabled: bool, default = `True` whether or not to enable fp8 fp8_calibrating: bool, default = `False` @@ -478,6 +560,7 @@ def make_graphed_callables( using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg must be set to `False` if calculating weight transposes' outside TE, e.g., in the optimizer step. + """ set_capture_start() @@ -532,6 +615,7 @@ def forward_func(*args, **kwargs): num_warmup_iters=num_warmup_iters, allow_unused_input=allow_unused_input, fp8_weight_caching=fp8_weight_caching, + sample_kwargs=sample_kwargs, _order=_order, )