From 1ebb27f090e10117964f1fa54a0be32d10a5a6e1 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Sun, 27 Aug 2023 12:21:11 +0300 Subject: [PATCH 1/3] Do parameter count calculations in 64 bits to not overflow in case of very large models --- run.c | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/run.c b/run.c index 9329b932..fdbb16d8 100644 --- a/run.c +++ b/run.c @@ -115,26 +115,28 @@ void free_run_state(RunState* s) { void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { int head_size = p->dim / p->n_heads; + // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models + unsigned long n_layers = p->n_layers; w->token_embedding_table = ptr; ptr += p->vocab_size * p->dim; w->rms_att_weight = ptr; - ptr += p->n_layers * p->dim; + ptr += n_layers * p->dim; w->wq = ptr; - ptr += p->n_layers * p->dim * (p->n_heads * head_size); + ptr += n_layers * p->dim * (p->n_heads * head_size); w->wk = ptr; - ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); + ptr += n_layers * p->dim * (p->n_kv_heads * head_size); w->wv = ptr; - ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size); + ptr += n_layers * p->dim * (p->n_kv_heads * head_size); w->wo = ptr; - ptr += p->n_layers * (p->n_heads * head_size) * p->dim; + ptr += n_layers * (p->n_heads * head_size) * p->dim; w->rms_ffn_weight = ptr; - ptr += p->n_layers * p->dim; + ptr += n_layers * p->dim; w->w1 = ptr; - ptr += p->n_layers * p->dim * p->hidden_dim; + ptr += n_layers * p->dim * p->hidden_dim; w->w2 = ptr; - ptr += p->n_layers * p->hidden_dim * p->dim; + ptr += n_layers * p->hidden_dim * p->dim; w->w3 = ptr; - ptr += p->n_layers * p->dim * p->hidden_dim; + ptr += n_layers * p->dim * p->hidden_dim; w->rms_final_weight = ptr; ptr += p->dim; ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE) @@ -249,7 +251,7 @@ float* forward(Transformer* transformer, int token, int pos) { memcpy(x, content_row, dim*sizeof(*x)); // forward all the layers - for(int l = 0; l < p->n_layers; l++) { + for(unsigned long l = 0; l < p->n_layers; l++) { // attention rmsnorm rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); From c5ec6e21b8659d6d3500a2af3ac1dfe7f3e19ae1 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Tue, 29 Aug 2023 17:47:55 +0300 Subject: [PATCH 2/3] Use long long so it works with MSVC --- run.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run.c b/run.c index fdbb16d8..c6ec94a2 100644 --- a/run.c +++ b/run.c @@ -116,7 +116,7 @@ void free_run_state(RunState* s) { void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { int head_size = p->dim / p->n_heads; // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models - unsigned long n_layers = p->n_layers; + unsigned long long n_layers = p->n_layers; w->token_embedding_table = ptr; ptr += p->vocab_size * p->dim; w->rms_att_weight = ptr; @@ -251,7 +251,7 @@ float* forward(Transformer* transformer, int token, int pos) { memcpy(x, content_row, dim*sizeof(*x)); // forward all the layers - for(unsigned long l = 0; l < p->n_layers; l++) { + for(unsigned long long l = 0; l < p->n_layers; l++) { // attention rmsnorm rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); From ab19aa08045f0f30db4291641ece301d7cc339f3 Mon Sep 17 00:00:00 2001 From: Brandon Rowlett Date: Wed, 30 Aug 2023 14:54:41 -0500 Subject: [PATCH 3/3] Setting UTF encoding, otherwise windows breaks with UnicodeEncodeError: 'charmap' codec can't encode character '\u200b' in position 971: character maps to --- tinystories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinystories.py b/tinystories.py index 800d73af..814732d4 100644 --- a/tinystories.py +++ b/tinystories.py @@ -88,7 +88,7 @@ def train_vocab(vocab_size): shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) print(f"Writing temporary file {tiny_file} with {num_shards} shards...") - with open(tiny_file, "w") as of: + with open(tiny_file, "w", encoding="utf-8") as of: for shard in tqdm(shard_filenames[:num_shards]): with open(shard, "r") as f: data = json.load(f)