Skip to content

Commit

Permalink
run.c - arg parse fix
Browse files Browse the repository at this point in the history
Fixes segfault when last parameter is missing
  • Loading branch information
trholding committed Aug 2, 2023
1 parent a8b8b85 commit 8a09a40
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ out/model.bin
run
run.com.dbg
run.com
a.out
21 changes: 10 additions & 11 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -527,12 +527,12 @@ int main(int argc, char *argv[]) {
switch (argv[i][0]) {
case '-':
switch (argv[i][1]) {
case 'c': checkpoint = argv[++i]; break;
case 't': temperature = atof(argv[++i]); break;
case 's': steps = atoi(argv[++i]); break;
case 'b': buffertokens = atoi(argv[++i]); break;
case 'p': prompt = argv[++i]; break;
default: printf("Invalid option: %s\n", argv[i]);
case 'c': if (i + 1 < argc) { checkpoint = argv[++i]; } break;
case 't': if (i + 1 < argc) { temperature = atof(argv[++i]); } break;
case 's': if (i + 1 < argc) { steps = atoi(argv[++i]); } break;
case 'b': if (i + 1 < argc) { buffertokens = atoi(argv[++i]);} break;
case 'p': if (i + 1 < argc) { prompt = argv[++i]; } break;
default: printf("Invalid option: %s\n", argv[i]);
exit(EXIT_FAILURE);
} break;
default:
Expand Down Expand Up @@ -620,11 +620,10 @@ int main(int argc, char *argv[]) {
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
int bufferflush = 1; // buffer flush after token counter
#define MAX_BUFFER_SIZE 4096 // max buffer size
char outbuff[MAX_BUFFER_SIZE]=""; // used for output buffering
memset( outbuff, '\0', sizeof( outbuff )); // clear buffer area
#define BUFFER_SIZE 4096 // max buffer size
static char outbuff[BUFFER_SIZE]; // used for output buffering
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
setvbuf(stdout, outbuff, _IOFBF, MAX_BUFFER_SIZE); // setup output buffering
setvbuf(stdout, outbuff, _IOFBF, BUFFER_SIZE); // setup output buffering
while (pos < steps) {

// forward the transformer to get logits for the next token
Expand Down Expand Up @@ -652,7 +651,7 @@ int main(int argc, char *argv[]) {
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];

printf("%s", token_str);
if (bufferflush==pos && strlen(outbuff)<=MAX_BUFFER_SIZE) { fflush(stdout); bufferflush+=buffertokens; } // flush after every n tokens
if (bufferflush==pos && strlen(outbuff)<=BUFFER_SIZE) { fflush(stdout); bufferflush+=buffertokens; } // flush after every n tokens

// advance forward
token = next;
Expand Down

0 comments on commit 8a09a40

Please sign in to comment.