Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
trholding committed Aug 21, 2023
2 parents 548bf5d + ee95b1b commit 06f25f6
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 122 deletions.
80 changes: 75 additions & 5 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import shutil
import struct
import argparse
import json
from pathlib import Path

import numpy as np
import torch
from torch import nn
Expand All @@ -30,7 +33,7 @@

def serialize_fp32(file, tensor):
""" writes one fp32 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1).numpy().astype(np.float32)
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
b = struct.pack(f'{len(d)}f', *d)
file.write(b)

Expand Down Expand Up @@ -281,6 +284,71 @@ def load_checkpoint(checkpoint):
model.eval()
return model

def load_meta_model(model_path):
params_path = os.path.join(model_path, 'params.json')
with open(params_path) as f:
params = json.load(f)
print(params)

model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
models = [torch.load(p, map_location='cpu') for p in model_paths]

def concat_weights(models):
state_dict = {}
for name in list(models[0]):
tensors = [model[name] for model in models]
if len(tensors) == 1 or len(tensors[0].shape) == 1:
state_dict[name] = tensors[0]
continue
is_axis_1 = (
name.startswith('tok_embeddings.')
or name.endswith('.attention.wo.weight')
or name.endswith('.feed_forward.w2.weight')
)
axis = 1 if is_axis_1 else 0
state_dict[name] = torch.cat(tensors, dim=axis)
for model in models:
del model[name]
return state_dict

state_dict = concat_weights(models)
del models

# set ModelArgs
config = ModelArgs()
config.dim = params["dim"]
config.n_layers = params["n_layers"]
config.n_heads = params["n_heads"]
config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
config.multiple_of = params["multiple_of"]
config.norm_eps = params["norm_eps"]

config.vocab_size = 32000
config.max_seq_len = 2048

# create a new Transformer object and set weights
model = Transformer(config)

model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
model.norm.weight = nn.Parameter(state_dict['norm.weight'])

for layer in model.layers:
i = layer.layer_id
layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])

# final classifier
model.output.weight = nn.Parameter(state_dict['output.weight'])
model.eval()
return model

def load_hf_model(model_path):

try:
Expand Down Expand Up @@ -381,17 +449,19 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):

parser = argparse.ArgumentParser()
parser.add_argument("filepath", type=str, help="the output filepath")
parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
parser.add_argument("--hf", type=str, help="huggingface model")
parser.add_argument("--version", default=0, type=int, help="the version to export with")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
group.add_argument("--meta-llama", type=str, help="meta llama model path")
group.add_argument("--hf", type=str, help="huggingface model path")
args = parser.parse_args()

if args.checkpoint:
model = load_checkpoint(args.checkpoint)
elif args.meta_llama:
model = load_meta_model(args.meta_llama)
elif args.hf:
model = load_hf_model(args.hf)
else:
parser.error("Input model missing: --checkpoint or --hf is required")

if model is None:
parser.error("Can't load input model!")
Expand Down
112 changes: 0 additions & 112 deletions export_meta_llama_bin.py

This file was deleted.

12 changes: 8 additions & 4 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,8 @@ void error_usage() {
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
fprintf(stderr, "Options:\n");
fprintf(stderr, " -t <float> temperature, default 1.0\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling. default 0.9\n");
fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
fprintf(stderr, " -b <int> number of tokens to buffer, default 1. 0 = max_seq_len\n");
Expand All @@ -860,7 +860,7 @@ void error_usage() {

int main(int argc, char *argv[]) {

// default inits
// default parameters
char *checkpoint_path = NULL; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
Expand Down Expand Up @@ -906,7 +906,11 @@ int main(int argc, char *argv[]) {
}
#endif

if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
// parameter validation/overrides
if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
if (temperature < 0.0) temperature = 0.0;
if (topp < 0.0 || 1.0 < topp) topp = 0.9;
if (steps <= 0) steps = 0;

// build the Transformer via the model .bin file
Transformer transformer;
Expand Down
2 changes: 1 addition & 1 deletion sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
model = torch.compile(model) # requires PyTorch 2.0 (optional)

# load the tokenizer
vocab_source = checkpoint_dict.get("vocab_source", "llama2")
vocab_source = checkpoint_dict["config"].get("vocab_source", "llama2")
vocab_size = gptconf.vocab_size
if tokenizer:
# a specific tokenizer is provided, use it
Expand Down

0 comments on commit 06f25f6

Please sign in to comment.