Skip to content

Commit

Permalink
Manual merges - CLI arg parse
Browse files Browse the repository at this point in the history
run.c

cli arg parse - manually added a few changes from upstream

Need to fix this later to fully reflect upstream.
  • Loading branch information
trholding committed Aug 5, 2023
1 parent f0e5f61 commit df32a0d
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ typedef struct {
// final rmsnorm
float* rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
float* freq_cis_real; // (seq_len, dim/2)
float* freq_cis_imag; // (seq_len, dim/2)
float* freq_cis_real; // (seq_len, head_size/2)
float* freq_cis_imag; // (seq_len, head_size/2)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
} TransformerWeights;
Expand Down Expand Up @@ -524,10 +524,11 @@ int argmax(float* v, int n) {

int main(int argc, char *argv[]) {

// poor man's C argparse
// default inits
char *checkpoint = NULL; // e.g. out/model.bin
float temperature = 0.9f; // e.g. 1.0, or 0.0
int steps = 256; // max number of steps to run for, 0: use seq_len
float temperature = 0.9f; // 0.0 = greedy & deterministic, 1.0 = max uncertainty
rng_seed = (unsigned int)time(NULL); // seed rng with time by default
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
int buffertokens = 1; // output token buffer size

Expand All @@ -542,34 +543,33 @@ int main(int argc, char *argv[]) {
gets(promptbuffer); // Read prompt
prompt=promptbuffer; // Set prompt
#else
// 'checkpoint' is necessary arg
// poor man's C argparse so we can override the defaults above from the command line
if (argc < 2) {
printf("Usage: %s <checkpoint_file> \n", argv[0]);
exit(EXIT_FAILURE);
}
if (argc >= 2) { checkpoint = argv[1]; }
for (int i = 2; i < argc; i++) {
// do some basic validation
switch (argv[i][0]) {
case '-':
switch (argv[i][1]) {
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
case 't': if (i + 1 < argc) { temperature = atof(argv[++i]); } break;
case 's': if (i + 1 < argc) { steps = atoi(argv[++i]); } break;
case 't': if (i + 1 < argc) { temperature = atof(argv[++i]); } break;
case 's': if (i + 1 < argc) { rng_seed = atoi(argv[++i]); } break;
case 'n': if (i + 1 < argc) { steps = atoi(argv[++i]); } break;
case 'b': if (i + 1 < argc) { buffertokens = atoi(argv[++i]); } break;
case 'p': if (i + 1 < argc) { prompt = argv[++i]; } break;
default: printf("Invalid option: %s\n", argv[i]);
exit(EXIT_FAILURE);
} break;
default:
printf("Usage: %s <checkpoint_file> -t [temperature] -s [steps] -b [buffertokens] -p [prompt] \n", argv[0]);
printf("Usage: %s <checkpoint_file> -t [temperature] -s [seed] -n [steps] -b [buffertokens] -p [prompt] \n", argv[0]);
exit(EXIT_FAILURE);
}
}
#endif

// seed rng with time. if you want deterministic behavior use temperature 0.0
rng_seed = (unsigned int)time(NULL);

// read in the model.bin file
Config config;
TransformerWeights weights;
Expand Down

0 comments on commit df32a0d

Please sign in to comment.