Skip to content

Commit

Permalink
Add support to specify different attention templates for different la…
Browse files Browse the repository at this point in the history
…yers in batch major attention.

PiperOrigin-RevId: 633385811
  • Loading branch information
lingvo-bot authored and Copybara-Service committed Jun 5, 2024
1 parent d0c8081 commit 8ba944a
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 64 deletions.
65 changes: 48 additions & 17 deletions lingvo/core/batch_major_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8734,6 +8734,7 @@ def __init__(self, params):
p = self.params
if p.num_splits > 1 or p.num_micro_batches > 1:
assert p.deterministic_dropout
assert p.atten_tpl is not None, 'atten_tpl must be set.'

def _Dropout(self, name, drop_prob):
"""Returns a DropoutLayer Params."""
Expand Down Expand Up @@ -8802,13 +8803,15 @@ def _MultiHeadedAtten(self, name, num_heads=None,
enable_qkv_proj_in_onestep=False,
enable_qk_proj_in_onestep=False,
query_stride=1,
query_first_n=None):
query_first_n=None,
atten_tpl=None):
"""Returns a MultiHeadedAttention params."""
p = self.params
if num_heads is None:
num_heads = p.num_heads

atten_p = p.atten_tpl.Copy().Set(
if atten_tpl is None:
atten_tpl = p.atten_tpl
atten_p = atten_tpl.Copy().Set(
name=name,
input_dim=p.model_dim,
hidden_dim=p.attention_hidden_dim or p.model_dim,
Expand Down Expand Up @@ -9145,7 +9148,12 @@ def _Stride(self, name, stride, first_n=None, axis=1):
"""
return StrideLayer.Params().Set(stride=stride, first_n=first_n, axis=axis, name=name)

def _StridedAttention(self, name, stride=1, first_n=None, num_heads=None):
def _StridedAttention(self,
name,
stride=1,
first_n=None,
num_heads=None,
layer_idx=None):
"""Computes self attention with optional stride.
Args:
Expand All @@ -9161,6 +9169,7 @@ def _StridedAttention(self, name, stride=1, first_n=None, num_heads=None):
1. first_n can't be 0. If first_n <= stride, only the first token is
used.
num_heads: the number of heads.
layer_idx: Valid layer index if p.atten_tpl is a list.
Returns:
A self attention layer params.
Expand All @@ -9185,18 +9194,25 @@ def _StridedAttention(self, name, stride=1, first_n=None, num_heads=None):
if num_heads is None:
num_heads = p.num_heads

atten_tpl = None
if isinstance(p.atten_tpl, list):
assert layer_idx is not None, 'layer_idx must be specified.'
atten_tpl = p.atten_tpl[layer_idx]
else:
atten_tpl = p.atten_tpl

# compute qkv in one step only if default_enable_qkv_proj_in_onestep
# and no striding (stride==1) and not first_n
enable_qkv_proj_in_onestep = (p.default_enable_qkv_proj_in_onestep
and stride == 1
and not first_n)
# Overriding default param based on stride.
enable_qk_proj_in_onestep = (p.atten_tpl.enable_qk_proj_in_onestep
enable_qk_proj_in_onestep = (atten_tpl.enable_qk_proj_in_onestep
and stride == 1
and not first_n)
# Rope template doesn't handle first_n, so resetting it to None until it is
# handled correctly.
if first_n is not None and p.atten_tpl.rope_tpl is not None:
if first_n is not None and atten_tpl.rope_tpl is not None:
tf.logging.warning('Rope Attn needs to handle valid query_first_n.')
first_n = None

Expand All @@ -9206,7 +9222,7 @@ def _StridedAttention(self, name, stride=1, first_n=None, num_heads=None):
('after_ln->strided_query',
self._Stride('query_after_stride', stride, first_n)),
('{}->after_att,prob'.format(attention_inputs),
self._MultiHeadedAtten('atten', num_heads, enable_qkv_proj_in_onestep, enable_qk_proj_in_onestep, stride, first_n)),
self._MultiHeadedAtten('atten', num_heads, enable_qkv_proj_in_onestep, enable_qk_proj_in_onestep, stride, first_n, atten_tpl)),
('after_att->after_dropout',
self._Dropout('dropout', p.residual_dropout_prob)),
('{}->strided_input'.format(input_to_add),
Expand Down Expand Up @@ -9248,7 +9264,7 @@ def _Pool(self, name, stride, first_n=None):
first_n=first_n,
name=name)

def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None):
def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None, layer_idx=None):
"""Computes self attention with optional stride.
Args:
Expand All @@ -9263,6 +9279,7 @@ def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None):
be None or 1. first_n can't be 0. If first_n <= stride, only the first
token is used.
num_heads: the number of heads.
layer_idx: Valid layer index if p.atten_tpl is a list..
Returns:
A self attention layer params.
Expand All @@ -9283,6 +9300,11 @@ def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None):

if num_heads is None:
num_heads = p.num_heads
if isinstance(p.atten_tpl, list):
assert layer_idx is not None, 'layer_idx must be specified.'
atten_tpl = p.atten_tpl[layer_idx]
else:
atten_tpl = p.atten_tpl
sub_list = []
if p.packed_input:
if stride > 1:
Expand All @@ -9302,12 +9324,12 @@ def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None):
and stride == 1
and not first_n)
# Overriding default param based on stride.
enable_qk_proj_in_onestep = (p.atten_tpl.enable_qk_proj_in_onestep
enable_qk_proj_in_onestep = (atten_tpl.enable_qk_proj_in_onestep
and stride == 1
and not first_n)
# Rope template doesn't handle first_n, so resetting it to None until it is
# handled correctly.
if first_n is not None and p.atten_tpl.rope_tpl is not None:
if first_n is not None and atten_tpl.rope_tpl is not None:
tf.logging.warning('Rope Attn needs to handle valid query_first_n.')
first_n = None

Expand All @@ -9317,7 +9339,7 @@ def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None):
('after_ln,i.paddings->strided_query,o.paddings',
self._Pool('query_after_pooling', stride, first_n)),
('{}->after_att,prob'.format(attention_inputs),
self._MultiHeadedAtten('atten', num_heads, enable_qkv_proj_in_onestep, enable_qk_proj_in_onestep, stride, first_n)),
self._MultiHeadedAtten('atten', num_heads, enable_qkv_proj_in_onestep, enable_qk_proj_in_onestep, stride, first_n, atten_tpl)),
('after_att->after_dropout',
self._Dropout('dropout', p.residual_dropout_prob)),
shortcut_sub,
Expand All @@ -9340,7 +9362,8 @@ def FunnelEncoderLayer(self, name, stride=1, first_n=None,
ff_hidden_dim=None, num_heads=None,
ff_gated_fn=None,
num_ffns=1,
use_moe=False):
use_moe=False,
layer_idx=None):
"""(inputs, paddings) -> (encoded, paddings).
Args:
Expand All @@ -9363,6 +9386,7 @@ def FunnelEncoderLayer(self, name, stride=1, first_n=None,
'gelu', and callable.
num_ffns: number of ffn layers.
use_moe: The first of the ffn layer use moe.
layer_idx: Valid layer index if p.atten_tpl is a list..
Returns:
A transformer encoder layer params that supports optional stride.
Expand Down Expand Up @@ -9390,7 +9414,7 @@ def FunnelEncoderLayer(self, name, stride=1, first_n=None,
moe_p = self.MoE('moe', ff_hidden_dim=ff_hidden_dim)

s_layers = [self._FunnelAttention('self_atten', stride=stride,
first_n=first_n, num_heads=num_heads),
first_n=first_n, num_heads=num_heads, layer_idx=layer_idx),
moe_p if use_moe else ff_layer]
if num_ffns > 1:
for ffn_id in range(1, num_ffns):
Expand All @@ -9399,7 +9423,7 @@ def FunnelEncoderLayer(self, name, stride=1, first_n=None,

def TransformerEncoderLayer(self, name, stride=1, first_n=None,
ff_hidden_dim=None, num_heads=None,
use_moe=False):
use_moe=False, layer_idx=None):
"""(inputs, paddings) -> (encoded, paddings).
Args:
Expand All @@ -9418,6 +9442,7 @@ def TransformerEncoderLayer(self, name, stride=1, first_n=None,
num_heads: The number of heads for the multi-head attention module. If
specified, this will override p.num_heads.
use_moe: whether to use moe feedforward layer or not.
layer_idx: Valid layer index if p.atten_tpl is a list..
Returns:
A transformer encoder layer params that supports optional stride.
Expand All @@ -9434,7 +9459,7 @@ def TransformerEncoderLayer(self, name, stride=1, first_n=None,
return self._Seq(name, self._Seq(
'block',
self._StridedAttention('self_atten', stride=stride,
first_n=first_n, num_heads=num_heads),
first_n=first_n, num_heads=num_heads, layer_idx=layer_idx),
ffw_p))

def Stack(self, name, blocks, output_all_layer_hiddens=False):
Expand All @@ -9455,8 +9480,11 @@ def Stack(self, name, blocks, output_all_layer_hiddens=False):

def TransformerEncoderStack(self, name, num_layers=1):
"""Returns a stack of num_layers self-attention layers."""
p = self.params
if isinstance(p.atten_tpl, list):
assert len(p.atten_tpl) == num_layers, 'atten_tpl list must have the same length as num_layers.'
blocks = [
self.TransformerEncoderLayer(name='iter_{:0>3d}'.format(d))
self.TransformerEncoderLayer(name='iter_{:0>3d}'.format(d), layer_idx=d)
for d in range(num_layers)
]
return self.Stack(name, blocks)
Expand Down Expand Up @@ -9494,6 +9522,7 @@ def _MultiHeadedAtten(
enable_qk_proj_in_onestep=False,
query_stride=1,
query_first_n=None,
atten_tpl=None,
):
"""Returns a MultiHeadedAttention params."""
p = self.params
Expand Down Expand Up @@ -9635,6 +9664,7 @@ def _MultiHeadedAtten(
enable_qk_proj_in_onestep=False,
query_stride=1,
query_first_n=None,
atten_tpl=None,
):
"""Returns a MultiHeadedAttention params."""
p = self.params
Expand Down Expand Up @@ -10037,12 +10067,13 @@ def _Attention(self, name, is_causal=True):
('i.paddings->o.paddings', self._Id('id')),
)

def TransformerEncoderLayer(self, name, is_causal=True):
def TransformerEncoderLayer(self, name, is_causal=True, layer_idx=None):
"""(inputs, paddings) -> (encoded, paddings).
Args:
name: the string name of the encoder layer params.
is_causal: If true, add cause per_step padding to the attention layer.
layer_idx: the index of the layer.
Returns:
A transformer encoder layer params that supports optional stride.
Expand Down
32 changes: 28 additions & 4 deletions lingvo/core/batch_major_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5801,6 +5801,14 @@ def testBProp(self):
'testcase_name': '_baseline',
'strides': [1, 1],
},
{
'testcase_name': '_baseline_atten_tpl_list',
'strides': [1, 1],
'atten_tpl': [
attention.MultiHeadedAttention.Params(),
attention.MultiHeadedAttention.Params(),
],
},
{
'testcase_name': '_stride_2',
'strides': [1, 2],
Expand Down Expand Up @@ -5835,6 +5843,7 @@ def testFunnelTransformerStack(
trunc_seq=True,
num_splits=1,
num_micro_batches=1,
atten_tpl=None,
):
with self.session(use_gpu=False) as sess:
bs = 2
Expand All @@ -5851,6 +5860,7 @@ def testFunnelTransformerStack(
funnel_pool_tpl=attention.FunnelPoolingLayer.Params().Set(
begin_intact=begin_intact, trunc_seq=trunc_seq
),
atten_tpl=atten_tpl or attention.MultiHeadedAttention.Params(),
)
atten_builder = atten_builder_params.Instantiate()
layers = []
Expand All @@ -5859,7 +5869,9 @@ def testFunnelTransformerStack(
accumulate_stride *= stride
layers.append(
atten_builder.FunnelEncoderLayer(
name='atten_{}'.format(layer_i), stride=stride
name='atten_{}'.format(layer_i),
stride=stride,
layer_idx=layer_i,
)
)
p = atten_builder.Stack('model', layers)
Expand Down Expand Up @@ -6327,6 +6339,14 @@ def testFunnelEncoderLayerWithPerLayerFfns(self):
'testcase_name': '_baseline',
'strides': [1, 1],
},
{
'testcase_name': '_baseline_atten_tpl_list',
'strides': [1, 1],
'atten_tpl': [
attention.MultiHeadedAttention.Params(),
attention.MultiHeadedAttention.Params(),
],
},
{
'testcase_name': '_stride_2',
'strides': [2, 1],
Expand All @@ -6336,15 +6356,17 @@ def testFunnelEncoderLayerWithPerLayerFfns(self):
'strides': [2, 0],
},
)
def testTransformerStackWithStride(self, strides):
def testTransformerStackWithStride(
self, strides, atten_tpl=attention.MultiHeadedAttention.Params()
):
with self.session(use_gpu=False) as sess:
bs = 2
sl = 10
d = 16
tf.random.set_seed(12345)
atten_builder = (
attention.Builder.Params()
.Set(model_dim=d, num_heads=2, ff_hidden_dim=5)
.Set(model_dim=d, num_heads=2, ff_hidden_dim=5, atten_tpl=atten_tpl)
.Instantiate()
)
layers = []
Expand All @@ -6353,7 +6375,9 @@ def testTransformerStackWithStride(self, strides):
accumulate_stride *= stride
layers.append(
atten_builder.TransformerEncoderLayer(
name='atten_{}'.format(layer_i), stride=stride
name='atten_{}'.format(layer_i),
stride=stride,
layer_idx=layer_i,
)
)
p = atten_builder.Seq('model', *layers)
Expand Down
Loading

0 comments on commit 8ba944a

Please sign in to comment.