Skip to content

Commit

Permalink
reformat llama transfer script, retrieve model info from LlamaConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyu05 committed Sep 14, 2023
1 parent 86dcc48 commit fbb2ef9
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 98 deletions.
6 changes: 4 additions & 2 deletions tools/checkpoint_saver_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,15 @@ def check_message(msg):
setattr(margs, arg, value)

validate_args(margs)

margs.ckpt_transfer = True
if args.tokenizer_model:
margs.tokenizer_model = args.tokenizer_model
set_global_variables(margs)

# margs = megatron args
margs = get_args()
margs.ckpt_transfer = True

print("args.tokenizer_model", args.tokenizer_model)
if hasattr(md, 'consumed_train_samples'):
margs.consumed_train_samples = md.consumed_train_samples
margs.consumed_valid_samples = md.consumed_valid_samples
Expand Down
21 changes: 21 additions & 0 deletions tools/convert_checkpoint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,24 @@ cd /hf/transformers
python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py \
/path/to/Megatron/checkpoint/iter_0097500/mp_rank_00/model_optim_rng.pt
```

## HF Transformers to Megatron-DeepSpeed (currently only support LLama)

In order to convert llama model from HF Transformers to Megatron-DeepSpeed, you can do this by two steps:

```bash
# 1. Convert llama weight from hf to megatron
python tools/convert_checkpoint/transformers_to_megatron_llama.py \
--out=/path/to/Megatron-Deepspeed/checkpoint/ \
--cache-dir=/path/to/hf/transformers/llama_checkpoint

# 2. Convert Megatron-DeepSpeed checkpoint to distributed version
python3 tools/checkpoint_util.py \
--target-tensor-parallel-size 4 \
--target-pipeline-parallel-size 2 \
--load-dir /path/to/Megatron-Deepspeed/checkpoint/ \
--save-dir /path/to/Megatron-Deepspeed/distribute_checkpoint/ \
--model-type GPT
```


Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,9 @@ def merge_meta_llama(size: int, root_dir: Path):
return merged_ckpt


def merge_hf_llama(size: int, version: int, cache_dir: Optional[Path] = None, model_path=None, tokenizer_len=32000):
assert version == 2, "Only llama v2 available using huggingface"
print(cache_dir)
def merge_hf_llama(cache_dir: Optional[Path] = None):
# assert version == 2, "Only llama v2 available using huggingface"
model = LlamaForCausalLM.from_pretrained(cache_dir, cache_dir=cache_dir, local_files_only=True, use_safetensors=False)
# resize token embeddings size according saved tokenizer for model extend token size.
# model.resize_token_embeddings(tokenizer_len)
weights = model.state_dict()
weights["tok_embeddings.weight"] = weights.pop("model.embed_tokens.weight")
weights["norm.weight"] = weights.pop("model.norm.weight")
Expand All @@ -110,12 +107,5 @@ def merge_hf_llama(size: int, version: int, cache_dir: Optional[Path] = None, mo
"post_attention_layernorm": "ffn_norm"
}[rmatch.group(2)]
weights[rmatch.group(1) + new_key + rmatch.group(3)] = weights.pop(key)
return weights
return weights, model.config


def merge_llama(size: int, version: int, root_dir: Optional[Path] = None, tokenizer_len: Optional[int] = 32000):
if root_dir is not None and (root_dir/"consolidated.00.pth").exists():
return merge_meta_llama(size, root_dir), "meta"
print(f"Weights at {root_dir} do not look like a meta checkpoint, assuming "
"huggingface cache_dir instead")
return merge_hf_llama(size, version, root_dir, tokenizer_len), "hf"
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,14 @@
import torch
from tqdm.auto import trange
from transformers import AutoModelForCausalLM, LlamaTokenizer
from transformers import LlamaConfig

from permute_qkv import permute_qkv
from merge_llama import merge_llama
from transformers import AutoTokenizer
from merge_llama import merge_hf_llama

llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
llama_s2hidden = {7: 4096, 13: 5120, 30: 6656, 65: 8192, 70: 8192}


def llama_to_megatron(weights: dict, size: int, source: str = "meta",
version: int = 1) -> dict:
def llama_to_megatron(weights: dict, llama_config: LlamaConfig = None) -> dict:
def permute(qkv_w):
if source == "hf":
return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads)
return qkv_w
return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads)

def rearrange_qkv(wq, wk, wv):
wq = torch.split(wq, n_hidden_per_head, dim=0)
Expand All @@ -42,12 +32,11 @@ def rearrange_qkv(wq, wk, wv):
return permute(torch.concat(w_qkv))

# config
n_layer = llama_s2layer[size]
hidden = llama_s2hidden[size]
n_heads = llama_s2heads[size]
n_layer = llama_config.num_hidden_layers
hidden = llama_config.hidden_size
n_heads = llama_config.num_attention_heads
n_hidden_per_head = hidden//n_heads
n_kv_heads = n_heads if version == 1 or size <= 13 else 8

n_kv_heads = llama_config.num_key_value_heads
# weights independent of layers
embedding = {"word_embeddings": {"weight": weights["tok_embeddings.weight"]}}
transformer = {"final_layernorm.weight": weights["norm.weight"]}
Expand Down Expand Up @@ -86,32 +75,34 @@ def rearrange_qkv(wq, wk, wv):
return {"embedding": embedding, "encoder": transformer,
"lm_head": lm_head}

def main(model_name: str = "llama2", size: int = 7, out: Optional[Path] = None,
cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None, padded_vocab_size: Optional[int] = 32000):
def main(out: Optional[Path] = None,
cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None):

if megatron_path:
print("Add megatron to os path")
os.path.append(megatron_path)
# get weights from or specified directory
print("Getting llama...")
version = 2 if "2" in model_name else 1
hf_weights, llama_source = merge_llama(size, version, cache_dir, padded_vocab_size)
hf_weights, llama_config = merge_hf_llama(cache_dir)

# convert state dict to be megatron-compatible
megatron_weights = llama_to_megatron(hf_weights, size, llama_source,
version=1 if model_name == "llama" else 2)
megatron_weights = llama_to_megatron(hf_weights, llama_config=llama_config)

# set args
# llama1, llama2
args = {"num_layers": llama_s2layer[size],
"hidden_size": llama_s2hidden[size],
"num_attention_heads": llama_s2heads[size],
"ffn_hidden_size": llama_s2dense[size],
"num_key_value_heads": llama_s2heads[size],
args = {"num_layers": llama_config.num_hidden_layers,
"hidden_size": llama_config.hidden_size,
"num_attention_heads": llama_config.num_attention_heads,
"ffn_hidden_size": llama_config.intermediate_size,
"num_key_value_heads": llama_config.num_key_value_heads,
"parallel_attn": False,
"make_vocab_size_divisible_by": 1,
"glu_activation": "swiglu",
"max_position_embeddings": llama_config.max_length, # should use max_length rather than max_position_embeddings, detail in https://github.com/lm-sys/FastChat/issues/2046#issuecomment-1645265800
"seq_length": llama_config.max_length,
"layernorm_epsilon": llama_config.rms_norm_eps,
# llama args
"padded_vocab_size": padded_vocab_size,
"use_rms_norm": True,
"tie_embed_logits": False,
"padded_vocab_size": llama_config.vocab_size,
"tokenizer_type": "GPTSentencePieceTokenizer",
"no-query-key-layer-scaling": True,
"attention-dropout": 0,
Expand All @@ -124,19 +115,13 @@ def main(model_name: str = "llama2", size: int = 7, out: Optional[Path] = None,
"add_position_embedding": False,
"add_bias_linear": False,
}
if model_name == "llama":
args.update({"max_position_embeddings": 2048, "seq_length": 2048,
"layernorm_epsilon": 1e-6})
else: # llama2
args.update({"max_position_embeddings": 2048, "seq_length": 2048,
"layernorm_epsilon": 1e-5})
if size >= 34:
args.update({"num_attention_heads_kv": 8})
if llama_config.num_key_value_heads:
args.update({"num_attention_heads_kv": llama_config.num_key_value_heads})

args.update({
"tensor_model_parallel_size": 1,
"pipeline_model_parallel_size": 1,
"iteration": "release",
"iteration": 0,
"bias_gelu_fusion": False,
"bias_droput_fusion": False,
})
Expand All @@ -145,42 +130,31 @@ def main(model_name: str = "llama2", size: int = 7, out: Optional[Path] = None,
(out/"release"/"mp_rank_00").mkdir(parents=True)
with open(out/"latest_checkpointed_iteration.txt", "w+") as f:
f.write("release")
final_dict = {"iteration": "release", "model": {"language_model": megatron_weights},
final_dict = {"iteration": 'release', "model": {"language_model": megatron_weights},
"checkpoint_version": 3.0, "args": Namespace(**args)}
torch.save(final_dict, out/"release"/"mp_rank_00"/"model_optim_rng.pt")
print("Saved weights in", out)

if model_name == "llama2" and llama_source == "hf":
tokenizer = LlamaTokenizer.from_pretrained(
cache_dir, cache_dir=cache_dir, local_files_only=True,
)
token_path = out/"tokenizer.model"
vocab_file = tokenizer.vocab_file
shutil.copy(vocab_file, token_path)
print("Saved tokenizer.model in", token_path)
tokenizer = LlamaTokenizer.from_pretrained(
cache_dir, cache_dir=cache_dir, local_files_only=True,
)
token_path = out/"tokenizer.model"
vocab_file = tokenizer.vocab_file
shutil.copy(vocab_file, token_path)
print("Saved tokenizer.model in", token_path)
print("Done")

if __name__ == "__main__":
parser = ArgumentParser(description="Convert Huggingface falcon weights to "
parser = ArgumentParser(description="Convert Huggingface llama weights to "
"megatron-compatible weights")
parser.add_argument("model", choices={"falcon", "llama", "llama2"})
parser.add_argument("--size", default=7, choices={7, 13, 30, 34, 40, 65, 70}, type=int,
help="The size of the model")
parser.add_argument("--out", type=Path,
help="Directory to store the megatron weights (as checkpoint)")
parser.add_argument("--cache-dir", type=Path,
help=("Directory to store the huggingface weights, or "
"in case of the llama model, where to look for "
"the consolidated.xx.pth"))
parser.add_argument("--megatron-path", type=Path,
parser.add_argument("--megatron-path", type=Path, default=None,
help="Path where to find megatron code")
parser.add_argument("--tokenizer-size", type=int, help="Directory to store the megatron weights (as checkpoint)", default=None)
args = parser.parse_args()

# small arg verification
if args.model == "llama":
assert args.size in {7, 13, 30, 65}
else:
assert args.size in {7, 13, 70}

main(args.model, args.size, args.out, args.cache_dir, args.megatron_path, args.tokenizer_size)
main(args.out, args.cache_dir, args.megatron_path)
19 changes: 0 additions & 19 deletions tools/convert_checkpoint/weights2megatron/README.md

This file was deleted.

0 comments on commit fbb2ef9

Please sign in to comment.