Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 model init factory #880

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sudhakarsingh27
Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented May 30, 2024

Description

Trying to bake in fp8_model_init into layer initialization.

The fp8_model_init context manager needs to be then added/managed by the user. Baking it into TE layer initialization would allow that in addition to being just an argument.

When trying to integrate with larger code bases like megatron or HF accelerate, we just need to pass the argument otherwise we'll have to figure out a place to add this context manager. So theoretically, this would result less code change.

One thing I'm not sure about is if calling fp8_model_init per layer is fine.

(Another though: could this potentially also allow selectively controlling to which layer to apply fp8 weights and is that helpful?)

@ptrendx @timmoon10 @ksivaman, do you think this makes sense?

@ksivaman
Copy link
Member

Could you outline the motivation behind this? Currently we have the fp8_model_init user API that works as a context manager and IIRC this is trying to effectively expose the same as a parameter. Why?

@sudhakarsingh27
Copy link
Collaborator Author

sudhakarsingh27 commented May 30, 2024

The CM API needs to be then managed and added by the user. But this change would allow that in addition to just an argument.

When trying to integrate with larger code bases like megatron or HF accelerate, we just need to pass the argument otherwise we'll have to figure out a place to add this context manager. So theoretically, this would result less code change but I'm not sure if calling fp8_model_init per layer is fine. (Could this potentially also allow selectively controlling which layer to only do fp8 weights and is that helpful?)

@timmoon10
Copy link
Collaborator

I think initializing FP8 weights with a constructor kwarg makes a lot of sense. In effect, the fp8_model_init context is an indirect way of passing a boolean arg to the module constructors (although it has the advantage/disadvantage of setting all modules to the same value). For backward compatibility, how about an API like:

class Linear(TransformerEngineBaseModule):

    def __init__
        self,
        ...,
        with_fp8_weight: Optional[bool] = None,
    ):
        if with_fp8_weight is None:
            with_fp8_weight = FP8GlobalStateManager.with_fp8_parameters()

        ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants