diff --git a/lingvo/core/batch_major_attention.py b/lingvo/core/batch_major_attention.py index d68d99fd5..86455b883 100644 --- a/lingvo/core/batch_major_attention.py +++ b/lingvo/core/batch_major_attention.py @@ -8734,6 +8734,7 @@ def __init__(self, params): p = self.params if p.num_splits > 1 or p.num_micro_batches > 1: assert p.deterministic_dropout + assert p.atten_tpl is not None, 'atten_tpl must be set.' def _Dropout(self, name, drop_prob): """Returns a DropoutLayer Params.""" @@ -8802,13 +8803,15 @@ def _MultiHeadedAtten(self, name, num_heads=None, enable_qkv_proj_in_onestep=False, enable_qk_proj_in_onestep=False, query_stride=1, - query_first_n=None): + query_first_n=None, + atten_tpl=None): """Returns a MultiHeadedAttention params.""" p = self.params if num_heads is None: num_heads = p.num_heads - - atten_p = p.atten_tpl.Copy().Set( + if atten_tpl is None: + atten_tpl = p.atten_tpl + atten_p = atten_tpl.Copy().Set( name=name, input_dim=p.model_dim, hidden_dim=p.attention_hidden_dim or p.model_dim, @@ -9145,7 +9148,12 @@ def _Stride(self, name, stride, first_n=None, axis=1): """ return StrideLayer.Params().Set(stride=stride, first_n=first_n, axis=axis, name=name) - def _StridedAttention(self, name, stride=1, first_n=None, num_heads=None): + def _StridedAttention(self, + name, + stride=1, + first_n=None, + num_heads=None, + layer_idx=None): """Computes self attention with optional stride. Args: @@ -9161,6 +9169,7 @@ def _StridedAttention(self, name, stride=1, first_n=None, num_heads=None): 1. first_n can't be 0. If first_n <= stride, only the first token is used. num_heads: the number of heads. + layer_idx: Valid layer index if p.atten_tpl is a list. Returns: A self attention layer params. @@ -9185,18 +9194,25 @@ def _StridedAttention(self, name, stride=1, first_n=None, num_heads=None): if num_heads is None: num_heads = p.num_heads + atten_tpl = None + if isinstance(p.atten_tpl, list): + assert layer_idx is not None, 'layer_idx must be specified.' + atten_tpl = p.atten_tpl[layer_idx] + else: + atten_tpl = p.atten_tpl + # compute qkv in one step only if default_enable_qkv_proj_in_onestep # and no striding (stride==1) and not first_n enable_qkv_proj_in_onestep = (p.default_enable_qkv_proj_in_onestep and stride == 1 and not first_n) # Overriding default param based on stride. - enable_qk_proj_in_onestep = (p.atten_tpl.enable_qk_proj_in_onestep + enable_qk_proj_in_onestep = (atten_tpl.enable_qk_proj_in_onestep and stride == 1 and not first_n) # Rope template doesn't handle first_n, so resetting it to None until it is # handled correctly. - if first_n is not None and p.atten_tpl.rope_tpl is not None: + if first_n is not None and atten_tpl.rope_tpl is not None: tf.logging.warning('Rope Attn needs to handle valid query_first_n.') first_n = None @@ -9206,7 +9222,7 @@ def _StridedAttention(self, name, stride=1, first_n=None, num_heads=None): ('after_ln->strided_query', self._Stride('query_after_stride', stride, first_n)), ('{}->after_att,prob'.format(attention_inputs), - self._MultiHeadedAtten('atten', num_heads, enable_qkv_proj_in_onestep, enable_qk_proj_in_onestep, stride, first_n)), + self._MultiHeadedAtten('atten', num_heads, enable_qkv_proj_in_onestep, enable_qk_proj_in_onestep, stride, first_n, atten_tpl)), ('after_att->after_dropout', self._Dropout('dropout', p.residual_dropout_prob)), ('{}->strided_input'.format(input_to_add), @@ -9248,7 +9264,7 @@ def _Pool(self, name, stride, first_n=None): first_n=first_n, name=name) - def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None): + def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None, layer_idx=None): """Computes self attention with optional stride. Args: @@ -9263,6 +9279,7 @@ def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None): be None or 1. first_n can't be 0. If first_n <= stride, only the first token is used. num_heads: the number of heads. + layer_idx: Valid layer index if p.atten_tpl is a list.. Returns: A self attention layer params. @@ -9283,6 +9300,11 @@ def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None): if num_heads is None: num_heads = p.num_heads + if isinstance(p.atten_tpl, list): + assert layer_idx is not None, 'layer_idx must be specified.' + atten_tpl = p.atten_tpl[layer_idx] + else: + atten_tpl = p.atten_tpl sub_list = [] if p.packed_input: if stride > 1: @@ -9302,12 +9324,12 @@ def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None): and stride == 1 and not first_n) # Overriding default param based on stride. - enable_qk_proj_in_onestep = (p.atten_tpl.enable_qk_proj_in_onestep + enable_qk_proj_in_onestep = (atten_tpl.enable_qk_proj_in_onestep and stride == 1 and not first_n) # Rope template doesn't handle first_n, so resetting it to None until it is # handled correctly. - if first_n is not None and p.atten_tpl.rope_tpl is not None: + if first_n is not None and atten_tpl.rope_tpl is not None: tf.logging.warning('Rope Attn needs to handle valid query_first_n.') first_n = None @@ -9317,7 +9339,7 @@ def _FunnelAttention(self, name, stride=1, first_n=None, num_heads=None): ('after_ln,i.paddings->strided_query,o.paddings', self._Pool('query_after_pooling', stride, first_n)), ('{}->after_att,prob'.format(attention_inputs), - self._MultiHeadedAtten('atten', num_heads, enable_qkv_proj_in_onestep, enable_qk_proj_in_onestep, stride, first_n)), + self._MultiHeadedAtten('atten', num_heads, enable_qkv_proj_in_onestep, enable_qk_proj_in_onestep, stride, first_n, atten_tpl)), ('after_att->after_dropout', self._Dropout('dropout', p.residual_dropout_prob)), shortcut_sub, @@ -9340,7 +9362,8 @@ def FunnelEncoderLayer(self, name, stride=1, first_n=None, ff_hidden_dim=None, num_heads=None, ff_gated_fn=None, num_ffns=1, - use_moe=False): + use_moe=False, + layer_idx=None): """(inputs, paddings) -> (encoded, paddings). Args: @@ -9363,6 +9386,7 @@ def FunnelEncoderLayer(self, name, stride=1, first_n=None, 'gelu', and callable. num_ffns: number of ffn layers. use_moe: The first of the ffn layer use moe. + layer_idx: Valid layer index if p.atten_tpl is a list.. Returns: A transformer encoder layer params that supports optional stride. @@ -9390,7 +9414,7 @@ def FunnelEncoderLayer(self, name, stride=1, first_n=None, moe_p = self.MoE('moe', ff_hidden_dim=ff_hidden_dim) s_layers = [self._FunnelAttention('self_atten', stride=stride, - first_n=first_n, num_heads=num_heads), + first_n=first_n, num_heads=num_heads, layer_idx=layer_idx), moe_p if use_moe else ff_layer] if num_ffns > 1: for ffn_id in range(1, num_ffns): @@ -9399,7 +9423,7 @@ def FunnelEncoderLayer(self, name, stride=1, first_n=None, def TransformerEncoderLayer(self, name, stride=1, first_n=None, ff_hidden_dim=None, num_heads=None, - use_moe=False): + use_moe=False, layer_idx=None): """(inputs, paddings) -> (encoded, paddings). Args: @@ -9418,6 +9442,7 @@ def TransformerEncoderLayer(self, name, stride=1, first_n=None, num_heads: The number of heads for the multi-head attention module. If specified, this will override p.num_heads. use_moe: whether to use moe feedforward layer or not. + layer_idx: Valid layer index if p.atten_tpl is a list.. Returns: A transformer encoder layer params that supports optional stride. @@ -9434,7 +9459,7 @@ def TransformerEncoderLayer(self, name, stride=1, first_n=None, return self._Seq(name, self._Seq( 'block', self._StridedAttention('self_atten', stride=stride, - first_n=first_n, num_heads=num_heads), + first_n=first_n, num_heads=num_heads, layer_idx=layer_idx), ffw_p)) def Stack(self, name, blocks, output_all_layer_hiddens=False): @@ -9455,8 +9480,11 @@ def Stack(self, name, blocks, output_all_layer_hiddens=False): def TransformerEncoderStack(self, name, num_layers=1): """Returns a stack of num_layers self-attention layers.""" + p = self.params + if isinstance(p.atten_tpl, list): + assert len(p.atten_tpl) == num_layers, 'atten_tpl list must have the same length as num_layers.' blocks = [ - self.TransformerEncoderLayer(name='iter_{:0>3d}'.format(d)) + self.TransformerEncoderLayer(name='iter_{:0>3d}'.format(d), layer_idx=d) for d in range(num_layers) ] return self.Stack(name, blocks) @@ -9494,6 +9522,7 @@ def _MultiHeadedAtten( enable_qk_proj_in_onestep=False, query_stride=1, query_first_n=None, + atten_tpl=None, ): """Returns a MultiHeadedAttention params.""" p = self.params @@ -9635,6 +9664,7 @@ def _MultiHeadedAtten( enable_qk_proj_in_onestep=False, query_stride=1, query_first_n=None, + atten_tpl=None, ): """Returns a MultiHeadedAttention params.""" p = self.params @@ -10037,12 +10067,13 @@ def _Attention(self, name, is_causal=True): ('i.paddings->o.paddings', self._Id('id')), ) - def TransformerEncoderLayer(self, name, is_causal=True): + def TransformerEncoderLayer(self, name, is_causal=True, layer_idx=None): """(inputs, paddings) -> (encoded, paddings). Args: name: the string name of the encoder layer params. is_causal: If true, add cause per_step padding to the attention layer. + layer_idx: the index of the layer. Returns: A transformer encoder layer params that supports optional stride. diff --git a/lingvo/core/batch_major_attention_test.py b/lingvo/core/batch_major_attention_test.py index 4c8020c1b..831078301 100644 --- a/lingvo/core/batch_major_attention_test.py +++ b/lingvo/core/batch_major_attention_test.py @@ -5801,6 +5801,14 @@ def testBProp(self): 'testcase_name': '_baseline', 'strides': [1, 1], }, + { + 'testcase_name': '_baseline_atten_tpl_list', + 'strides': [1, 1], + 'atten_tpl': [ + attention.MultiHeadedAttention.Params(), + attention.MultiHeadedAttention.Params(), + ], + }, { 'testcase_name': '_stride_2', 'strides': [1, 2], @@ -5835,6 +5843,7 @@ def testFunnelTransformerStack( trunc_seq=True, num_splits=1, num_micro_batches=1, + atten_tpl=None, ): with self.session(use_gpu=False) as sess: bs = 2 @@ -5851,6 +5860,7 @@ def testFunnelTransformerStack( funnel_pool_tpl=attention.FunnelPoolingLayer.Params().Set( begin_intact=begin_intact, trunc_seq=trunc_seq ), + atten_tpl=atten_tpl or attention.MultiHeadedAttention.Params(), ) atten_builder = atten_builder_params.Instantiate() layers = [] @@ -5859,7 +5869,9 @@ def testFunnelTransformerStack( accumulate_stride *= stride layers.append( atten_builder.FunnelEncoderLayer( - name='atten_{}'.format(layer_i), stride=stride + name='atten_{}'.format(layer_i), + stride=stride, + layer_idx=layer_i, ) ) p = atten_builder.Stack('model', layers) @@ -6327,6 +6339,14 @@ def testFunnelEncoderLayerWithPerLayerFfns(self): 'testcase_name': '_baseline', 'strides': [1, 1], }, + { + 'testcase_name': '_baseline_atten_tpl_list', + 'strides': [1, 1], + 'atten_tpl': [ + attention.MultiHeadedAttention.Params(), + attention.MultiHeadedAttention.Params(), + ], + }, { 'testcase_name': '_stride_2', 'strides': [2, 1], @@ -6336,7 +6356,9 @@ def testFunnelEncoderLayerWithPerLayerFfns(self): 'strides': [2, 0], }, ) - def testTransformerStackWithStride(self, strides): + def testTransformerStackWithStride( + self, strides, atten_tpl=attention.MultiHeadedAttention.Params() + ): with self.session(use_gpu=False) as sess: bs = 2 sl = 10 @@ -6344,7 +6366,7 @@ def testTransformerStackWithStride(self, strides): tf.random.set_seed(12345) atten_builder = ( attention.Builder.Params() - .Set(model_dim=d, num_heads=2, ff_hidden_dim=5) + .Set(model_dim=d, num_heads=2, ff_hidden_dim=5, atten_tpl=atten_tpl) .Instantiate() ) layers = [] @@ -6353,7 +6375,9 @@ def testTransformerStackWithStride(self, strides): accumulate_stride *= stride layers.append( atten_builder.TransformerEncoderLayer( - name='atten_{}'.format(layer_i), stride=stride + name='atten_{}'.format(layer_i), + stride=stride, + layer_idx=layer_i, ) ) p = atten_builder.Seq('model', *layers) diff --git a/lingvo/core/self_attention_layer.py b/lingvo/core/self_attention_layer.py index 8bdb3b60c..79ad1ace6 100644 --- a/lingvo/core/self_attention_layer.py +++ b/lingvo/core/self_attention_layer.py @@ -30,7 +30,7 @@ class Builder(batch_major_attention.Builder): """Builder for self-attention layers.""" - def SelfAttention(self, name): + def SelfAttention(self, name, layer_idx=None): p = self.params input_to_add = ( 'i.vec' if p.selfatten_add_unnormalized_input else 'after_ln') @@ -38,6 +38,11 @@ def SelfAttention(self, name): attention_inputs = 'after_ln,after_ln,after_ln,i.paddings' if p.packed_input: attention_inputs += ',i.segment_mask' + if isinstance(p.atten_tpl, list): + assert layer_idx is not None, 'layer_idx must be specified.' + atten_tpl = p.atten_tpl[layer_idx] + else: + atten_tpl = p.atten_tpl sub_list = [ ('i.vec->after_ln', self._DefaultLN('LN')), @@ -46,7 +51,8 @@ def SelfAttention(self, name): self._MultiHeadedAtten( 'atten', enable_qkv_proj_in_onestep=p.default_enable_qkv_proj_in_onestep, - enable_qk_proj_in_onestep=p.atten_tpl.enable_qk_proj_in_onestep, + enable_qk_proj_in_onestep=atten_tpl.enable_qk_proj_in_onestep, + atten_tpl=atten_tpl, ), ), ( @@ -70,33 +76,52 @@ def SelfAttention(self, name): *sub_list ) - def _TransformerLayerBlock(self, name, feed_forward_qdomain=None): + def _TransformerLayerBlock( + self, name, feed_forward_qdomain=None, layer_idx=None + ): """(inputs, paddings) -> (encoded, paddings).""" return self._Seq( name, - self.SelfAttention('self_atten'), + self.SelfAttention('self_atten', layer_idx=layer_idx), self.Feedforward('ff', qdomain=feed_forward_qdomain), ) def TransformerStack(self, name, num_layers=1, feed_forward_qdomain=None): """Returns a stack of num_layers self-attention layers.""" - blocks = [ - self._TransformerLayerBlock( - 'block_{}'.format(d), feed_forward_qdomain=feed_forward_qdomain - ) - for d in range(num_layers) - ] - return self._MaybeSplit(name, blocks) or ( - self._Rep(name, num_layers, self._TransformerLayerBlock('block')) - ) + p = self.params + if isinstance(p.atten_tpl, list): + assert ( + len(p.atten_tpl) == num_layers + ), 'atten_tpl list must have the same length as num_layers.' + blocks = [] + for i in range(num_layers): + blocks.append( + self._Seq( + 'iter_%03d' % i, + self._TransformerLayerBlock( + 'block', + feed_forward_qdomain=feed_forward_qdomain, + layer_idx=i, + ), + ) + ) + return self._MaybeSplit(name, blocks) or self._Seq(name, *blocks) def _StridedTransformerLayerBlock( - self, name, *, stride=1, first_n=None, feed_forward_qdomain=None + self, + name, + *, + stride=1, + first_n=None, + feed_forward_qdomain=None, + layer_idx=None ): """(inputs, paddings) -> (encoded, paddings).""" return self._Seq( name, - self._StridedAttention('self_atten', stride=stride, first_n=first_n), + self._StridedAttention( + 'self_atten', stride=stride, first_n=first_n, layer_idx=layer_idx + ), self.Feedforward('ff', qdomain=feed_forward_qdomain), ) @@ -110,6 +135,11 @@ def TransformerStackV2( feed_forward_qdomain=None ): """Returns a stack of num_layers self-attention layers.""" + p = self.params + if isinstance(p.atten_tpl, list): + assert ( + len(p.atten_tpl) == num_layers + ), 'atten_tpl list must have the same length as num_layers.' blocks = [] for i in range(num_layers): if i < num_layers - 1: @@ -124,6 +154,7 @@ def TransformerStackV2( stride=stride, first_n=first_n, feed_forward_qdomain=feed_forward_qdomain, + layer_idx=i, ), ) ) @@ -151,6 +182,7 @@ def TransformerLayerBlock( first_n=None, num_heads=None, feed_forward_qdomain=None, + layer_idx=None, ): """(inputs, paddings) -> (encoded, paddings).""" p = self.params @@ -182,6 +214,12 @@ def TransformerLayerBlock( if num_heads is None: num_heads = p.num_heads + if isinstance(p.atten_tpl, list): + assert layer_idx is not None, 'layer_idx must be specified.' + atten_tpl = p.atten_tpl[layer_idx] + else: + atten_tpl = p.atten_tpl + # compute qkv in one step only if default_enable_qkv_proj_in_onestep # and no striding (stride==1) and not first_n enable_qkv_proj_in_onestep = ( @@ -189,10 +227,10 @@ def TransformerLayerBlock( ) # Overriding default param based on stride. enable_qk_proj_in_onestep = ( - p.atten_tpl.enable_qk_proj_in_onestep + atten_tpl.enable_qk_proj_in_onestep and stride == 1 and not first_n - and not p.atten_tpl.use_mqa + and not atten_tpl.use_mqa ) sub_list += [ @@ -215,6 +253,7 @@ def TransformerLayerBlock( enable_qk_proj_in_onestep=enable_qk_proj_in_onestep, query_stride=stride, query_first_n=first_n, + atten_tpl=atten_tpl, ), ), ( @@ -334,15 +373,24 @@ def TransformerLayerBlock( def TransformerStack(self, name, num_layers=1, feed_forward_qdomain=None): """Returns a stack of num_layers self-attention layers.""" - blocks = [ - self.TransformerLayerBlock( - 'block_{}'.format(d), feed_forward_qdomain=feed_forward_qdomain - ) - for d in range(num_layers) - ] - return self._MaybeSplit(name, blocks) or ( - self._Rep(name, num_layers, self.TransformerLayerBlock('block')) - ) + p = self.params + if isinstance(p.atten_tpl, list): + assert ( + len(p.atten_tpl) == num_layers + ), 'atten_tpl list must have the same length as num_layers.' + blocks = [] + for i in range(num_layers): + blocks.append( + self._Seq( + 'iter_%03d' % i, + self.TransformerLayerBlock( + 'block', + feed_forward_qdomain=feed_forward_qdomain, + layer_idx=i, + ), + ) + ) + return self._MaybeSplit(name, blocks) or self._Seq(name, *blocks) def TransformerStackV2( self, @@ -354,6 +402,11 @@ def TransformerStackV2( feed_forward_qdomain=None ): """Returns a stack of num_layers self-attention layers.""" + p = self.params + if isinstance(p.atten_tpl, list): + assert ( + len(p.atten_tpl) == num_layers + ), 'atten_tpl list must have the same length as num_layers.' blocks = [] for i in range(num_layers): if i < num_layers - 1: @@ -368,6 +421,7 @@ def TransformerStackV2( stride=stride, first_n=first_n, feed_forward_qdomain=feed_forward_qdomain, + layer_idx=i, ), ) ) diff --git a/lingvo/core/self_attention_layer_test.py b/lingvo/core/self_attention_layer_test.py index e4b5e5e99..5eb9a9d71 100644 --- a/lingvo/core/self_attention_layer_test.py +++ b/lingvo/core/self_attention_layer_test.py @@ -33,6 +33,19 @@ class BuilderTest(test_utils.TestCase, parameterized.TestCase): 'num_splits': 1, 'num_micro_batches': 2, }, + { + 'testcase_name': '_atten_tpl_list', + 'num_splits': 1, + 'num_micro_batches': 2, + 'atten_tpl': [ + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + ], + }, { 'testcase_name': '_simplified_transformer', 'num_splits': 1, @@ -40,6 +53,21 @@ class BuilderTest(test_utils.TestCase, parameterized.TestCase): 'builder': self_attention.SimplifiedTransformerBuilder, 'expected_output': 39.930980, }, + { + 'testcase_name': '_simplified_transformer_atten_tpl_list', + 'num_splits': 1, + 'num_micro_batches': 1, + 'builder': self_attention.SimplifiedTransformerBuilder, + 'atten_tpl': [ + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + ], + 'expected_output': 39.930980, + }, { 'testcase_name': '_simplified_transformer_parallel', 'num_splits': 1, @@ -55,12 +83,14 @@ def testTransformerStack( num_micro_batches, builder=self_attention.Builder, parallel_attention_mlp=False, + atten_tpl=None, expected_output=386.16742, ): with self.session(use_gpu=False) as sess: bs = 2 sl = 21 d = 16 + num_layers = 6 tf.random.set_seed(12345) deterministic_dropout = num_splits > 1 or num_micro_batches > 1 atten_builder = builder.Params().Set( @@ -70,13 +100,20 @@ def testTransformerStack( deterministic_dropout=deterministic_dropout, num_splits=num_splits, num_micro_batches=num_micro_batches, + atten_tpl=atten_tpl or mt_attention.MultiHeadedAttention.Params(), ) if builder is self_attention.SimplifiedTransformerBuilder: atten_builder.Set( parallel_attention_mlp=parallel_attention_mlp, ) - atten_builder.atten_tpl.enable_shaped_attention = True - p = atten_builder.Instantiate().TransformerStack('atten', 6) + if isinstance(atten_builder.atten_tpl, list): + atten_builder.atten_tpl = [ + atten_builder.atten_tpl[i].Set(enable_shaped_attention=True) + for i in range(len(atten_builder.atten_tpl)) + ] + else: + atten_builder.atten_tpl.enable_shaped_attention = True + p = atten_builder.Instantiate().TransformerStack('atten', num_layers) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() input_embs = tf.constant(np.random.random(size=[bs, sl, d]), dtype=float) @@ -101,24 +138,54 @@ def testTransformerStack( { 'testcase_name': '_v1_stack', 'use_v1_stack': True, - }, { + }, + { + 'testcase_name': '_v1_stack_atten_tpl_list', + 'use_v1_stack': True, + 'atten_tpl': [ + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + ], + }, + { 'testcase_name': '_baseline', 'first_n': None, - }, { + }, + { + 'testcase_name': '_baseline_atten_tpl_list', + 'first_n': None, + 'atten_tpl': [ + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + ], + }, + { 'testcase_name': '_first_1', 'first_n': 1, - }, { + }, + { 'testcase_name': '_first_2', 'first_n': 2, - }, { + }, + { 'testcase_name': '_stride_2', 'stride': 2, - }) - def testTransformerStackV2(self, use_v1_stack=False, stride=1, first_n=None): + }, + ) + def testTransformerStackV2( + self, + use_v1_stack=False, + stride=1, + first_n=None, + atten_tpl=mt_attention.MultiHeadedAttention.Params(), + ): with self.session(use_gpu=False) as sess: bs = 2 sl = 21 d = 16 + num_layers = 3 tf.random.set_seed(12345) atten_builder = self_attention.Builder.Params().Set( model_dim=d, @@ -126,16 +193,19 @@ def testTransformerStackV2(self, use_v1_stack=False, stride=1, first_n=None): ff_hidden_dim=5, deterministic_dropout=False, num_splits=1, - num_micro_batches=1) + num_micro_batches=1, + atten_tpl=atten_tpl, + ) builder = atten_builder.Instantiate() if use_v1_stack: - p = builder.TransformerStack('atten', num_layers=3) + p = builder.TransformerStack('atten', num_layers=num_layers) else: p = builder.TransformerStackV2( 'atten', - num_layers=3, + num_layers=num_layers, final_layer_stride=stride, - final_layer_first_n=first_n) + final_layer_first_n=first_n, + ) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() self.assertAllEqual([ @@ -220,10 +290,28 @@ def testTransformerStackV2(self, use_v1_stack=False, stride=1, first_n=None): 'testcase_name': '_v1_stack', 'use_v1_stack': True, }, + { + 'testcase_name': '_v1_stack_atten_tpl_list', + 'use_v1_stack': True, + 'atten_tpl': [ + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + ], + }, { 'testcase_name': '_baseline', 'first_n': None, }, + { + 'testcase_name': '_baseline_atten_tpl_list', + 'first_n': None, + 'atten_tpl': [ + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + mt_attention.MultiHeadedAttention.Params(), + ], + }, { 'testcase_name': '_first_1', 'first_n': 1, @@ -243,12 +331,17 @@ def testTransformerStackV2(self, use_v1_stack=False, stride=1, first_n=None): }, ) def testTransformerStackV2WithSimplifiedTransformer( - self, use_v1_stack=False, stride=1, first_n=None + self, + use_v1_stack=False, + stride=1, + first_n=None, + atten_tpl=mt_attention.MultiHeadedAttention.Params(), ): with self.session(use_gpu=False) as sess: bs = 2 sl = 21 d = 16 + num_layers = 3 tf.random.set_seed(12345) atten_builder = self_attention.SimplifiedTransformerBuilder.Params().Set( model_dim=d, @@ -258,16 +351,25 @@ def testTransformerStackV2WithSimplifiedTransformer( num_splits=1, num_micro_batches=1, selfatten_enable_value_proj=False, + atten_tpl=atten_tpl, ) - atten_builder.atten_tpl.enable_shaped_attention = True - atten_builder.atten_tpl.enable_ctx_post_proj = False + if isinstance(atten_builder.atten_tpl, list): + atten_builder.atten_tpl = [ + atten_builder.atten_tpl[i].Set( + enable_shaped_attention=True, enable_ctx_post_proj=False + ) + for i in range(len(atten_builder.atten_tpl)) + ] + else: + atten_builder.atten_tpl.enable_shaped_attention = True + atten_builder.atten_tpl.enable_ctx_post_proj = False builder = atten_builder.Instantiate() if use_v1_stack: - p = builder.TransformerStack('atten', num_layers=3) + p = builder.TransformerStack('atten', num_layers=num_layers) else: p = builder.TransformerStackV2( 'atten', - num_layers=3, + num_layers=num_layers, final_layer_stride=stride, final_layer_first_n=first_n, )