diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index e7679f0ec846..7710b56e7cd9 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,4 +1,3 @@ -import math import warnings from typing import List, Optional, Tuple, Union @@ -1005,115 +1004,6 @@ def bert_for_question_answering_forward( return {"hidden_states": hidden_states} -def get_bert_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.bert.modeling_bert import BertAttention - - def forward( - self: BertAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - final_attention_mask = None - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - final_attention_mask = relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - final_attention_mask = relative_position_scores_query + relative_position_scores_key - - scale = 1 / math.sqrt(self.attention_head_size) - if attention_mask is not None: - if final_attention_mask != None: - final_attention_mask = final_attention_mask * scale + attention_mask - else: - final_attention_mask = attention_mask - - if final_attention_mask is not None: - batch_size, src_len = query_layer.size()[0], query_layer.size()[2] - tgt_len = key_layer.size()[2] - final_attention_mask = final_attention_mask.expand( - batch_size, self.num_attention_heads, src_len, tgt_len - ).contiguous() - - query_layer = query_layer.permute(0, 2, 1, 3).contiguous() - key_layer = key_layer.permute(0, 2, 1, 3).contiguous() - value_layer = value_layer.permute(0, 2, 1, 3).contiguous() - - context_layer = me_attention( - query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale - ) - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, None) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - return forward - - def get_jit_fused_bert_self_output_forward(): from transformers.models.bert.modeling_bert import BertSelfOutput diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 1f34215c5175..1541436264e9 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -714,93 +714,6 @@ def bloom_for_question_answering_forward( return {"hidden_states": hidden_states} -def get_bloom_flash_attention_forward(enable_jit_fused=False): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - from transformers.models.bloom.modeling_bloom import BloomAttention - - def forward( - self: BloomAttention, - hidden_states: torch.Tensor, - residual: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - fused_qkv = self.query_key_value(hidden_states) - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, tgt_len, _, _ = query_layer.size() - - _, kv_length, _, _ = key_layer.size() - - proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim) - query_layer = query_layer.contiguous().view(*proj_shape) - key_layer = key_layer.contiguous().view(*proj_shape) - value_layer = value_layer.contiguous().view(*proj_shape) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, head_dim, kv_length] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) - - if use_cache is True: - present = (key_layer, value_layer) - else: - present = None - - tgt_len = key_layer.size()[1] - - attention_numerical_mask = torch.zeros( - (batch_size, self.num_heads, tgt_len, kv_length), - dtype=torch.float32, - device=query_layer.device, - requires_grad=True, - ) - attention_numerical_mask = ( - attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta - ) - attention_numerical_mask = torch.masked_fill( - attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min - ) - attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype) - - context_layer = me_attention( - query_layer, - key_layer, - value_layer, - attn_bias=attention_numerical_mask, - scale=self.inv_norm_factor, - p=self.attention_dropout.p, - ) - context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) - if self.pretraining_tp > 1 and self.slow_but_exact: - slices = self.hidden_size / self.pretraining_tp - output_tensor = torch.zeros_like(context_layer) - for i in range(self.pretraining_tp): - output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], - ) - else: - output_tensor = self.dense(context_layer) - - # TODO to replace with the bias_dropout_add function in jit - output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - outputs = (output_tensor, present, None) - - return outputs - - return forward - - def get_jit_fused_bloom_attention_forward(): from transformers.models.bloom.modeling_bloom import BloomAttention diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py index 26e0b224d3ab..49fce0556750 100644 --- a/colossalai/shardformer/modeling/sam.py +++ b/colossalai/shardformer/modeling/sam.py @@ -1,9 +1,4 @@ -import math -from typing import Tuple - import torch -import torch.nn.functional as F -from torch import Tensor def forward_fn(): @@ -45,163 +40,3 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch return outputs return forward - - -def get_sam_flash_attention_forward(): - from transformers.models.sam.modeling_sam import SamAttention - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor: - batch, point_batch_size, n_tokens, channel = hidden_states.shape - c_per_head = channel // num_attention_heads - hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) - return hidden_states - - def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor: - batch, n_tokens, n_heads, c_per_head = hidden_states.shape - return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) - - def forward( - self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None - ) -> Tensor: - # Input projections - query = self.q_proj(query) - key = self.k_proj(key) - value = self.v_proj(value) - - point_batch_size = query.shape[1] - # Separate into heads - query = _separate_heads(query, self.num_attention_heads) - key = _separate_heads(key, self.num_attention_heads) - value = _separate_heads(value, self.num_attention_heads) - - # SamAttention - _, _, _, c_per_head = query.shape - bias = None - if attention_similarity is not None: - bias = attention_similarity - - scale = 1.0 / math.sqrt(c_per_head) - out = me_attention(query, key, value, attn_bias=bias, scale=scale) - - out = _recombine_heads(out, point_batch_size) - out = self.out_proj(out) - - return out - - return forward - - -def get_sam_vision_flash_attention_forward(): - from transformers.models.sam.modeling_sam import SamVisionAttention - - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") - - def add_decomposed_rel_pos( - query: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], - ) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py - - Args: - attn (`torch.Tensor`): - attention map. - query (`torch.Tensor`): - query q in the attention layer with shape (batch_size, query_height * query_width, channel). - rel_pos_h (`torch.Tensor`): - relative position embeddings (Lh, channel) for height axis. - rel_pos_w (`torch.Tensor`): - relative position embeddings (Lw, channel) for width axis. - q_size (tuple): - spatial sequence size of query q with (query_height, query_width). - k_size (tuple): - spatial sequence size of key k with (key_height, key_width). - - Returns: - attn (`torch.Tensor`): - attention map with added relative positional embeddings. - """ - - query_height, query_width = q_size - key_height, key_width = k_size - relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h) - relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w) - - batch_size, _, nHead, dim = query.shape - reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) - rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) - rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] - rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width) - return rel_pos - - def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: - """ - Get relative positional embeddings according to the relative positions of - query and key sizes. - - Args: - q_size (int): - size of the query. - k_size (int): - size of key k. - rel_pos (`torch.Tensor`): - relative position embeddings (L, channel). - - Returns: - Extracted positional embeddings according to relative positions. - """ - max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos. - rel_pos_resized = F.interpolate( - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), - size=max_rel_dist, - mode="linear", - ) - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) - - # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - - return rel_pos_resized[relative_coords.long()] - - def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: - batch_size, height, width, _ = hidden_states.shape - # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = ( - self.qkv(hidden_states) - .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) - .permute(2, 0, 1, 3, 4) - ) - - query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0) - - rel_pos = None - if self.use_rel_pos: - rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)) - - attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale) - - attn_output = attn_output.reshape(batch_size, height, width, -1) - - attn_output = self.proj(attn_output) - - outputs = (attn_output, None) - - return outputs - - return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index c11ed99ac470..b84a372a5d5f 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -11,7 +11,6 @@ from ..modeling.bert import ( BertPipelineForwards, bert_sequence_parallel_forward_fn, - get_bert_flash_attention_forward, get_jit_fused_bert_intermediate_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, @@ -49,7 +48,6 @@ def module_policy(self): BertLayer, BertModel, BertOutput, - BertSelfAttention, BertSelfOutput, ) @@ -218,16 +216,6 @@ def module_policy(self): target_key=BertEmbeddings, ) - # use flash attention - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_bert_flash_attention_forward(), - }, - policy=policy, - target_key=BertSelfAttention, - ) - # use jit operator if self.shard_config.enable_jit_fused: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 20a75cf904a8..d80adb84a756 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -11,14 +11,13 @@ from ..modeling.bloom import ( BloomPipelineForwards, build_bloom_alibi_tensor_fn, - get_bloom_flash_attention_forward, get_bloom_sequence_parallel_forward_fn, get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, get_jit_fused_bloom_mlp_forward, get_lm_forward_with_dist_cross_entropy, ) -from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func +from ..modeling.jit import get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -165,16 +164,6 @@ def module_policy(self): target_key=BloomModel, ) - if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement( - description={ - "forward": get_bloom_flash_attention_forward(), - "dropout_add": get_dropout_add_func(), - }, - policy=policy, - target_key=BloomAttention, - ) - # enable jit fused operator if self.shard_config.enable_jit_fused: self.append_or_create_method_replacement( diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index c224d776957a..53faf8997f02 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -1,5 +1,3 @@ -import warnings - import colossalai.shardformer.layer as col_nn from ..modeling.sam import forward_fn @@ -212,24 +210,6 @@ def module_policy(self): target_key=SamTwoWayTransformer, ) - # use flash attention - if self.shard_config.enable_flash_attention: - warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.") - # self.append_or_create_method_replacement( - # description={ - # "forward": get_sam_flash_attention_forward(), - # }, - # policy=policy, - # target_key=SamAttention, - # ) - # self.append_or_create_method_replacement( - # description={ - # "forward": get_sam_vision_flash_attention_forward(), - # }, - # policy=policy, - # target_key=SamVisionAttention, - # ) - return policy def postprocess(self): diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index a42c7cc2eb99..00e1a13d6950 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -71,8 +71,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ✔️ ✔️ ✔️ - ✔️ - ✔️ + ❌ + ❌ ✔️ ✔️ ✔️ @@ -95,8 +95,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ✔️ ✔️ ✔️ - ✔️ - ✔️ + ❌ + ❌ ✔️ ✔️ ✔️ @@ -155,8 +155,8 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ✔️ ❌ ❌ - ✔️ - ✔️ + ❌ + ❌ ✔️ ✔️ ❌