Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add option to pass kwargs to CUDA graph module #945

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jun 19, 2024

Description

This PR addresses a request to modify te.make_graphed_callables so we can pass in kwargs like the attention masks. See vasunvidia@d0a1057 for an initial implementation. If kwargs are provided in te.make_graphed_callables (via the sample_kwargs), they must also be provided whenever the graph is replayed. Note that only tensors are accepted as positional args or kwargs, since otherwise we run into another pile of design problems (what happens if the args differ during graph capture and graph replays?).

To be honest, I don't really like this approach. Ideally te.make_graphed_callables should match the API of torch.cuda.make_graphed_callables, which only supports positional args. But the best ways to handle modules with kwargs in plain PyTorch are creating wrappers that handle the kwargs:

# Non-graphed module
y = mymodule(x, key=val)

# Option 1: wrap module and call with positional args
class MyWrapper(torch.nn.Module):
    def __init__(self, module):
        self.module = module
    def forward(self, x, val):
        return self.module(x, key=val)
graphed_forward1 = torch.nn.make_graphed_callables(
    MyWrapper(mymodule),
    (x, val),
)
y = graphed_forward1(x, val)

# Option 2: wrap module and wrap graphed callable
graphed_forward2 = lambda x, *, val: graphed_forward1(x, val)
y = graphed_forward2(x, key=val)

This is quite clunky. If we accept API divergence from PyTorch, it becomes much cleaner:

# Option 3: modify API for make_graphed_callables
graphed_forward3 = te.make_graphed_callables(
    mymodule,
    x,
    sample_kwargs=dict(key=val),
)
y = graphed_forward3(x, key=val)

While I was touching the code, I also commented and added tests for the custom integration with Megatron-LM interleaved pipeline parallelism.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Add option to pass kwargs to CUDA graph module
  • Add tests for CUDA graph integration with Megatron-LM interleaved pipeline parallelism

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 added the enhancement New feature or request label Jun 19, 2024
@timmoon10 timmoon10 requested a review from ksivaman June 19, 2024 00:15
@timmoon10 timmoon10 marked this pull request as ready for review July 2, 2024 00:14
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants