Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
trholding committed Aug 26, 2023
2 parents d311a72 + 49daf18 commit a124fd3
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 11 deletions.
81 changes: 81 additions & 0 deletions test.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#define TESTING
#include "run.c"

void assert_eq(int a, int b) {
if (a != b) {
printf("Assertion failed: %d != %d\n", a, b);
exit(EXIT_FAILURE);
}
}

void test_prompt_encoding(Tokenizer* tokenizer, char* prompt, int* expected_tokens, int num_expected_tokens) {
// encode
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int));
int num_prompt_tokens = 0; // the total number of prompt tokens
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);

#if VERBOSITY == 1
// print maybe
printf("expected tokens:\n");
for (int i = 0; i < num_expected_tokens; i++) printf("%d ", expected_tokens[i]);
printf("\n");
printf("actual tokens:\n");
for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]);
printf("\n");
#endif

// verify
assert_eq(num_prompt_tokens, num_expected_tokens);
for (int i = 0; i < num_prompt_tokens; i++) {
assert_eq(prompt_tokens[i], expected_tokens[i]);
}

#if VERBOSITY == 1
printf("OK\n");
printf("---\n");
#endif
free(prompt_tokens);
}

void test_prompt_encodings() {
// let's verify that the Tokenizer works as expected

char *tokenizer_path = "tokenizer.bin";
int vocab_size = 32000;
Tokenizer tokenizer;
build_tokenizer(&tokenizer, tokenizer_path, vocab_size);

// test 0 (test the empty string) (I added this as a simple case)
char *prompt0 = "";
int expected_tokens0[] = {1};
test_prompt_encoding(&tokenizer, prompt0, expected_tokens0, sizeof(expected_tokens0) / sizeof(int));

// the tests below are taken from the Meta Llama 2 repo example code
// https://github.com/facebookresearch/llama/blob/main/example_text_completion.py
// and the expected tokens come from me breaking in the debugger in Python

// test 1
char *prompt = "I believe the meaning of life is";
int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338};
test_prompt_encoding(&tokenizer, prompt, expected_tokens, sizeof(expected_tokens) / sizeof(int));

// test 2
char* prompt2 = "Simply put, the theory of relativity states that ";
int expected_tokens2[] = {1, 3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871};
test_prompt_encoding(&tokenizer, prompt2, expected_tokens2, sizeof(expected_tokens2) / sizeof(int));

// test 3
char* prompt3 = "A brief message congratulating the team on the launch:\n\n Hi everyone,\n\n I just ";
int expected_tokens3[] = {1, 319, 11473, 2643, 378, 629, 271, 18099, 278, 3815, 373, 278, 6826, 29901, 13, 13, 4706, 6324, 14332, 29892, 13, 13, 4706, 306, 925, 29871};
test_prompt_encoding(&tokenizer, prompt3, expected_tokens3, sizeof(expected_tokens3) / sizeof(int));

// test 4
char* prompt4 = "Translate English to French:\n\n sea otter => loutre de mer\n peppermint => menthe poivrée\n plush girafe => girafe peluche\n cheese =>";
int expected_tokens4[] = {1, 4103, 9632, 4223, 304, 5176, 29901, 13, 13, 4706, 7205, 4932, 357, 1149, 301, 449, 276, 316, 2778, 13, 4706, 1236, 407, 837, 524, 1149, 6042, 354, 772, 440, 29878, 1318, 13, 4706, 715, 1878, 330, 3055, 1725, 1149, 330, 3055, 1725, 4639, 28754, 13, 4706, 923, 968, 1149};
test_prompt_encoding(&tokenizer, prompt4, expected_tokens4, sizeof(expected_tokens4) / sizeof(int));
}

int main(int argc, char *argv[]) {
test_prompt_encodings();
printf("ALL OK\n");
}
26 changes: 16 additions & 10 deletions tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import requests
import sentencepiece as spm
import torch
import torch.distributed as dist
from tqdm import tqdm
Expand Down Expand Up @@ -97,16 +98,21 @@ def train_vocab(vocab_size):
of.write(text + "\n")
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")

# 2) run the train_vocab.sh script that trains the sentencepiece model
print("Will now train the vocab with:")
cmd = f"bash train_vocab.sh {tiny_file} {prefix} {vocab_size}"
print(cmd)
print("OK? [y/N] ")
dec = input()
if dec.lower() != "y":
print("Exiting...")
return
os.system(cmd)
# 2) train the sentencepiece model
print("Will now train the vocab...")
spm.SentencePieceTrainer.train(input=tiny_file,
model_prefix=prefix,
model_type="bpe",
vocab_size=vocab_size,
self_test_sample_size=0,
input_format="text",
character_coverage=1.0,
num_threads=os.cpu_count(),
split_digits=True,
allow_whitespace_only_pieces=True,
byte_fallback=True,
unk_surface=r" \342\201\207 ",
normalization_rule_name="identity")

# 3) optional cleanup, ask the user if they'd like to delete tiny.txt
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def get_lr(it):
"loss/val": losses["val"],
"lr": lr,
"mfu": running_mfu * 100, # convert to percentage
}
}, step = iter_num
)
except Exception as e:
print(f"logging to wandb failed: {e}")
Expand Down

0 comments on commit a124fd3

Please sign in to comment.