From 4c46f3e61db07b05363f4d506e6d89fd3340ad87 Mon Sep 17 00:00:00 2001 From: Joongwon Kim Date: Mon, 25 Mar 2024 13:19:08 -0700 Subject: [PATCH 1/9] Add MATH evaluation --- README.md | 1 + eval/MATH/answer_extraction.py | 271 +++++++++++++++++++++ eval/MATH/examplars.py | 23 ++ eval/MATH/run_eval.py | 280 ++++++++++++++++++++++ eval/MATH/utilities.py | 425 +++++++++++++++++++++++++++++++++ requirements.txt | 6 +- scripts/eval/MATH.sh | 93 ++++++++ scripts/prepare_eval_data.sh | 3 + scripts/submit_eval_jobs.py | 29 +++ 9 files changed, 1130 insertions(+), 1 deletion(-) create mode 100644 eval/MATH/answer_extraction.py create mode 100644 eval/MATH/examplars.py create mode 100644 eval/MATH/run_eval.py create mode 100644 eval/MATH/utilities.py create mode 100644 scripts/eval/MATH.sh diff --git a/README.md b/README.md index 919b20bd..a42c0874 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,7 @@ We provide the scripts for running evaluation of Huggingface/OpenAI models on a - [MMLU](https://github.com/hendrycks/test) - [Grade School Math (GSM)](https://github.com/openai/grade-school-math) +- [MATH](https://github.com/hendrycks/math) - [Big-Bench Hard (BBH)](https://github.com/suzgunmirac/BIG-Bench-Hard/tree/main) - [TydiQA](https://github.com/google-research-datasets/tydiqa) - [Codex HumanEval](https://github.com/openai/human-eval/tree/master) diff --git a/eval/MATH/answer_extraction.py b/eval/MATH/answer_extraction.py new file mode 100644 index 00000000..aba16ac6 --- /dev/null +++ b/eval/MATH/answer_extraction.py @@ -0,0 +1,271 @@ +import re +import regex + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if len(substr) > 0 and substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + if "sqrt" not in a: + a = int(a) + if "sqrt" not in b: + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: + return string + + +def _fix_sqrt(string): + _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string) + _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string) + return _string + + +def _fix_tan(string): + _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string) + _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string) + return _string + + +def strip_string(string): + string = str(string).strip() + # linebreaks + string = string.replace("\n", "") + + # right "." + string = string.rstrip(".") + + # remove inverse spaces + string = string.replace("\\!", "") + # string = string.replace("\\ ", "") + + # replace \\ with \ + # string = string.replace("\\\\", "\\") + # string = string.replace("\\\\", "\\") + + if string.startswith("\\text{") and string.endswith("}"): + string = string.split("{", 1)[1][:-1] + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = string.replace("cfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove unit: miles, dollars if after is not none + _string = re.sub(r"\\text{.*?}$", "", string).strip() + if _string != "" and _string != string: + # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) + string = _string + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "").strip() + string = string.replace("^\\circ", "").strip() + + string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip() + string = regex.sub(r"p\.m\.$", "", string).strip() + string = regex.sub(r"(\d)\s*t$", r"\1", string).strip() + + # remove dollar signs + string = string.replace("\\$", "") + string = string.replace("$", "") + + # string = string.replace("\\text", "") + string = string.replace("x\\in", "") + + # remove percentage + string = string.replace("\\%", "%") + string = string.replace("\%", "%") + # string = string.replace("%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + + # cdot + string = string.replace("\\cdot", "") + + # inf + string = string.replace("infinity", "\\infty") + if "\\infty" not in string: + string = string.replace("inf", "\\infty") + string = string.replace("+\\inity", "\\infty") + + # and + # string = string.replace("and", "") + string = string.replace("\\mathbf", "") + string = string.replace("\\mathrm", "") + + # use regex to remove \mbox{...} + string = re.sub(r"\\mbox{.*?}", "", string) + + # quote + string.replace("'", "") + string.replace("\"", "") + + # i, j + if "j" in string and "i" not in string: + string = string.replace("j", "i") + + # replace a.000b where b is not number or b is end, with ab, use regex + string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) + string = re.sub(r"(\d+)\.0+$", r"\1", string) + + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + # if len(string.split("=")) == 2: + # if len(string.split("=")[0]) <= 2: + # string = string.split("=")[1] + + string = _fix_sqrt(string) + string = _fix_tan(string) + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + string = regex.sub(r"(\\|,|\.)+$", "", string) + + return string + +def extract_boxed_answers(text): + answers = [] + for piece in text.split('boxed{')[1:]: + n = 0 + for i in range(len(piece)): + if piece[i] == '{': + n += 1 + elif piece[i] == '}': + n -= 1 + if n < 0: + if i + 1 < len(piece) and piece[i + 1] == '%': + answers.append(piece[: i + 1]) + else: + answers.append(piece[:i]) + break + return answers + +def extract_program_output(pred_str): + """ + extract output between the last ```output\n...\n``` + """ + if "```output" not in pred_str: + return "" + if '```output' in pred_str: + pred_str = pred_str.split('```output')[-1] + if '```' in pred_str: + pred_str = pred_str.split('```')[0] + output = pred_str.strip() + return output + +def extract_answer(pred_str, exhaust=False): + pred = [] + if 'final answer is $' in pred_str and '$. I hope' in pred_str: + tmp = pred_str.split('final answer is $', 1)[1] + pred = [tmp.split('$. I hope', 1)[0].strip()] + elif 'boxed' in pred_str: + pred = extract_boxed_answers(pred_str) + elif ('he answer is' in pred_str): + pred = [pred_str.split('he answer is')[-1].strip()] + else: + program_output = extract_program_output(pred_str) + if program_output != "": + # fall back to program + pred.append(program_output) + else: # use the last number + pattern = '-?\d*\.?\d+' + ans = re.findall(pattern, pred_str.replace(",", "")) + if(len(ans) >= 1): + ans = ans[-1] + else: + ans = '' + if ans: + pred.append(ans) + + # multiple line + _pred = [] + for ans in pred: + ans = ans.strip().split("\n")[0] + ans = ans.lstrip(":") + ans = ans.rstrip(".") + ans = ans.rstrip("/") + ans = strip_string(ans) + _pred.append(ans) + if exhaust: + return _pred + else: + return _pred[-1] if _pred else "" + +def extract_math_answer(question, reasoning, task): + answer = [] + for ans in extract_answer(reasoning, exhaust=True): + if 'separated by commas' in question and all(ch not in ans for ch in '()[]'): + answer.extend([a.strip() for a in ans.split(",")]) + elif regex.search(r"\\text\{\s*and\s*\}", ans): + answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")]) + else: + answer.append(ans.strip()) + return answer + +def extract_math_few_shot_cot_answer(question, reasoning, task): + if 'Problem:' in reasoning: + reasoning = reasoning.split("Problem:", 1)[0] + return extract_math_answer(question, reasoning, task) + +def extract_last_single_answer(question, reasoning, task): + return extract_answer(reasoning, exhaust=False) + +def extract_gsm_few_shot_cot_answer(question, reasoning, task): + if 'Q: ' in reasoning: + reasoning = reasoning.split("Q: ", 1)[0] + pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)] + if pred: + return pred[-1] + else: + return "[invalid]" diff --git a/eval/MATH/examplars.py b/eval/MATH/examplars.py new file mode 100644 index 00000000..2c1e20a7 --- /dev/null +++ b/eval/MATH/examplars.py @@ -0,0 +1,23 @@ +# These examplars are from the DeepSeekMath GitHub repository (https://github.com/deepseek-ai/DeepSeek-Math/tree/main/evaluation/few_shot_prompts) +EXAMPLARS = [ + { + "question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", + "cot_answer": "The expressions inside each square root must be non-negative.\nTherefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", + "short_answer": "[2,5)" + }, + { + "question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", + "cot_answer": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$", + "short_answer": "24" + }, + { + "question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", + "cot_answer": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}\n30n&=480\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}", + "short_answer": "16" + }, + { + "question": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.", + "cot_answer": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$", + "short_answer": "-\\frac{2}{3}" + } +] \ No newline at end of file diff --git a/eval/MATH/run_eval.py b/eval/MATH/run_eval.py new file mode 100644 index 00000000..ae35ec6c --- /dev/null +++ b/eval/MATH/run_eval.py @@ -0,0 +1,280 @@ +import argparse +import json +import os +import random +import torch +import vllm + +from eval.utils import ( + generate_completions, + load_hf_lm, + query_openai_chat_model, + dynamic_import_function, + load_hf_tokenizer +) +from eval.MATH.answer_extraction import extract_answer +from eval.MATH.examplars import EXAMPLARS as MATH_EXAMPLARS +from eval.MATH.utilities import last_boxed_only_string, remove_boxed, eval_math + +DEFAULT_PROMPT_PREFIX_COT = "Solve the question below by reasoning step by step, and put the final answer within \\boxed{}." +DEFAULT_PROMPT_PREFIX_NO_COT = "Answer the following question." + +DEFAULT_PROMPT_TEMPLATE_COT = """Question: %s\nSolution: %s""" +DEFAULT_PROMPT_TEMPLATE_NO_COT = """Question: %s\nAnswer: %s""" + +def main(args): + random.seed(42) + + print("Loading data...") + test_data = [] + with open(os.path.join(args.data_dir, f"test.jsonl")) as fin: + for line in fin: + example = json.loads(line) + test_data.append({ + "question": example["question"], + "answer": extract_answer(example["answer"]) + }) + + if args.max_num_examples and len(test_data) > args.max_num_examples: + test_data = random.sample(test_data, args.max_num_examples) + + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir, exist_ok=True) + + global MATH_EXAMPLARS + if args.n_shot: + if len(MATH_EXAMPLARS) > args.n_shot: + MATH_EXAMPLARS = random.sample(MATH_EXAMPLARS, args.n_shot) + demonstrations = [] + for example in MATH_EXAMPLARS: + if args.no_cot: + demonstrations.append( + "Question: " + example["question"] + "\n" + "Answer: " + example["short_answer"] + ) + else: + demonstrations.append( + "Question:" + "\n" + example["question"] + "\n" + "Solution: " + "\n" + example["cot_answer"] + "\n" + "Final Answer: " + f"The final answer is ${example['short_answer']}$." + ) + prompt_prefix = args.prompt_prefix + "\n\n" + "\n\n".join(demonstrations) + "\n\n" + else: + prompt_prefix = args.prompt_prefix + "\n\n" + + if args.use_chat_format: + chat_formatting_function = dynamic_import_function(args.chat_formatting_function) + def apply_chat_format(example, tokenizer): + messages = [{"role": "user", "content": prompt_prefix + "Question: " + example["question"].strip()}] + prompt = chat_formatting_function(messages, tokenizer, add_bos=False) + prompt += "Answer:" if prompt[-1] in ["\n", " "] else " Answer:" + return prompt + + if args.model_name_or_path: + print("Loading model and tokenizer...") + tokenizer = load_hf_tokenizer( + model_name_or_path=args.model_name_or_path, + tokenizer_name_or_path=args.tokenizer_name_or_path, + use_fast_tokenizer=not args.use_slow_tokenizer, + ) + if args.use_vllm: + model = vllm.LLM( + model=args.model_name_or_path, + tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path, + tokenizer_mode="slow" if args.use_slow_tokenizer else "auto", + tensor_parallel_size=torch.cuda.device_count(), + ) + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=512, + stop=["\n"] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. + ) + if args.use_chat_format: + prompts = [apply_chat_format(example, tokenizer) for example in test_data] + else: + if args.no_cot: + prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data] + else: + prompts = [prompt_prefix + "Question: " + "\n" + example["question"].strip() + "\nSolution: " + "\n" for example in test_data] + generations = model.generate(prompts, sampling_params) + prompt_to_output = { + g.prompt: g.outputs[0].text for g in generations + } + outputs = [prompt_to_output[prompt] if prompt in prompt_to_output else "" for prompt in prompts] + else: + model = load_hf_lm( + model_name_or_path=args.model_name_or_path, + load_in_8bit=args.load_in_8bit, + device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto", + gptq_model=args.gptq, + ) + from transformers import GPTNeoXForCausalLM, OPTForCausalLM + if isinstance(model, GPTNeoXForCausalLM) or isinstance(model, OPTForCausalLM): + tokenizer.model_max_length = model.config.max_position_embeddings + print("Set tokenizer.model_max_length to model.config.max_position_embeddings: {}".format(model.config.max_position_embeddings)) + if args.use_chat_format: + prompts = [apply_chat_format(example, tokenizer) for example in test_data] + else: + if args.no_cot: + prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data] + else: + prompts = [prompt_prefix + "Question: " + "\n" + example["question"].strip() + "\nSolution: " + "\n" for example in test_data] + new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start. + outputs = generate_completions( + model=model, + tokenizer=tokenizer, + prompts=prompts, + max_new_tokens=512, + batch_size=args.eval_batch_size, + stop_id_sequences=[[new_line_token]] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. + do_sample=False, + ) + else: + prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data] + instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)] + results = query_openai_chat_model( + engine=args.openai_engine, + instances=instances, + batch_size=args.eval_batch_size if args.eval_batch_size else 10, + output_path=os.path.join(args.save_dir, f"openai_results.jsonl"), + ) + outputs = [result["output"] for result in results] + + predictions = [] + for output in outputs: + last_boxed_output = last_boxed_only_string(output) + if last_boxed_output: + output = remove_boxed(last_boxed_output) + predictions.append(output) + + predictions = [{ + "question": example["question"], + "answer": example["answer"], + "model_output": output, + "prediction": pred + } for example, output, pred in zip(test_data, outputs, predictions)] + + print("Calculating accuracy...") + correct_list = [] + for pred in predictions: + correct = eval_math(pred) + correct_list.append(correct) + accuracy = round(sum(correct_list) / len(correct_list), ndigits=4) + print(f"Accuracy: {accuracy}") + + with open(os.path.join(args.save_dir, f"predictions.jsonl"), "w") as fout: + for prediction in predictions: + fout.write(json.dumps(prediction) + "\n") + + with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout: + json.dump({ + "accuracy": accuracy + }, fout, indent=4) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="data/gsm" + ) + parser.add_argument( + "--max_num_examples", + type=int, + default=None, + help="maximum number of examples to evaluate." + ) + parser.add_argument( + "--save_dir", + type=str, + default="results/gsm" + ) + parser.add_argument( + "--model_name_or_path", + type=str, + default=None, + help="if specified, we will load the model to generate the predictions." + ) + parser.add_argument( + "--tokenizer_name_or_path", + type=str, + default=None, + help="if specified, we will load the tokenizer from here." + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If given, we will use the slow tokenizer." + ) + parser.add_argument( + "--openai_engine", + type=str, + default=None, help="if specified, we will use the OpenAI API to generate the predictions." + ) + parser.add_argument( + "--n_shot", + type=int, + default=8, + help="max number of examples to use for demonstration." + ) + parser.add_argument( + "--no_cot", + action="store_true", + help="If given, we're evaluating a model without chain-of-thought." + ) + parser.add_argument( + "--eval_batch_size", + type=int, + default=1, + help="batch size for evaluation." + ) + parser.add_argument( + "--load_in_8bit", + action="store_true", + help="load model in 8bit mode, which will reduce memory and speed up inference." + ) + parser.add_argument( + "--gptq", + action="store_true", + help="If given, we're evaluating a 4-bit quantized GPTQ model." + ) + parser.add_argument( + "--use_vllm", + action="store_true", + help="If given, we will use the vllm library, which will likely increase the inference throughput." + ) + parser.add_argument( + "--use_chat_format", + action="store_true", + help="If given, we will use the chat format for the prompts." + ) + parser.add_argument( + "--chat_formatting_function", + type=str, + default="eval.templates.create_prompt_with_tulu_chat_format", + help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`." + ) + parser.add_argument( + "--prompt_prefix", + type=str, + default=None, + help="the specific prefix to use for instructing the model." + ) + parser.add_argument( + "--prompt_template", + type=str, + default=None, + help="the specific template to use for instructing the model." + ) + args = parser.parse_args() + + # update the prompt prefix depending on whether CoT is being used + if args.prompt_prefix is None: + args.prompt_prefix = DEFAULT_PROMPT_PREFIX_NO_COT if args.no_cot else DEFAULT_PROMPT_PREFIX_COT + + # update the prompt template depending on whether CoT is being used + if args.prompt_template is None: + args.prompt_template = DEFAULT_PROMPT_TEMPLATE_NO_COT if args.no_cot else DEFAULT_PROMPT_TEMPLATE_COT + + # model_name_or_path and openai_engine cannot be both None or both not None. + assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified." + main(args) \ No newline at end of file diff --git a/eval/MATH/utilities.py b/eval/MATH/utilities.py new file mode 100644 index 00000000..a6bc2cd6 --- /dev/null +++ b/eval/MATH/utilities.py @@ -0,0 +1,425 @@ +import multiprocessing +from copy import deepcopy +from math import isclose +import numpy as np +from typing import Union, Any, Dict + +from sympy import simplify, N +from sympy.parsing.sympy_parser import parse_expr +from sympy.parsing.latex import parse_latex +import re +import regex + +from eval.MATH.answer_extraction import extract_answer, extract_program_output, strip_string + +def is_correct(item, pred_key='prediction', prec=1e-3): + pred = item[pred_key] + ans = item['answer'] + if isinstance(pred, list) and isinstance(ans, list): + pred_matched = set() + ans_matched = set() + for i in range(len(pred)): + for j in range(len(ans)): + item_cpy = deepcopy(item) + item_cpy.update({ + pred_key: pred[i], + 'answer': ans[j] + }) + if is_correct(item_cpy, pred_key=pred_key, prec=prec): + pred_matched.add(i) + ans_matched.add(j) + if item_cpy[pred_key] == '2,3,4': + print(item, flush=True) + print("wtf", flush=True) + return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) + elif isinstance(pred, str) and isinstance(ans, str): + if '\\cup' in pred and '\\cup' in ans: + item = deepcopy(item) + item.update({ + pred_key: pred.split('\\cup'), + 'answer': ans.split('\\cup'), + }) + return is_correct(item, pred_key=pred_key, prec=prec) + else: + label = False + try: + label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec + except: + pass + label = label or (ans and pred == ans) or math_equal(pred, ans) + return label + else: + print(item, flush=True) + raise NotImplementedError() + +def eval_math(item, pred_key='prediction', prec=1e-3): + pred = item[pred_key] + if pred_key == 'program_output' and isinstance(pred, str): + pred = [pred] + ans = item['answer'] + if isinstance(pred, list) and isinstance(ans, list): + # for some questions in MATH, `reference` repeats answers + _ans = [] + for a in ans: + if a not in _ans: + _ans.append(a) + ans = _ans + # some predictions for MATH questions also repeats answers + _pred = [] + for a in pred: + if a not in _pred: + _pred.append(a) + # some predictions mistakenly box non-answer strings + pred = _pred[-len(ans):] + + item.update({ + pred_key: pred, + 'answer': ans + }) + return is_correct(item, pred_key=pred_key, prec=prec) + +def extract_program(result: str, last_only=True): + """ + extract the program after "```python", and before "```" + """ + program = "" + start = False + for line in result.split("\n"): + if line.startswith("```python"): + if last_only: + program = "" # only extract the last program + else: + program += "\n# ========\n" + start = True + elif line.startswith("```"): + start = False + elif start: + program += line + "\n" + return program + + +def parse_ground_truth(example: Dict[str, Any], data_name): + if 'gt_cot' in example: + return example['gt_cot'], strip_string(example['gt']) + + # parse ground truth + if data_name in ["math", 'ocw']: + gt_cot = example['solution'] + gt_ans = extract_answer(gt_cot) + elif data_name == "gsm8k": + gt_cot, gt_ans = example['answer'].split("####") + elif data_name == "gsm-hard": + gt_cot, gt_ans = example['code'], example['target'] + elif data_name == "svamp": + gt_cot, gt_ans = example['Equation'], example['Answer'] + elif data_name == "asdiv": + gt_cot = example['formula'] + gt_ans = re.sub(r"\(.*?\)", "", example['answer']) + elif data_name == "mawps": + gt_cot, gt_ans = None, example['target'] + elif data_name == "tabmwp": + gt_cot = example['solution'] + gt_ans = example['answer'] + if example['ans_type'] in ['integer_number', 'decimal_number']: + if '/' in gt_ans: + gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) + elif ',' in gt_ans: + gt_ans = float(gt_ans.replace(',', '')) + elif '%' in gt_ans: + gt_ans = float(gt_ans.split('%')[0]) / 100 + else: + gt_ans = float(gt_ans) + elif data_name == "bbh": + gt_cot, gt_ans = None, example['target'] + else: + raise NotImplementedError(data_name) + # post process + gt_cot = str(gt_cot).strip() + gt_ans = strip_string(gt_ans) + return gt_cot, gt_ans + + +def parse_question(example, data_name): + question = "" + if data_name == "asdiv": + question = f"{example['body'].strip()} {example['question'].strip()}" + elif data_name == "svamp": + body = example["Body"].strip() + if not body.endswith("."): + body = body + "." + question = f'{body} {example["Question"].strip()}' + elif data_name == "tabmwp": + title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" + question = f'Read the following table {title_str}and answer a question:\n' + question += f'{example["table"]}\n{example["question"]}' + if example['choices']: + question += f' Please select from the following options: {example["choices"]}' + else: + for key in ['question', 'problem', 'Question', 'input']: + if key in example: + question = example[key] + break + assert question != "" + return question.strip() + + +def run_execute(executor, result, prompt_type, execute=False): + if not result or result == 'error': + return None, None + report = None + + if "program_only" in prompt_type: + prediction = extract_program_output(result) + elif prompt_type in ["pot", "pal"] and execute: + code = extract_program(result) + prediction, report = executor.apply(code) + else: + prediction = extract_answer(result) + + prediction = strip_string(prediction) + return prediction, report + + +def parse_digits(num): + # format: 234.23 || 23% + num = regex.sub(',', '', str(num)) + try: + return float(num) + except: + if num.endswith('%'): + num = num[:-1] + if num.endswith('\\'): + num = num[:-1] + try: + return float(num) / 100 + except: + pass + return None + +def is_digit(num): + # paired with parse_digits + return parse_digits(num) is not None + + +def normalize_prediction(prediction): + try: # 1. numerical equal + if is_digit(prediction): + prediction = np.round(float(str(prediction).replace(",", "")), 6) + return str(prediction) + except: + pass + + # 2. symbolic equal + prediction = str(prediction).strip() + + ## deal with [], (), {} + brackets = [] + while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")): + bracket = prediction[0] + prediction = prediction[1:-1] + if brackets and ',' in prediction: + pred_parts = [normalize_prediction(part) for part in prediction.split(",")] + prediction = ",".join(pred_parts) + + if brackets: + for b in reversed(brackets): + if b == '[': + prediction = '[' + prediction + ']' + else: + assert b == '(' + prediction = '(' + prediction + ')' + + def _parse(s): + for f in [parse_latex, parse_expr]: + try: + return f(s) + except: + pass + return s + + prediction = _parse(prediction) + + for s in ['{', "}", "(", ")"]: + prediction = prediction.replace(s, "") + + return prediction + + +def math_equal(prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + is_close: bool = True, + timeout: bool = False, + ) -> bool: + """ + Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + if str(prediction) == str(reference): + return True + + try: # 1. numerical equal + if is_digit(prediction) and is_digit(reference): + prediction = parse_digits(prediction) + reference = parse_digits(reference) + # number questions + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + for item in gt_result: + try: + if is_close: + if isclose(item, prediction, abs_tol=1e-3): + return True + else: + if item == prediction: + return True + except Exception: + continue + return False + except: + pass + + if not prediction and prediction not in [0, False]: + return False + + # 2. symbolic equal + reference = str(reference).strip() + prediction = str(prediction).strip() + + if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: + pred_parts = prediction[1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts): + if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): + return True + + if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ + (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): + pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] + ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] + matched = True + if len(pred_lines) == len(ref_lines): + for pred_line, ref_line in zip(pred_lines, ref_lines): + pred_parts = pred_line.split("&") + ref_parts = ref_line.split("&") + if len(pred_parts) == len(ref_parts): + if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): + matched = False + break + else: + matched = False + if not matched: + break + else: + matched = False + if matched: + return True + + if prediction.count('=') == 1 and reference.count('=') == 1: + pred = prediction.split('=') + pred = f"{pred[0].strip()} - ({pred[1].strip()})" + ref = reference.split('=') + ref = f"{ref[0].strip()} - ({ref[1].strip()})" + if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): + return True + elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: + if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): + return True + elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: + if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): + return True + + # symbolic equal with sympy + if timeout: + if call_with_timeout(symbolic_equal_process, prediction, reference): + return True + else: + if symbolic_equal(prediction, reference): + return True + + return False + + +def math_equal_process(param): + return math_equal(param[-2], param[-1]) + + +def symbolic_equal(a, b): + def _parse(s): + for f in [parse_latex, parse_expr]: + try: + return f(s) + except: + pass + return s + a = _parse(a) + b = _parse(b) + + try: + if simplify(a-b) == 0: + return True + except: + pass + + try: + if isclose(N(a), N(b), abs_tol=1e-3): + return True + except: + pass + return False + + +def symbolic_equal_process(a, b, output_queue): + result = symbolic_equal(a, b) + output_queue.put(result) + + +def call_with_timeout(func, *args, timeout=1, **kwargs): + output_queue = multiprocessing.Queue() + process_args = args + (output_queue,) + process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) + process.start() + process.join(timeout) + + if process.is_alive(): + process.terminate() + process.join() + return False + + return output_queue.get() + +def last_boxed_only_string(string: str): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return "" + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + if right_brace_idx == None: + retval = "" + else: + retval = string[idx:right_brace_idx + 1] + return retval + +def remove_boxed(s: str): + left = "\\boxed{" + try: + assert s[:len(left)] == left + assert s[-1] == "}" + return s[len(left):-1] + except: + return "" diff --git a/requirements.txt b/requirements.txt index 697e3518..74055456 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,4 +34,8 @@ alpaca-eval==0.5.3 # for human eval web app flask vllm -openpyxl \ No newline at end of file +openpyxl +# for math evaluations +antlr4-python3-runtime==4.11.0 +mpmath==1.3.0 +sympy==1.12.0 \ No newline at end of file diff --git a/scripts/eval/MATH.sh b/scripts/eval/MATH.sh new file mode 100644 index 00000000..54f6cfa9 --- /dev/null +++ b/scripts/eval/MATH.sh @@ -0,0 +1,93 @@ +# Here we use 1 GPU for demonstration, but you can use multiple GPUs and larger eval_batch_size to speed up the evaluation. +export CUDA_VISIBLE_DEVICES=0 + + +# Evaluating llama 7B model using chain-of-thought +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/llama-7B-cot-4shot \ + --model ../hf_llama_models/7B \ + --tokenizer ../hf_llama_models/7B \ + --n_shot 4 \ + --use_vllm + + +# Evaluating llama 7B model using direct answering (no chain-of-thought) +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/gsm/llama-7B-no-cot-4shot \ + --model ../hf_llama_models/7B \ + --tokenizer ../hf_llama_models/7B \ + --n_shot 4 \ + --no_cot \ + --use_vllm + + +# Evaluating tulu 7B model using chain-of-thought and chat format +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/tulu-7B-cot-4shot \ + --model ../checkpoints/tulu_7B \ + --tokenizer ../checkpoints/tulu_7B \ + --n_shot 4 \ + --use_chat_format \ + --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \ + --use_vllm + + +# Evaluating llama2 chat model using chain-of-thought and chat format +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/llama2-chat-7B-cot-4shot \ + --model ../hf_llama2_models/7B-chat \ + --tokenizer ../hf_llama2_models/7B-chat \ + --n_shot 4 \ + --use_chat_format \ + --chat_formatting_function eval.templates.create_prompt_with_llama2_chat_format \ + --use_vllm + + +# Evaluating chatgpt using chain-of-thought +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/chatgpt-cot \ + --openai_engine "gpt-3.5-turbo-0301" \ + --eval_batch_size 20 \ + --n_shot 4 + + +# Evaluating chatgpt using direct answering (no chain-of-thought) +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/chatgpt-no-cot \ + --openai_engine "gpt-3.5-turbo-0301" \ + --eval_batch_size 20 \ + --n_shot 4 \ + --no_cot + + +# Evaluating gpt4 using chain-of-thought +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/gpt4-cot \ + --openai_engine "gpt-4-0314" \ + --eval_batch_size 20 \ + --n_shot 4 + + +# Evaluating gpt4 using direct answering (no chain-of-thought) +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/gpt4-no-cot \ + --openai_engine "gpt-4-0314" \ + --eval_batch_size 20 \ + --n_shot 4 \ + --no_cot diff --git a/scripts/prepare_eval_data.sh b/scripts/prepare_eval_data.sh index c8a2bc33..43b44569 100755 --- a/scripts/prepare_eval_data.sh +++ b/scripts/prepare_eval_data.sh @@ -24,6 +24,9 @@ wget -P data/eval/tydiqa/ https://storage.googleapis.com/tydiqa/v1.1/tydiqa-gold # GSM dataset wget -P data/eval/gsm/ https://github.com/openai/grade-school-math/raw/master/grade_school_math/data/test.jsonl +# MATH dataset +mkdir -p data/eval/MATH +wget -P data/eval/MATH/ https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/datasets/math/test.jsonl # Codex HumanEval wget -P data/eval/codex_humaneval https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index 37d2219e..00223205 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -45,6 +45,8 @@ "mmlu_5shot", "gsm_direct", "gsm_cot", + "MATH_direct", + "MATH_cot", "bbh_direct", "bbh_cot", "tydiqa_goldp_1shot", @@ -155,6 +157,33 @@ --use_chat_format \ --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format ''' + elif experiment_group == "MATH_direct": + task_spec['arguments'][0] = ''' + python -m eval.MATH.run_eval \ + --data_dir /data/MATH/ \ + --max_num_examples 200 \ + --save_dir /output/ \ + --use_vllm \ + --model_name_or_path /model \ + --tokenizer_name_or_path /model \ + --n_shot 4 \ + --no_cot \ + --use_chat_format \ + --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format + ''' + elif experiment_group == "MATH_cot": + task_spec['arguments'][0] = ''' + python -m eval.MATH.run_eval \ + --data_dir /data/MATH/ \ + --max_num_examples 200 \ + --save_dir /output/ \ + --use_vllm \ + --model_name_or_path /model \ + --tokenizer_name_or_path /model \ + --n_shot 4 \ + --use_chat_format \ + --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format + ''' elif experiment_group == "tydiqa_goldp_1shot": task_spec["arguments"][0] = ''' python -m eval.tydiqa.run_eval \ From 0d7d7d1a252c0cd739811423c0e11770a16f0edd Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sat, 6 Jul 2024 03:29:18 +0000 Subject: [PATCH 2/9] fix eval --- eval/MATH/run_eval.py | 60 ++++++++++++++++++++++++++++-------- scripts/prepare_eval_data.sh | 4 +-- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/eval/MATH/run_eval.py b/eval/MATH/run_eval.py index ae35ec6c..5e1e4905 100644 --- a/eval/MATH/run_eval.py +++ b/eval/MATH/run_eval.py @@ -31,8 +31,9 @@ def main(args): for line in fin: example = json.loads(line) test_data.append({ - "question": example["question"], - "answer": extract_answer(example["answer"]) + "question": example["problem"], + "answer": extract_answer(example["solution"]), + "type": example["type"] }) if args.max_num_examples and len(test_data) > args.max_num_examples: @@ -53,7 +54,7 @@ def main(args): ) else: demonstrations.append( - "Question:" + "\n" + example["question"] + "\n" + "Solution: " + "\n" + example["cot_answer"] + "\n" + "Final Answer: " + f"The final answer is ${example['short_answer']}$." + "Problem:" + "\n" + example["question"] + "\n\n" + "Solution: " + "\n" + example["cot_answer"] + "\n" + "Final Answer: " + f"The final answer is ${example['short_answer']}$. I hope it is correct." ) prompt_prefix = args.prompt_prefix + "\n\n" + "\n\n".join(demonstrations) + "\n\n" else: @@ -62,9 +63,9 @@ def main(args): if args.use_chat_format: chat_formatting_function = dynamic_import_function(args.chat_formatting_function) def apply_chat_format(example, tokenizer): - messages = [{"role": "user", "content": prompt_prefix + "Question: " + example["question"].strip()}] + messages = [{"role": "user", "content": prompt_prefix + "Problem: " + example["question"].strip()}] prompt = chat_formatting_function(messages, tokenizer, add_bos=False) - prompt += "Answer:" if prompt[-1] in ["\n", " "] else " Answer:" + prompt += "Solution:" if prompt[-1] in ["\n", " "] else " Solution:" return prompt if args.model_name_or_path: @@ -81,18 +82,23 @@ def apply_chat_format(example, tokenizer): tokenizer_mode="slow" if args.use_slow_tokenizer else "auto", tensor_parallel_size=torch.cuda.device_count(), ) + stop_strings = args.additional_stop_sequence + # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). + # For chat format, we will rely on the model knows when to stop. + if not args.use_chat_format: + stop_strings += ["\n"] sampling_params = vllm.SamplingParams( temperature=0, max_tokens=512, - stop=["\n"] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. + stop=stop_strings, ) if args.use_chat_format: prompts = [apply_chat_format(example, tokenizer) for example in test_data] else: if args.no_cot: - prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data] + prompts = [prompt_prefix + "Problem: " + example["question"].strip() + "\Solution:" for example in test_data] else: - prompts = [prompt_prefix + "Question: " + "\n" + example["question"].strip() + "\nSolution: " + "\n" for example in test_data] + prompts = [prompt_prefix + "Problem: " + "\n" + example["question"].strip() + "\n\nSolution: " + "\n" for example in test_data] generations = model.generate(prompts, sampling_params) prompt_to_output = { g.prompt: g.outputs[0].text for g in generations @@ -116,14 +122,18 @@ def apply_chat_format(example, tokenizer): prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data] else: prompts = [prompt_prefix + "Question: " + "\n" + example["question"].strip() + "\nSolution: " + "\n" for example in test_data] - new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start. + # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. + stop_tokens = [[tokenizer.encode(stop_seq, add_special_tokens=False)[-1]] for stop_seq in args.additional_stop_sequence] + if not args.use_chat_format: + new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start. + stop_tokens += [[new_line_token]] outputs = generate_completions( model=model, tokenizer=tokenizer, prompts=prompts, max_new_tokens=512, batch_size=args.eval_batch_size, - stop_id_sequences=[[new_line_token]] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. + stop_id_sequences=stop_tokens, do_sample=False, ) else: @@ -158,15 +168,32 @@ def apply_chat_format(example, tokenizer): correct_list.append(correct) accuracy = round(sum(correct_list) / len(correct_list), ndigits=4) print(f"Accuracy: {accuracy}") + metrics = { + "accuracy": accuracy + } + + # calculate per-type accuracy + type_correct = {} + type_total = {} + for pred, sample in zip(predictions, test_data): + type_ = sample["type"] + if type_ not in type_correct: + type_correct[type_] = 0 + type_total[type_] = 0 + type_correct[type_] += eval_math(pred) + type_total[type_] += 1 + type_accuracy = {type_: round(type_correct[type_] / type_total[type_], ndigits=4) for type_ in type_correct} + print("Per-type accuracy:") + for type_, acc in type_accuracy.items(): + print(f"{type_}: {acc}") + metrics["per_type_accuracy"] = type_accuracy with open(os.path.join(args.save_dir, f"predictions.jsonl"), "w") as fout: for prediction in predictions: fout.write(json.dumps(prediction) + "\n") with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout: - json.dump({ - "accuracy": accuracy - }, fout, indent=4) + json.dump(metrics, fout, indent=4) if __name__ == "__main__": @@ -265,6 +292,13 @@ def apply_chat_format(example, tokenizer): default=None, help="the specific template to use for instructing the model." ) + parser.add_argument( + '--additional_stop_sequence', + type=str, + nargs="+", + default=[], + help="Additional stop sequences to use when generating completions. Useful for e.g. llama-3-instruct." + ) args = parser.parse_args() # update the prompt prefix depending on whether CoT is being used diff --git a/scripts/prepare_eval_data.sh b/scripts/prepare_eval_data.sh index 5a0ad0aa..866e44b3 100755 --- a/scripts/prepare_eval_data.sh +++ b/scripts/prepare_eval_data.sh @@ -26,7 +26,7 @@ wget -P data/eval/gsm/ https://github.com/openai/grade-school-math/raw/master/gr # MATH dataset mkdir -p data/eval/MATH -wget -P data/eval/MATH/ https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/datasets/math/test.jsonl +wget -P data/eval/MATH/ https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/main/evaluation/datasets/math/test.jsonl # Codex HumanEval wget -P data/eval/codex_humaneval https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz @@ -59,4 +59,4 @@ wget -P data/eval/xstest https://github.com/paul-rottger/exaggerated-safety/raw/ # we use self-instruct test set, and vicuna test set for our human evaluation mkdir -p data/eval/creative_tasks wget -O data/eval/creative_tasks/self_instruct_test.jsonl https://github.com/yizhongw/self-instruct/raw/main/human_eval/user_oriented_instructions.jsonl -wget -O data/eval/creative_tasks/vicuna_test.jsonl https://github.com/lm-sys/FastChat/raw/main/fastchat/eval/table/question.jsonl \ No newline at end of file +wget -O data/eval/creative_tasks/vicuna_test.jsonl https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/vicuna_bench/question.jsonl \ No newline at end of file From 86f0bbc453fd0a57590431ca4ebc6f75646defe9 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sat, 6 Jul 2024 03:30:39 +0000 Subject: [PATCH 3/9] more consistent minerva prompt --- eval/MATH/run_eval.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval/MATH/run_eval.py b/eval/MATH/run_eval.py index 5e1e4905..c87b7429 100644 --- a/eval/MATH/run_eval.py +++ b/eval/MATH/run_eval.py @@ -119,9 +119,9 @@ def apply_chat_format(example, tokenizer): prompts = [apply_chat_format(example, tokenizer) for example in test_data] else: if args.no_cot: - prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data] + prompts = [prompt_prefix + "Problem: " + example["question"].strip() + "\nSolution:" for example in test_data] else: - prompts = [prompt_prefix + "Question: " + "\n" + example["question"].strip() + "\nSolution: " + "\n" for example in test_data] + prompts = [prompt_prefix + "Problem: " + "\n" + example["question"].strip() + "\nSolution: " + "\n" for example in test_data] # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. stop_tokens = [[tokenizer.encode(stop_seq, add_special_tokens=False)[-1]] for stop_seq in args.additional_stop_sequence] if not args.use_chat_format: @@ -137,7 +137,7 @@ def apply_chat_format(example, tokenizer): do_sample=False, ) else: - prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data] + prompts = [prompt_prefix + "Problem: " + example["question"].strip() + "\nSolution:" for example in test_data] instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)] results = query_openai_chat_model( engine=args.openai_engine, From be83419c1c9d88372c17a94bc6abb74ce94d4223 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sat, 6 Jul 2024 03:48:21 +0000 Subject: [PATCH 4/9] configurable gen length --- eval/MATH/run_eval.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/eval/MATH/run_eval.py b/eval/MATH/run_eval.py index c87b7429..8e2358d6 100644 --- a/eval/MATH/run_eval.py +++ b/eval/MATH/run_eval.py @@ -89,7 +89,7 @@ def apply_chat_format(example, tokenizer): stop_strings += ["\n"] sampling_params = vllm.SamplingParams( temperature=0, - max_tokens=512, + max_tokens=args.max_new_tokens, stop=stop_strings, ) if args.use_chat_format: @@ -248,6 +248,12 @@ def apply_chat_format(example, tokenizer): action="store_true", help="If given, we're evaluating a model without chain-of-thought." ) + parser.add_argument( + '--max_new_tokens', + type=int, + default=2048, + help="maximum number of tokens to generate for each prompt." + ) parser.add_argument( "--eval_batch_size", type=int, From 13ecda20d53656f21ad42e9916bdee258070c750 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sat, 6 Jul 2024 03:49:52 +0000 Subject: [PATCH 5/9] fix eval script --- scripts/submit_eval_jobs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index 1bedbc72..b220cc5c 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -123,6 +123,8 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): "gsm_cot", "MATH_direct", "MATH_cot", + "MATH_direct", + "MATH_cot", "bbh_direct", "bbh_cot", "tydiqa_goldp_1shot", @@ -253,7 +255,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): --n_shot 4 \ --no_cot \ --use_chat_format \ - --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format + --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \ ''' elif experiment_group == "MATH_cot": task_spec['arguments'][0] = ''' @@ -266,7 +268,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): --tokenizer_name_or_path /model \ --n_shot 4 \ --use_chat_format \ - --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format + --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \ ''' elif experiment_group == "tydiqa_goldp_1shot": task_spec["arguments"][0] = ''' From 7f5a16bcd5f586bbd395684ee8b0d045c4e0e15c Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sat, 6 Jul 2024 18:08:46 +0000 Subject: [PATCH 6/9] nit: allow openai key unset --- eval/dispatch_openai_requests.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/eval/dispatch_openai_requests.py b/eval/dispatch_openai_requests.py index 4724ab84..dedb9ce7 100644 --- a/eval/dispatch_openai_requests.py +++ b/eval/dispatch_openai_requests.py @@ -4,9 +4,13 @@ ''' import asyncio from typing import Any, List, Dict -from openai import AsyncOpenAI +from openai import AsyncOpenAI, OpenAIError -aclient = AsyncOpenAI() +try: + aclient = AsyncOpenAI() +except OpenAIError as e: + print(f"Error initializing OpenAI client: {e}") + print("If you are running an eval without OpenAI models, this is okay.") async def dispatch_openai_chat_requesets( messages_list: List[List[Dict[str,Any]]], From a4221e3d8e0168ae789106987033f2baca4a800f Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sat, 6 Jul 2024 18:22:12 +0000 Subject: [PATCH 7/9] fix up the prompt a bit --- eval/MATH/run_eval.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/eval/MATH/run_eval.py b/eval/MATH/run_eval.py index 8e2358d6..0704cce4 100644 --- a/eval/MATH/run_eval.py +++ b/eval/MATH/run_eval.py @@ -50,11 +50,11 @@ def main(args): for example in MATH_EXAMPLARS: if args.no_cot: demonstrations.append( - "Question: " + example["question"] + "\n" + "Answer: " + example["short_answer"] + "Problem:\n" + example["question"] + "\n\n" + "Solution:\n" + example["short_answer"] ) else: demonstrations.append( - "Problem:" + "\n" + example["question"] + "\n\n" + "Solution: " + "\n" + example["cot_answer"] + "\n" + "Final Answer: " + f"The final answer is ${example['short_answer']}$. I hope it is correct." + "Problem:\n" + example["question"] + "\n\n" + "Solution:\n" + example["cot_answer"] ) prompt_prefix = args.prompt_prefix + "\n\n" + "\n\n".join(demonstrations) + "\n\n" else: @@ -63,7 +63,7 @@ def main(args): if args.use_chat_format: chat_formatting_function = dynamic_import_function(args.chat_formatting_function) def apply_chat_format(example, tokenizer): - messages = [{"role": "user", "content": prompt_prefix + "Problem: " + example["question"].strip()}] + messages = [{"role": "user", "content": prompt_prefix + "Problem:\n" + example["question"].strip()}] prompt = chat_formatting_function(messages, tokenizer, add_bos=False) prompt += "Solution:" if prompt[-1] in ["\n", " "] else " Solution:" return prompt @@ -93,12 +93,12 @@ def apply_chat_format(example, tokenizer): stop=stop_strings, ) if args.use_chat_format: - prompts = [apply_chat_format(example, tokenizer) for example in test_data] + prompts = [apply_chat_format(demonstrations, example, tokenizer) for example in test_data] else: if args.no_cot: - prompts = [prompt_prefix + "Problem: " + example["question"].strip() + "\Solution:" for example in test_data] + prompts = [prompt_prefix + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] else: - prompts = [prompt_prefix + "Problem: " + "\n" + example["question"].strip() + "\n\nSolution: " + "\n" for example in test_data] + prompts = [prompt_prefix + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] generations = model.generate(prompts, sampling_params) prompt_to_output = { g.prompt: g.outputs[0].text for g in generations @@ -119,9 +119,9 @@ def apply_chat_format(example, tokenizer): prompts = [apply_chat_format(example, tokenizer) for example in test_data] else: if args.no_cot: - prompts = [prompt_prefix + "Problem: " + example["question"].strip() + "\nSolution:" for example in test_data] + prompts = [prompt_prefix + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] else: - prompts = [prompt_prefix + "Problem: " + "\n" + example["question"].strip() + "\nSolution: " + "\n" for example in test_data] + prompts = [prompt_prefix + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. stop_tokens = [[tokenizer.encode(stop_seq, add_special_tokens=False)[-1]] for stop_seq in args.additional_stop_sequence] if not args.use_chat_format: @@ -251,7 +251,7 @@ def apply_chat_format(example, tokenizer): parser.add_argument( '--max_new_tokens', type=int, - default=2048, + default=1024, help="maximum number of tokens to generate for each prompt." ) parser.add_argument( From 826cdaa1f3d4754fe21383ffcbec48eb38655bbd Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 8 Jul 2024 18:08:25 -0700 Subject: [PATCH 8/9] fix up math --- eval/MATH/answer_extraction.py | 271 ----------------------- eval/MATH/minerva_utils.py | 309 ++++++++++++++++++++++++++ eval/MATH/run_eval.py | 51 ++--- eval/MATH/utilities.py | 393 --------------------------------- 4 files changed, 335 insertions(+), 689 deletions(-) delete mode 100644 eval/MATH/answer_extraction.py create mode 100644 eval/MATH/minerva_utils.py diff --git a/eval/MATH/answer_extraction.py b/eval/MATH/answer_extraction.py deleted file mode 100644 index aba16ac6..00000000 --- a/eval/MATH/answer_extraction.py +++ /dev/null @@ -1,271 +0,0 @@ -import re -import regex - -def _fix_fracs(string): - substrs = string.split("\\frac") - new_str = substrs[0] - if len(substrs) > 1: - substrs = substrs[1:] - for substr in substrs: - new_str += "\\frac" - if len(substr) > 0 and substr[0] == "{": - new_str += substr - else: - try: - assert len(substr) >= 2 - except: - return string - a = substr[0] - b = substr[1] - if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" - else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - string = new_str - return string - - -def _fix_a_slash_b(string): - if len(string.split("/")) != 2: - return string - a = string.split("/")[0] - b = string.split("/")[1] - try: - if "sqrt" not in a: - a = int(a) - if "sqrt" not in b: - b = int(b) - assert string == "{}/{}".format(a, b) - new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" - return new_string - except: - return string - - -def _fix_sqrt(string): - _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string) - _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string) - return _string - - -def _fix_tan(string): - _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string) - _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string) - return _string - - -def strip_string(string): - string = str(string).strip() - # linebreaks - string = string.replace("\n", "") - - # right "." - string = string.rstrip(".") - - # remove inverse spaces - string = string.replace("\\!", "") - # string = string.replace("\\ ", "") - - # replace \\ with \ - # string = string.replace("\\\\", "\\") - # string = string.replace("\\\\", "\\") - - if string.startswith("\\text{") and string.endswith("}"): - string = string.split("{", 1)[1][:-1] - - # replace tfrac and dfrac with frac - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") - string = string.replace("cfrac", "frac") - - # remove \left and \right - string = string.replace("\\left", "") - string = string.replace("\\right", "") - - # Remove unit: miles, dollars if after is not none - _string = re.sub(r"\\text{.*?}$", "", string).strip() - if _string != "" and _string != string: - # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) - string = _string - - # Remove circ (degrees) - string = string.replace("^{\\circ}", "").strip() - string = string.replace("^\\circ", "").strip() - - string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip() - string = regex.sub(r"p\.m\.$", "", string).strip() - string = regex.sub(r"(\d)\s*t$", r"\1", string).strip() - - # remove dollar signs - string = string.replace("\\$", "") - string = string.replace("$", "") - - # string = string.replace("\\text", "") - string = string.replace("x\\in", "") - - # remove percentage - string = string.replace("\\%", "%") - string = string.replace("\%", "%") - # string = string.replace("%", "") - - # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string - string = string.replace(" .", " 0.") - string = string.replace("{.", "{0.") - - # cdot - string = string.replace("\\cdot", "") - - # inf - string = string.replace("infinity", "\\infty") - if "\\infty" not in string: - string = string.replace("inf", "\\infty") - string = string.replace("+\\inity", "\\infty") - - # and - # string = string.replace("and", "") - string = string.replace("\\mathbf", "") - string = string.replace("\\mathrm", "") - - # use regex to remove \mbox{...} - string = re.sub(r"\\mbox{.*?}", "", string) - - # quote - string.replace("'", "") - string.replace("\"", "") - - # i, j - if "j" in string and "i" not in string: - string = string.replace("j", "i") - - # replace a.000b where b is not number or b is end, with ab, use regex - string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) - string = re.sub(r"(\d+)\.0+$", r"\1", string) - - # if empty, return empty string - if len(string) == 0: - return string - if string[0] == ".": - string = "0" + string - - # to consider: get rid of e.g. "k = " or "q = " at beginning - # if len(string.split("=")) == 2: - # if len(string.split("=")[0]) <= 2: - # string = string.split("=")[1] - - string = _fix_sqrt(string) - string = _fix_tan(string) - string = string.replace(" ", "") - - # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} - string = _fix_fracs(string) - - # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y - string = _fix_a_slash_b(string) - - string = regex.sub(r"(\\|,|\.)+$", "", string) - - return string - -def extract_boxed_answers(text): - answers = [] - for piece in text.split('boxed{')[1:]: - n = 0 - for i in range(len(piece)): - if piece[i] == '{': - n += 1 - elif piece[i] == '}': - n -= 1 - if n < 0: - if i + 1 < len(piece) and piece[i + 1] == '%': - answers.append(piece[: i + 1]) - else: - answers.append(piece[:i]) - break - return answers - -def extract_program_output(pred_str): - """ - extract output between the last ```output\n...\n``` - """ - if "```output" not in pred_str: - return "" - if '```output' in pred_str: - pred_str = pred_str.split('```output')[-1] - if '```' in pred_str: - pred_str = pred_str.split('```')[0] - output = pred_str.strip() - return output - -def extract_answer(pred_str, exhaust=False): - pred = [] - if 'final answer is $' in pred_str and '$. I hope' in pred_str: - tmp = pred_str.split('final answer is $', 1)[1] - pred = [tmp.split('$. I hope', 1)[0].strip()] - elif 'boxed' in pred_str: - pred = extract_boxed_answers(pred_str) - elif ('he answer is' in pred_str): - pred = [pred_str.split('he answer is')[-1].strip()] - else: - program_output = extract_program_output(pred_str) - if program_output != "": - # fall back to program - pred.append(program_output) - else: # use the last number - pattern = '-?\d*\.?\d+' - ans = re.findall(pattern, pred_str.replace(",", "")) - if(len(ans) >= 1): - ans = ans[-1] - else: - ans = '' - if ans: - pred.append(ans) - - # multiple line - _pred = [] - for ans in pred: - ans = ans.strip().split("\n")[0] - ans = ans.lstrip(":") - ans = ans.rstrip(".") - ans = ans.rstrip("/") - ans = strip_string(ans) - _pred.append(ans) - if exhaust: - return _pred - else: - return _pred[-1] if _pred else "" - -def extract_math_answer(question, reasoning, task): - answer = [] - for ans in extract_answer(reasoning, exhaust=True): - if 'separated by commas' in question and all(ch not in ans for ch in '()[]'): - answer.extend([a.strip() for a in ans.split(",")]) - elif regex.search(r"\\text\{\s*and\s*\}", ans): - answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")]) - else: - answer.append(ans.strip()) - return answer - -def extract_math_few_shot_cot_answer(question, reasoning, task): - if 'Problem:' in reasoning: - reasoning = reasoning.split("Problem:", 1)[0] - return extract_math_answer(question, reasoning, task) - -def extract_last_single_answer(question, reasoning, task): - return extract_answer(reasoning, exhaust=False) - -def extract_gsm_few_shot_cot_answer(question, reasoning, task): - if 'Q: ' in reasoning: - reasoning = reasoning.split("Q: ", 1)[0] - pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)] - if pred: - return pred[-1] - else: - return "[invalid]" diff --git a/eval/MATH/minerva_utils.py b/eval/MATH/minerva_utils.py new file mode 100644 index 00000000..2d17af14 --- /dev/null +++ b/eval/MATH/minerva_utils.py @@ -0,0 +1,309 @@ +''' +Utils from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py +''' +import re + +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "ft", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + +def normalize_final_answer(final_answer: str) -> str: + """ + Normalize a final answer to a quantitative reasoning question. + + Copied character for character from appendix D of Lewkowycz et al. (2022) + """ + final_answer = final_answer.split("=")[-1] + + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract answer that is in LaTeX math, is bold, + # is surrounded by a box, etc. + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize 100,000 -> 100000 + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer + +def get_unnormalized_answer(text: str) -> str: + INVALID_ANSWER = "[invalidanswer]" + end_seq = "I hope it is correct." + text += end_seq + match = re.search( + r"Final Answer: The final answer is(.*?). I hope it is correct.", + text, + ) + if match: + return match.group(1).strip() + else: + return INVALID_ANSWER + +# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py +def is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + print("WARNING: Both None") + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def remove_boxed(s): + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + +def fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except AssertionError: + return string + + +def remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = fix_a_slash_b(string) + + return string \ No newline at end of file diff --git a/eval/MATH/run_eval.py b/eval/MATH/run_eval.py index 0704cce4..cbde5394 100644 --- a/eval/MATH/run_eval.py +++ b/eval/MATH/run_eval.py @@ -12,9 +12,9 @@ dynamic_import_function, load_hf_tokenizer ) -from eval.MATH.answer_extraction import extract_answer from eval.MATH.examplars import EXAMPLARS as MATH_EXAMPLARS -from eval.MATH.utilities import last_boxed_only_string, remove_boxed, eval_math +from eval.MATH.utilities import last_boxed_only_string, remove_boxed +from eval.MATH.minerva_utils import normalize_final_answer, get_unnormalized_answer, is_equiv DEFAULT_PROMPT_PREFIX_COT = "Solve the question below by reasoning step by step, and put the final answer within \\boxed{}." DEFAULT_PROMPT_PREFIX_NO_COT = "Answer the following question." @@ -32,7 +32,7 @@ def main(args): example = json.loads(line) test_data.append({ "question": example["problem"], - "answer": extract_answer(example["solution"]), + "answer": normalize_final_answer(remove_boxed(last_boxed_only_string((example["solution"])))), "type": example["type"] }) @@ -50,22 +50,25 @@ def main(args): for example in MATH_EXAMPLARS: if args.no_cot: demonstrations.append( - "Problem:\n" + example["question"] + "\n\n" + "Solution:\n" + example["short_answer"] + ("Problem:\n" + example["question"] + "\n\n" + "Solution:", example["short_answer"]) ) else: demonstrations.append( - "Problem:\n" + example["question"] + "\n\n" + "Solution:\n" + example["cot_answer"] + ("Problem:\n" + example["question"] + "\n\n" + "Solution:", example["cot_answer"] + "\n" + "Final Answer: " + f"The final answer is ${example['short_answer']}$. I hope it is correct.") ) - prompt_prefix = args.prompt_prefix + "\n\n" + "\n\n".join(demonstrations) + "\n\n" + initial_demonstrations = "\n\n".join(["\n".join(d) for d in demonstrations]) else: - prompt_prefix = args.prompt_prefix + "\n\n" - + demonstrations = [] + if args.use_chat_format: chat_formatting_function = dynamic_import_function(args.chat_formatting_function) - def apply_chat_format(example, tokenizer): - messages = [{"role": "user", "content": prompt_prefix + "Problem:\n" + example["question"].strip()}] + def apply_chat_format(example, demonstrations, tokenizer): + messages = [] + for user_turn, assistant_turn in demonstrations: + messages.append({"role": "user", "content": user_turn}) + messages.append({"role": "assistant", "content": assistant_turn}) + messages += [{"role": "user", "content": "Problem:\n" + example["question"].strip() + "\n\nSolution:"}] prompt = chat_formatting_function(messages, tokenizer, add_bos=False) - prompt += "Solution:" if prompt[-1] in ["\n", " "] else " Solution:" return prompt if args.model_name_or_path: @@ -82,7 +85,7 @@ def apply_chat_format(example, tokenizer): tokenizer_mode="slow" if args.use_slow_tokenizer else "auto", tensor_parallel_size=torch.cuda.device_count(), ) - stop_strings = args.additional_stop_sequence + stop_strings = args.additional_stop_sequence + ["Problem:"] # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). # For chat format, we will rely on the model knows when to stop. if not args.use_chat_format: @@ -93,12 +96,12 @@ def apply_chat_format(example, tokenizer): stop=stop_strings, ) if args.use_chat_format: - prompts = [apply_chat_format(demonstrations, example, tokenizer) for example in test_data] + prompts = [apply_chat_format(example, demonstrations, tokenizer) for example in test_data] else: if args.no_cot: - prompts = [prompt_prefix + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] + prompts = [initial_demonstrations + "\n\nProblem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] else: - prompts = [prompt_prefix + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] + prompts = [initial_demonstrations + "\n\nProblem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] generations = model.generate(prompts, sampling_params) prompt_to_output = { g.prompt: g.outputs[0].text for g in generations @@ -116,12 +119,12 @@ def apply_chat_format(example, tokenizer): tokenizer.model_max_length = model.config.max_position_embeddings print("Set tokenizer.model_max_length to model.config.max_position_embeddings: {}".format(model.config.max_position_embeddings)) if args.use_chat_format: - prompts = [apply_chat_format(example, tokenizer) for example in test_data] + prompts = [apply_chat_format(example, demonstrations, tokenizer) for example in test_data] else: if args.no_cot: - prompts = [prompt_prefix + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] + prompts = [initial_demonstrations + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] else: - prompts = [prompt_prefix + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] + prompts = [initial_demonstrations + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. stop_tokens = [[tokenizer.encode(stop_seq, add_special_tokens=False)[-1]] for stop_seq in args.additional_stop_sequence] if not args.use_chat_format: @@ -137,7 +140,7 @@ def apply_chat_format(example, tokenizer): do_sample=False, ) else: - prompts = [prompt_prefix + "Problem: " + example["question"].strip() + "\nSolution:" for example in test_data] + prompts = [initial_demonstrations + "Problem: " + example["question"].strip() + "\nSolution:" for example in test_data] instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)] results = query_openai_chat_model( engine=args.openai_engine, @@ -149,10 +152,8 @@ def apply_chat_format(example, tokenizer): predictions = [] for output in outputs: - last_boxed_output = last_boxed_only_string(output) - if last_boxed_output: - output = remove_boxed(last_boxed_output) - predictions.append(output) + output = get_unnormalized_answer(output) + predictions.append(normalize_final_answer(output)) predictions = [{ "question": example["question"], @@ -164,7 +165,7 @@ def apply_chat_format(example, tokenizer): print("Calculating accuracy...") correct_list = [] for pred in predictions: - correct = eval_math(pred) + correct = 1 if is_equiv(pred['prediction'], pred['answer']) else 0 correct_list.append(correct) accuracy = round(sum(correct_list) / len(correct_list), ndigits=4) print(f"Accuracy: {accuracy}") @@ -180,7 +181,7 @@ def apply_chat_format(example, tokenizer): if type_ not in type_correct: type_correct[type_] = 0 type_total[type_] = 0 - type_correct[type_] += eval_math(pred) + type_correct[type_] += 1 if is_equiv(pred["prediction"], pred["answer"]) else 0 type_total[type_] += 1 type_accuracy = {type_: round(type_correct[type_] / type_total[type_], ndigits=4) for type_ in type_correct} print("Per-type accuracy:") diff --git a/eval/MATH/utilities.py b/eval/MATH/utilities.py index a6bc2cd6..264f6c2b 100644 --- a/eval/MATH/utilities.py +++ b/eval/MATH/utilities.py @@ -1,396 +1,3 @@ -import multiprocessing -from copy import deepcopy -from math import isclose -import numpy as np -from typing import Union, Any, Dict - -from sympy import simplify, N -from sympy.parsing.sympy_parser import parse_expr -from sympy.parsing.latex import parse_latex -import re -import regex - -from eval.MATH.answer_extraction import extract_answer, extract_program_output, strip_string - -def is_correct(item, pred_key='prediction', prec=1e-3): - pred = item[pred_key] - ans = item['answer'] - if isinstance(pred, list) and isinstance(ans, list): - pred_matched = set() - ans_matched = set() - for i in range(len(pred)): - for j in range(len(ans)): - item_cpy = deepcopy(item) - item_cpy.update({ - pred_key: pred[i], - 'answer': ans[j] - }) - if is_correct(item_cpy, pred_key=pred_key, prec=prec): - pred_matched.add(i) - ans_matched.add(j) - if item_cpy[pred_key] == '2,3,4': - print(item, flush=True) - print("wtf", flush=True) - return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) - elif isinstance(pred, str) and isinstance(ans, str): - if '\\cup' in pred and '\\cup' in ans: - item = deepcopy(item) - item.update({ - pred_key: pred.split('\\cup'), - 'answer': ans.split('\\cup'), - }) - return is_correct(item, pred_key=pred_key, prec=prec) - else: - label = False - try: - label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec - except: - pass - label = label or (ans and pred == ans) or math_equal(pred, ans) - return label - else: - print(item, flush=True) - raise NotImplementedError() - -def eval_math(item, pred_key='prediction', prec=1e-3): - pred = item[pred_key] - if pred_key == 'program_output' and isinstance(pred, str): - pred = [pred] - ans = item['answer'] - if isinstance(pred, list) and isinstance(ans, list): - # for some questions in MATH, `reference` repeats answers - _ans = [] - for a in ans: - if a not in _ans: - _ans.append(a) - ans = _ans - # some predictions for MATH questions also repeats answers - _pred = [] - for a in pred: - if a not in _pred: - _pred.append(a) - # some predictions mistakenly box non-answer strings - pred = _pred[-len(ans):] - - item.update({ - pred_key: pred, - 'answer': ans - }) - return is_correct(item, pred_key=pred_key, prec=prec) - -def extract_program(result: str, last_only=True): - """ - extract the program after "```python", and before "```" - """ - program = "" - start = False - for line in result.split("\n"): - if line.startswith("```python"): - if last_only: - program = "" # only extract the last program - else: - program += "\n# ========\n" - start = True - elif line.startswith("```"): - start = False - elif start: - program += line + "\n" - return program - - -def parse_ground_truth(example: Dict[str, Any], data_name): - if 'gt_cot' in example: - return example['gt_cot'], strip_string(example['gt']) - - # parse ground truth - if data_name in ["math", 'ocw']: - gt_cot = example['solution'] - gt_ans = extract_answer(gt_cot) - elif data_name == "gsm8k": - gt_cot, gt_ans = example['answer'].split("####") - elif data_name == "gsm-hard": - gt_cot, gt_ans = example['code'], example['target'] - elif data_name == "svamp": - gt_cot, gt_ans = example['Equation'], example['Answer'] - elif data_name == "asdiv": - gt_cot = example['formula'] - gt_ans = re.sub(r"\(.*?\)", "", example['answer']) - elif data_name == "mawps": - gt_cot, gt_ans = None, example['target'] - elif data_name == "tabmwp": - gt_cot = example['solution'] - gt_ans = example['answer'] - if example['ans_type'] in ['integer_number', 'decimal_number']: - if '/' in gt_ans: - gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) - elif ',' in gt_ans: - gt_ans = float(gt_ans.replace(',', '')) - elif '%' in gt_ans: - gt_ans = float(gt_ans.split('%')[0]) / 100 - else: - gt_ans = float(gt_ans) - elif data_name == "bbh": - gt_cot, gt_ans = None, example['target'] - else: - raise NotImplementedError(data_name) - # post process - gt_cot = str(gt_cot).strip() - gt_ans = strip_string(gt_ans) - return gt_cot, gt_ans - - -def parse_question(example, data_name): - question = "" - if data_name == "asdiv": - question = f"{example['body'].strip()} {example['question'].strip()}" - elif data_name == "svamp": - body = example["Body"].strip() - if not body.endswith("."): - body = body + "." - question = f'{body} {example["Question"].strip()}' - elif data_name == "tabmwp": - title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" - question = f'Read the following table {title_str}and answer a question:\n' - question += f'{example["table"]}\n{example["question"]}' - if example['choices']: - question += f' Please select from the following options: {example["choices"]}' - else: - for key in ['question', 'problem', 'Question', 'input']: - if key in example: - question = example[key] - break - assert question != "" - return question.strip() - - -def run_execute(executor, result, prompt_type, execute=False): - if not result or result == 'error': - return None, None - report = None - - if "program_only" in prompt_type: - prediction = extract_program_output(result) - elif prompt_type in ["pot", "pal"] and execute: - code = extract_program(result) - prediction, report = executor.apply(code) - else: - prediction = extract_answer(result) - - prediction = strip_string(prediction) - return prediction, report - - -def parse_digits(num): - # format: 234.23 || 23% - num = regex.sub(',', '', str(num)) - try: - return float(num) - except: - if num.endswith('%'): - num = num[:-1] - if num.endswith('\\'): - num = num[:-1] - try: - return float(num) / 100 - except: - pass - return None - -def is_digit(num): - # paired with parse_digits - return parse_digits(num) is not None - - -def normalize_prediction(prediction): - try: # 1. numerical equal - if is_digit(prediction): - prediction = np.round(float(str(prediction).replace(",", "")), 6) - return str(prediction) - except: - pass - - # 2. symbolic equal - prediction = str(prediction).strip() - - ## deal with [], (), {} - brackets = [] - while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")): - bracket = prediction[0] - prediction = prediction[1:-1] - if brackets and ',' in prediction: - pred_parts = [normalize_prediction(part) for part in prediction.split(",")] - prediction = ",".join(pred_parts) - - if brackets: - for b in reversed(brackets): - if b == '[': - prediction = '[' + prediction + ']' - else: - assert b == '(' - prediction = '(' + prediction + ')' - - def _parse(s): - for f in [parse_latex, parse_expr]: - try: - return f(s) - except: - pass - return s - - prediction = _parse(prediction) - - for s in ['{', "}", "(", ")"]: - prediction = prediction.replace(s, "") - - return prediction - - -def math_equal(prediction: Union[bool, float, str], - reference: Union[float, str], - include_percentage: bool = True, - is_close: bool = True, - timeout: bool = False, - ) -> bool: - """ - Exact match of math if and only if: - 1. numerical equal: both can convert to float and are equal - 2. symbolic equal: both can convert to sympy expression and are equal - """ - if str(prediction) == str(reference): - return True - - try: # 1. numerical equal - if is_digit(prediction) and is_digit(reference): - prediction = parse_digits(prediction) - reference = parse_digits(reference) - # number questions - if include_percentage: - gt_result = [reference / 100, reference, reference * 100] - else: - gt_result = [reference] - for item in gt_result: - try: - if is_close: - if isclose(item, prediction, abs_tol=1e-3): - return True - else: - if item == prediction: - return True - except Exception: - continue - return False - except: - pass - - if not prediction and prediction not in [0, False]: - return False - - # 2. symbolic equal - reference = str(reference).strip() - prediction = str(prediction).strip() - - if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: - pred_parts = prediction[1:-1].split(",") - ref_parts = reference[1:-1].split(",") - if len(pred_parts) == len(ref_parts): - if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): - return True - - if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ - (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): - pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] - ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] - matched = True - if len(pred_lines) == len(ref_lines): - for pred_line, ref_line in zip(pred_lines, ref_lines): - pred_parts = pred_line.split("&") - ref_parts = ref_line.split("&") - if len(pred_parts) == len(ref_parts): - if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): - matched = False - break - else: - matched = False - if not matched: - break - else: - matched = False - if matched: - return True - - if prediction.count('=') == 1 and reference.count('=') == 1: - pred = prediction.split('=') - pred = f"{pred[0].strip()} - ({pred[1].strip()})" - ref = reference.split('=') - ref = f"{ref[0].strip()} - ({ref[1].strip()})" - if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): - return True - elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: - if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): - return True - elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: - if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): - return True - - # symbolic equal with sympy - if timeout: - if call_with_timeout(symbolic_equal_process, prediction, reference): - return True - else: - if symbolic_equal(prediction, reference): - return True - - return False - - -def math_equal_process(param): - return math_equal(param[-2], param[-1]) - - -def symbolic_equal(a, b): - def _parse(s): - for f in [parse_latex, parse_expr]: - try: - return f(s) - except: - pass - return s - a = _parse(a) - b = _parse(b) - - try: - if simplify(a-b) == 0: - return True - except: - pass - - try: - if isclose(N(a), N(b), abs_tol=1e-3): - return True - except: - pass - return False - - -def symbolic_equal_process(a, b, output_queue): - result = symbolic_equal(a, b) - output_queue.put(result) - - -def call_with_timeout(func, *args, timeout=1, **kwargs): - output_queue = multiprocessing.Queue() - process_args = args + (output_queue,) - process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) - process.start() - process.join(timeout) - - if process.is_alive(): - process.terminate() - process.join() - return False - - return output_queue.get() - def last_boxed_only_string(string: str): idx = string.rfind("\\boxed") if idx < 0: From 21a6a46ae4e05f9456fb08997b04b5c4ae9a331d Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Tue, 9 Jul 2024 01:36:37 +0000 Subject: [PATCH 9/9] remove math direct from eval list --- scripts/submit_eval_jobs.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index b220cc5c..ceff70d3 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -121,9 +121,6 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): "mmlu_5shot", "gsm_direct", "gsm_cot", - "MATH_direct", - "MATH_cot", - "MATH_direct", "MATH_cot", "bbh_direct", "bbh_cot", @@ -243,20 +240,6 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): ''' if args.gsm_stop_at_double_newline: task_spec['arguments'][0] += " --stop_at_double_newline" - elif experiment_group == "MATH_direct": - task_spec['arguments'][0] = ''' - python -m eval.MATH.run_eval \ - --data_dir /data/MATH/ \ - --max_num_examples 200 \ - --save_dir /output/ \ - --use_vllm \ - --model_name_or_path /model \ - --tokenizer_name_or_path /model \ - --n_shot 4 \ - --no_cot \ - --use_chat_format \ - --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \ - ''' elif experiment_group == "MATH_cot": task_spec['arguments'][0] = ''' python -m eval.MATH.run_eval \