Skip to content

Commit

Permalink
Merge branch 'karpathy:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
trholding committed Aug 6, 2023
2 parents df32a0d + 49e3ff6 commit 3121275
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 44 deletions.
12 changes: 6 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ run: run.c
$(CC) -O3 -o run run.c -lm

# useful for a debug build, can then e.g. analyze with valgrind, example:
# $ valgrind --leak-check=full ./run out/model.bin 1.0 3
# $ valgrind --leak-check=full ./run out/model.bin -n 3
rundebug: run.c
$(CC) -g -o run run.c -lm

# https://gcc.gnu.org/onlinedocs/gcc/Optimize-Options.html
# https://simonbyrne.github.io/notes/fastmath/
# -Ofast enables all -O3 optimizations.
# -Ofast enables all -O3 optimizations.
# Disregards strict standards compliance.
# It also enables optimizations that are not valid for all standard-compliant programs.
# It turns on -ffast-math, -fallow-store-data-races and the Fortran-specific
# -fstack-arrays, unless -fmax-stack-var-size is specified, and -fno-protect-parens.
# It also enables optimizations that are not valid for all standard-compliant programs.
# It turns on -ffast-math, -fallow-store-data-races and the Fortran-specific
# -fstack-arrays, unless -fmax-stack-var-size is specified, and -fno-protect-parens.
# It turns off -fsemantic-interposition.
# In our specific application this is *probably* okay to use
.PHONY: runfast
Expand All @@ -47,7 +47,7 @@ runoacc: run.c
$(CC) -D OPENACC -Ofast -fopenacc -foffload-options="-Ofast -lm" -march=native run.c -lm -o run

.PHONY: win64
win64:
win64:
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c

# compiles with gnu99 standard flags for amazon linux, coreos, etc. compatibility
Expand Down
16 changes: 8 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def apply_rotary_emb(
# reshape xq and xk to match the complex representation
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

# reshape freqs_cos and freqs_sin for broadcasting
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
Expand Down Expand Up @@ -154,7 +154,7 @@ def forward(

# restore time as batch dimension and concat heads
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

# final projection into the residual stream
output = self.wo(output)
output = self.resid_dropout(output)
Expand All @@ -170,7 +170,7 @@ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__(self, params: ModelArgs):
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
Expand Down Expand Up @@ -304,7 +304,7 @@ def estimate_mfu(self, fwdbwd_per_iter, dt):
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu

@torch.inference_mode()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Expand Down Expand Up @@ -334,7 +334,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

return idx

def export(self, filepath='model.bin'):
Expand All @@ -350,13 +350,13 @@ def serialize(t):
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
p = self.params
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, p.vocab_size, p.max_seq_len)
f.write(header)

# next write out the embedding weights
serialize(self.tok_embeddings.weight)

# now all the layers
# attention weights
for layer in self.layers:
Expand Down
14 changes: 7 additions & 7 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ void malloc_run_state(RunState* s, Config* p) {
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
printf("malloc failed!\n");
exit(EXIT_FAILURE);
Expand Down Expand Up @@ -330,7 +330,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
float* value_cache_row = s->value_cache + loff + pos * dim;
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));

// multihead attention. iterate over all heads
int h;
#ifdef ACCEL
Expand Down Expand Up @@ -386,7 +386,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// first calculate self.w1(x) and self.w3(x)
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);

// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
for (int i = 0; i < hidden_dim; i++) {
s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i])));
Expand All @@ -403,7 +403,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// residual connection
accum(x, s->xb, dim);
}

// final rmsnorm
rmsnorm(x, x, w->rms_final_weight, dim);

Expand All @@ -425,7 +425,7 @@ int str_lookup(char *str, char **vocab, int vocab_size) {
}

void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {

// a temporary buffer to merge two consecutive tokens
char* str_buffer = malloc((max_token_length*2+1) * sizeof(char)); // *2 for concat, +1 for null terminator

Expand Down Expand Up @@ -631,7 +631,7 @@ int main(int argc, char *argv[]) {
int *prompt_tokens = NULL;
int num_prompt_tokens = 0;
if (prompt != NULL) {
prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));
prompt_tokens = (int*)malloc(strlen(prompt) * sizeof(int));
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
}

Expand Down
48 changes: 25 additions & 23 deletions tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import random
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor

import numpy as np
import requests
Expand Down Expand Up @@ -66,34 +66,35 @@ def download():
print(f"Number of shards: {len(shard_filenames)}")
print(f"Example story:\n{data[0]}")

def pretokenize():

def process_shard(args):
shard_id, shard = args
enc = Tokenizer()
with open(shard, "r") as f:
data = json.load(f)
all_tokens = []
for example in tqdm(data, position=shard_id):
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
all_tokens.extend(tokens)
# convert to uint16 nparray
all_tokens = np.array(all_tokens, dtype=np.uint16)
# write to disk
tokenized_filename = shard.replace(".json", ".bin")
with open(tokenized_filename, "wb") as f:
f.write(all_tokens.tobytes())
print(f"Saved {tokenized_filename}")

def process_shard(shard):
with open(shard, "r") as f:
data = json.load(f)
all_tokens = []
for example in tqdm(data):
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
all_tokens.extend(tokens)
# convert to uint16 nparray
all_tokens = np.array(all_tokens, dtype=np.uint16)
# write to disk
tokenized_filename = shard.replace(".json", ".bin")
with open(tokenized_filename, "wb") as f:
f.write(all_tokens.tobytes())
print(f"Saved {tokenized_filename}")

def pretokenize():
# iterate the shards and tokenize all of them one by one
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))

# process all the shards in a threadpool
with ThreadPoolExecutor(max_workers=8) as executor:
executor.map(process_shard, shard_filenames)

# process all the shards in a process pool
with ProcessPoolExecutor() as executor:
executor.map(process_shard, enumerate(shard_filenames))
print("Done.")


Expand Down Expand Up @@ -163,4 +164,5 @@ def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
"download": download,
"pretokenize": pretokenize,
}
fun[args.stage]()
fun[args.stage]()

0 comments on commit 3121275

Please sign in to comment.