Skip to content

Commit

Permalink
[Shardformer]fix the num_heads assert for llama model and qwen model (#…
Browse files Browse the repository at this point in the history
…5704)

* fix the num_heads assert

* fix the transformers import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the import

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
wangbluo and pre-commit-ci[bot] committed May 10, 2024
1 parent a3cc68c commit 537f6a3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 30 deletions.
10 changes: 5 additions & 5 deletions colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@

try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
except ImportError:
Qwen2Model = "Qwen2Model"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2Attention = "Qwen2Attention"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"

from transformers.utils import logging

Expand Down Expand Up @@ -451,10 +455,6 @@ def qwen2_for_sequence_classification_forward(


def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv

from colossalai.shardformer.layer import ColoAttention

def forward(
self: Qwen2Attention,
hidden_states: torch.Tensor,
Expand Down
8 changes: 5 additions & 3 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"):
assert (
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
Expand Down
49 changes: 27 additions & 22 deletions colossalai/shardformer/policies/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@
get_qwen2_flash_attention_forward,
get_qwen2_model_forward_for_flash_attn,
)

try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2FlashAttention2,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2Model,
Qwen2SdpaAttention,
)
except ImportError:
Qwen2ForCausalLM = "Qwen2ForCausalLM"
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
Qwen2Attention = "Qwen2Attention"
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
Qwen2SdpaAttention = "Qwen2SdpaAttention"
Qwen2DecoderLayer = "Qwen2DecoderLayer"
Qwen2Model = "Qwen2Model"

from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
Expand All @@ -45,21 +65,6 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2FlashAttention2,
Qwen2Model,
Qwen2SdpaAttention,
)
except ImportError:
Qwen2Attention = "Qwen2Attention"
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
Qwen2SdpaAttention = "Qwen2SdpaAttention"
Qwen2DecoderLayer = "Qwen2DecoderLayer"
Qwen2Model = "Qwen2Model"

ATTN_IMPLEMENTATION = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
Expand All @@ -82,6 +87,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"):
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
Expand Down Expand Up @@ -256,7 +268,6 @@ def get_held_layers(self) -> List[Module]:
class Qwen2ModelPolicy(Qwen2Policy):
def module_policy(self):
policy = super().module_policy()
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model

if self.pipeline_stage_manager:
# set None as default
Expand All @@ -277,10 +288,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:

class Qwen2ForCausalLMPolicy(Qwen2Policy):
def module_policy(self):
from transformers import Qwen2ForCausalLM

policy = super().module_policy()

setattr(self.shard_config, "causal_lm", True)

if self.shard_config.enable_tensor_parallelism:
Expand Down Expand Up @@ -330,10 +338,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:

class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
def module_policy(self):
from transformers import Qwen2ForSequenceClassification

policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
Expand Down

0 comments on commit 537f6a3

Please sign in to comment.