Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add layer norm weight plus 1 #378

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from apex.normalization import MixedFusedRMSNorm as RMSNorm
else:
from .rmsnorm import RMSNorm
from torch.nn import LayerNorm
from .layer_norm_p1 import LayerNorm1P as LayerNorm
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than replacing torch.nn.LayerNorm and attempting to maintain the logic inside LayerNorm1P, I think a better and maintainable approach is to instantiate self.input_layernorm (and others) to either torch.nn.LayerNorm or LayerNorm1P based on args.apply_layernorm_1p. This will simplify LayerNorm1P to no longer handle the apply_layernorm_1p=False case.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, I am curious whether something like the following could work. Thanks!

   actualLayerNorm = LayerNorm1P if args.apply_layernorm_1p else torch.nn.LayerNorm
   ....
   self.input_layernorm = actualLayerNorm(...
  ...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, I am curious whether something like the following could work. Thanks!

   actualLayerNorm = LayerNorm1P if args.apply_layernorm_1p else torch.nn.LayerNorm
   ....
   self.input_layernorm = actualLayerNorm(...
  ...

Hi @tjruwase, This case could work. we need to replace all 'Layernorm' with 'actualLayerNorm' if using this logic. We can prepare another PR for this logic.

This PR follows the cuda logic to add LayernormP1,

if get_accelerator().device_name() == 'cuda':
    from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm

Layernorm1P will return the vanilla layernorm when apply_layernorm_1p=False(default path).


from .distributed import DistributedDataParallel
from .bert_model import BertModel
Expand Down
38 changes: 38 additions & 0 deletions megatron/model/layer_norm_p1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import math
import numbers

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn import init


class LayerNorm1P(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, apply_layernorm_1p=False):
super(LayerNorm1P, self).__init__()
self.eps = eps
self.apply_layernorm_1p = apply_layernorm_1p

if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()

def reset_parameters(self):

if self.apply_layernorm_1p:
init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)

def forward(self, input):
if self.apply_layernorm_1p:
weight_plus_1 = (self.weight + 1)
output = torch.nn.functional.layer_norm(input, self.normalized_shape, weight_plus_1, self.bias, self.eps)
return output
else:
return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
9 changes: 6 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,8 @@ def __init__(self, config,
else:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
eps=config.layernorm_epsilon,
apply_layernorm_1p=args.apply_layernorm_1p)
else:
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# Self attention.
Expand All @@ -939,7 +940,8 @@ def __init__(self, config,
else:
self.post_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
eps=config.layernorm_epsilon,
apply_layernorm_1p=args.apply_layernorm_1p)
else:
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# Cross attention.
Expand Down Expand Up @@ -1762,7 +1764,8 @@ def build_layer(layer_number, n_e):
else:
self.final_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
eps=config.layernorm_epsilon,
apply_layernorm_1p=args.apply_layernorm_1p)
else:
self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)

Expand Down