Skip to content

Commit

Permalink
run.c - Various fixes
Browse files Browse the repository at this point in the history
Various fixes from up streaming effort:

- Better arg parse
- Buffering fixes
  • Loading branch information
trholding committed Aug 4, 2023
1 parent 7bf7eb7 commit af5c2c7
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,13 @@ int main(int argc, char *argv[]) {
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
case 't': if (i + 1 < argc) { temperature = atof(argv[++i]); } break;
case 's': if (i + 1 < argc) { steps = atoi(argv[++i]); } break;
case 'b': if (i + 1 < argc) { buffertokens = atoi(argv[++i]); } break;
case 'p': if (i + 1 < argc) { prompt = argv[++i]; } break;
default: printf("Invalid option: %s\n", argv[i]);
exit(EXIT_FAILURE);
} break;
default:
printf("Usage: %s <checkpoint_file> -t [temperature] -s [steps] -p [prompt] \n", argv[0]);
printf("Usage: %s <checkpoint_file> -t [temperature] -s [steps] -b [buffertokens] -p [prompt] \n", argv[0]);
exit(EXIT_FAILURE);
}
}
Expand Down Expand Up @@ -620,11 +621,15 @@ int main(int argc, char *argv[]) {
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
int bufferflush = 1; // buffer flush after token counter
#define BUFFER_SIZE 4096 // max buffer size
static char outbuff[BUFFER_SIZE]; // used for output buffering
int bufferflush = 1; // token counter for flushing buffer
char outbuff[4096 * (6 + 2)] ; // buffersize is context length * average size of subwords + margin
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
setvbuf(stdout, outbuff, _IOFBF, BUFFER_SIZE); // setup output buffering

// setvbuf is used to buffer output into outbuff instead of flushing to screen directly
if (setvbuf(stdout, outbuff, _IOFBF, sizeof(outbuff)) != 0) {
puts("Error: Buffer allocation!"); exit(EXIT_FAILURE);
}

while (pos < steps) {

// forward the transformer to get logits for the next token
Expand All @@ -650,9 +655,10 @@ int main(int argc, char *argv[]) {

// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];

printf("%s", token_str);
if (bufferflush==pos && strlen(outbuff)<=BUFFER_SIZE) { fflush(stdout); bufferflush+=buffertokens; } // flush after every n tokens
// flush output to screen after the defined number of buffertokens have accumulated
if (bufferflush==pos) { fflush(stdout); bufferflush+=buffertokens; }

// advance forward
token = next;
Expand Down

0 comments on commit af5c2c7

Please sign in to comment.