diff --git a/README.md b/README.md index 1c5b04ed..079b0ead 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,7 @@ We provide the scripts for running evaluation of Huggingface/OpenAI models on a - [MMLU](https://github.com/hendrycks/test) - [Grade School Math (GSM)](https://github.com/openai/grade-school-math) +- [MATH](https://github.com/hendrycks/math) - [Big-Bench Hard (BBH)](https://github.com/suzgunmirac/BIG-Bench-Hard/tree/main) - [TydiQA](https://github.com/google-research-datasets/tydiqa) - [Codex HumanEval](https://github.com/openai/human-eval/tree/master) diff --git a/eval/MATH/examplars.py b/eval/MATH/examplars.py new file mode 100644 index 00000000..2c1e20a7 --- /dev/null +++ b/eval/MATH/examplars.py @@ -0,0 +1,23 @@ +# These examplars are from the DeepSeekMath GitHub repository (https://github.com/deepseek-ai/DeepSeek-Math/tree/main/evaluation/few_shot_prompts) +EXAMPLARS = [ + { + "question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", + "cot_answer": "The expressions inside each square root must be non-negative.\nTherefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", + "short_answer": "[2,5)" + }, + { + "question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", + "cot_answer": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$", + "short_answer": "24" + }, + { + "question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", + "cot_answer": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}\n30n&=480\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}", + "short_answer": "16" + }, + { + "question": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.", + "cot_answer": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$", + "short_answer": "-\\frac{2}{3}" + } +] \ No newline at end of file diff --git a/eval/MATH/minerva_utils.py b/eval/MATH/minerva_utils.py new file mode 100644 index 00000000..2d17af14 --- /dev/null +++ b/eval/MATH/minerva_utils.py @@ -0,0 +1,309 @@ +''' +Utils from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py +''' +import re + +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "ft", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + +def normalize_final_answer(final_answer: str) -> str: + """ + Normalize a final answer to a quantitative reasoning question. + + Copied character for character from appendix D of Lewkowycz et al. (2022) + """ + final_answer = final_answer.split("=")[-1] + + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract answer that is in LaTeX math, is bold, + # is surrounded by a box, etc. + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize 100,000 -> 100000 + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer + +def get_unnormalized_answer(text: str) -> str: + INVALID_ANSWER = "[invalidanswer]" + end_seq = "I hope it is correct." + text += end_seq + match = re.search( + r"Final Answer: The final answer is(.*?). I hope it is correct.", + text, + ) + if match: + return match.group(1).strip() + else: + return INVALID_ANSWER + +# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py +def is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + print("WARNING: Both None") + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def remove_boxed(s): + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + +def fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except AssertionError: + return string + + +def remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = fix_a_slash_b(string) + + return string \ No newline at end of file diff --git a/eval/MATH/run_eval.py b/eval/MATH/run_eval.py new file mode 100644 index 00000000..cbde5394 --- /dev/null +++ b/eval/MATH/run_eval.py @@ -0,0 +1,321 @@ +import argparse +import json +import os +import random +import torch +import vllm + +from eval.utils import ( + generate_completions, + load_hf_lm, + query_openai_chat_model, + dynamic_import_function, + load_hf_tokenizer +) +from eval.MATH.examplars import EXAMPLARS as MATH_EXAMPLARS +from eval.MATH.utilities import last_boxed_only_string, remove_boxed +from eval.MATH.minerva_utils import normalize_final_answer, get_unnormalized_answer, is_equiv + +DEFAULT_PROMPT_PREFIX_COT = "Solve the question below by reasoning step by step, and put the final answer within \\boxed{}." +DEFAULT_PROMPT_PREFIX_NO_COT = "Answer the following question." + +DEFAULT_PROMPT_TEMPLATE_COT = """Question: %s\nSolution: %s""" +DEFAULT_PROMPT_TEMPLATE_NO_COT = """Question: %s\nAnswer: %s""" + +def main(args): + random.seed(42) + + print("Loading data...") + test_data = [] + with open(os.path.join(args.data_dir, f"test.jsonl")) as fin: + for line in fin: + example = json.loads(line) + test_data.append({ + "question": example["problem"], + "answer": normalize_final_answer(remove_boxed(last_boxed_only_string((example["solution"])))), + "type": example["type"] + }) + + if args.max_num_examples and len(test_data) > args.max_num_examples: + test_data = random.sample(test_data, args.max_num_examples) + + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir, exist_ok=True) + + global MATH_EXAMPLARS + if args.n_shot: + if len(MATH_EXAMPLARS) > args.n_shot: + MATH_EXAMPLARS = random.sample(MATH_EXAMPLARS, args.n_shot) + demonstrations = [] + for example in MATH_EXAMPLARS: + if args.no_cot: + demonstrations.append( + ("Problem:\n" + example["question"] + "\n\n" + "Solution:", example["short_answer"]) + ) + else: + demonstrations.append( + ("Problem:\n" + example["question"] + "\n\n" + "Solution:", example["cot_answer"] + "\n" + "Final Answer: " + f"The final answer is ${example['short_answer']}$. I hope it is correct.") + ) + initial_demonstrations = "\n\n".join(["\n".join(d) for d in demonstrations]) + else: + demonstrations = [] + + if args.use_chat_format: + chat_formatting_function = dynamic_import_function(args.chat_formatting_function) + def apply_chat_format(example, demonstrations, tokenizer): + messages = [] + for user_turn, assistant_turn in demonstrations: + messages.append({"role": "user", "content": user_turn}) + messages.append({"role": "assistant", "content": assistant_turn}) + messages += [{"role": "user", "content": "Problem:\n" + example["question"].strip() + "\n\nSolution:"}] + prompt = chat_formatting_function(messages, tokenizer, add_bos=False) + return prompt + + if args.model_name_or_path: + print("Loading model and tokenizer...") + tokenizer = load_hf_tokenizer( + model_name_or_path=args.model_name_or_path, + tokenizer_name_or_path=args.tokenizer_name_or_path, + use_fast_tokenizer=not args.use_slow_tokenizer, + ) + if args.use_vllm: + model = vllm.LLM( + model=args.model_name_or_path, + tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path, + tokenizer_mode="slow" if args.use_slow_tokenizer else "auto", + tensor_parallel_size=torch.cuda.device_count(), + ) + stop_strings = args.additional_stop_sequence + ["Problem:"] + # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). + # For chat format, we will rely on the model knows when to stop. + if not args.use_chat_format: + stop_strings += ["\n"] + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=args.max_new_tokens, + stop=stop_strings, + ) + if args.use_chat_format: + prompts = [apply_chat_format(example, demonstrations, tokenizer) for example in test_data] + else: + if args.no_cot: + prompts = [initial_demonstrations + "\n\nProblem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] + else: + prompts = [initial_demonstrations + "\n\nProblem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] + generations = model.generate(prompts, sampling_params) + prompt_to_output = { + g.prompt: g.outputs[0].text for g in generations + } + outputs = [prompt_to_output[prompt] if prompt in prompt_to_output else "" for prompt in prompts] + else: + model = load_hf_lm( + model_name_or_path=args.model_name_or_path, + load_in_8bit=args.load_in_8bit, + device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto", + gptq_model=args.gptq, + ) + from transformers import GPTNeoXForCausalLM, OPTForCausalLM + if isinstance(model, GPTNeoXForCausalLM) or isinstance(model, OPTForCausalLM): + tokenizer.model_max_length = model.config.max_position_embeddings + print("Set tokenizer.model_max_length to model.config.max_position_embeddings: {}".format(model.config.max_position_embeddings)) + if args.use_chat_format: + prompts = [apply_chat_format(example, demonstrations, tokenizer) for example in test_data] + else: + if args.no_cot: + prompts = [initial_demonstrations + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] + else: + prompts = [initial_demonstrations + "Problem:\n" + example["question"].strip() + "\n\nSolution:\n" for example in test_data] + # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. + stop_tokens = [[tokenizer.encode(stop_seq, add_special_tokens=False)[-1]] for stop_seq in args.additional_stop_sequence] + if not args.use_chat_format: + new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start. + stop_tokens += [[new_line_token]] + outputs = generate_completions( + model=model, + tokenizer=tokenizer, + prompts=prompts, + max_new_tokens=512, + batch_size=args.eval_batch_size, + stop_id_sequences=stop_tokens, + do_sample=False, + ) + else: + prompts = [initial_demonstrations + "Problem: " + example["question"].strip() + "\nSolution:" for example in test_data] + instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)] + results = query_openai_chat_model( + engine=args.openai_engine, + instances=instances, + batch_size=args.eval_batch_size if args.eval_batch_size else 10, + output_path=os.path.join(args.save_dir, f"openai_results.jsonl"), + ) + outputs = [result["output"] for result in results] + + predictions = [] + for output in outputs: + output = get_unnormalized_answer(output) + predictions.append(normalize_final_answer(output)) + + predictions = [{ + "question": example["question"], + "answer": example["answer"], + "model_output": output, + "prediction": pred + } for example, output, pred in zip(test_data, outputs, predictions)] + + print("Calculating accuracy...") + correct_list = [] + for pred in predictions: + correct = 1 if is_equiv(pred['prediction'], pred['answer']) else 0 + correct_list.append(correct) + accuracy = round(sum(correct_list) / len(correct_list), ndigits=4) + print(f"Accuracy: {accuracy}") + metrics = { + "accuracy": accuracy + } + + # calculate per-type accuracy + type_correct = {} + type_total = {} + for pred, sample in zip(predictions, test_data): + type_ = sample["type"] + if type_ not in type_correct: + type_correct[type_] = 0 + type_total[type_] = 0 + type_correct[type_] += 1 if is_equiv(pred["prediction"], pred["answer"]) else 0 + type_total[type_] += 1 + type_accuracy = {type_: round(type_correct[type_] / type_total[type_], ndigits=4) for type_ in type_correct} + print("Per-type accuracy:") + for type_, acc in type_accuracy.items(): + print(f"{type_}: {acc}") + metrics["per_type_accuracy"] = type_accuracy + + with open(os.path.join(args.save_dir, f"predictions.jsonl"), "w") as fout: + for prediction in predictions: + fout.write(json.dumps(prediction) + "\n") + + with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout: + json.dump(metrics, fout, indent=4) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="data/gsm" + ) + parser.add_argument( + "--max_num_examples", + type=int, + default=None, + help="maximum number of examples to evaluate." + ) + parser.add_argument( + "--save_dir", + type=str, + default="results/gsm" + ) + parser.add_argument( + "--model_name_or_path", + type=str, + default=None, + help="if specified, we will load the model to generate the predictions." + ) + parser.add_argument( + "--tokenizer_name_or_path", + type=str, + default=None, + help="if specified, we will load the tokenizer from here." + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If given, we will use the slow tokenizer." + ) + parser.add_argument( + "--openai_engine", + type=str, + default=None, help="if specified, we will use the OpenAI API to generate the predictions." + ) + parser.add_argument( + "--n_shot", + type=int, + default=8, + help="max number of examples to use for demonstration." + ) + parser.add_argument( + "--no_cot", + action="store_true", + help="If given, we're evaluating a model without chain-of-thought." + ) + parser.add_argument( + '--max_new_tokens', + type=int, + default=1024, + help="maximum number of tokens to generate for each prompt." + ) + parser.add_argument( + "--eval_batch_size", + type=int, + default=1, + help="batch size for evaluation." + ) + parser.add_argument( + "--load_in_8bit", + action="store_true", + help="load model in 8bit mode, which will reduce memory and speed up inference." + ) + parser.add_argument( + "--gptq", + action="store_true", + help="If given, we're evaluating a 4-bit quantized GPTQ model." + ) + parser.add_argument( + "--use_vllm", + action="store_true", + help="If given, we will use the vllm library, which will likely increase the inference throughput." + ) + parser.add_argument( + "--use_chat_format", + action="store_true", + help="If given, we will use the chat format for the prompts." + ) + parser.add_argument( + "--chat_formatting_function", + type=str, + default="eval.templates.create_prompt_with_tulu_chat_format", + help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`." + ) + parser.add_argument( + "--prompt_prefix", + type=str, + default=None, + help="the specific prefix to use for instructing the model." + ) + parser.add_argument( + "--prompt_template", + type=str, + default=None, + help="the specific template to use for instructing the model." + ) + parser.add_argument( + '--additional_stop_sequence', + type=str, + nargs="+", + default=[], + help="Additional stop sequences to use when generating completions. Useful for e.g. llama-3-instruct." + ) + args = parser.parse_args() + + # update the prompt prefix depending on whether CoT is being used + if args.prompt_prefix is None: + args.prompt_prefix = DEFAULT_PROMPT_PREFIX_NO_COT if args.no_cot else DEFAULT_PROMPT_PREFIX_COT + + # update the prompt template depending on whether CoT is being used + if args.prompt_template is None: + args.prompt_template = DEFAULT_PROMPT_TEMPLATE_NO_COT if args.no_cot else DEFAULT_PROMPT_TEMPLATE_COT + + # model_name_or_path and openai_engine cannot be both None or both not None. + assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified." + main(args) \ No newline at end of file diff --git a/eval/MATH/utilities.py b/eval/MATH/utilities.py new file mode 100644 index 00000000..264f6c2b --- /dev/null +++ b/eval/MATH/utilities.py @@ -0,0 +1,32 @@ +def last_boxed_only_string(string: str): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return "" + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + if right_brace_idx == None: + retval = "" + else: + retval = string[idx:right_brace_idx + 1] + return retval + +def remove_boxed(s: str): + left = "\\boxed{" + try: + assert s[:len(left)] == left + assert s[-1] == "}" + return s[len(left):-1] + except: + return "" diff --git a/eval/dispatch_openai_requests.py b/eval/dispatch_openai_requests.py index 4724ab84..dedb9ce7 100644 --- a/eval/dispatch_openai_requests.py +++ b/eval/dispatch_openai_requests.py @@ -4,9 +4,13 @@ ''' import asyncio from typing import Any, List, Dict -from openai import AsyncOpenAI +from openai import AsyncOpenAI, OpenAIError -aclient = AsyncOpenAI() +try: + aclient = AsyncOpenAI() +except OpenAIError as e: + print(f"Error initializing OpenAI client: {e}") + print("If you are running an eval without OpenAI models, this is okay.") async def dispatch_openai_chat_requesets( messages_list: List[List[Dict[str,Any]]], diff --git a/requirements.txt b/requirements.txt index 3dac90c5..61962d61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,4 +36,8 @@ openpyxl # for ifeval nltk langdetect -immutabledict \ No newline at end of file +immutabledict +# for math evaluations +antlr4-python3-runtime==4.11.0 +mpmath==1.3.0 +sympy==1.12.0 \ No newline at end of file diff --git a/scripts/eval/MATH.sh b/scripts/eval/MATH.sh new file mode 100644 index 00000000..54f6cfa9 --- /dev/null +++ b/scripts/eval/MATH.sh @@ -0,0 +1,93 @@ +# Here we use 1 GPU for demonstration, but you can use multiple GPUs and larger eval_batch_size to speed up the evaluation. +export CUDA_VISIBLE_DEVICES=0 + + +# Evaluating llama 7B model using chain-of-thought +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/llama-7B-cot-4shot \ + --model ../hf_llama_models/7B \ + --tokenizer ../hf_llama_models/7B \ + --n_shot 4 \ + --use_vllm + + +# Evaluating llama 7B model using direct answering (no chain-of-thought) +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/gsm/llama-7B-no-cot-4shot \ + --model ../hf_llama_models/7B \ + --tokenizer ../hf_llama_models/7B \ + --n_shot 4 \ + --no_cot \ + --use_vllm + + +# Evaluating tulu 7B model using chain-of-thought and chat format +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/tulu-7B-cot-4shot \ + --model ../checkpoints/tulu_7B \ + --tokenizer ../checkpoints/tulu_7B \ + --n_shot 4 \ + --use_chat_format \ + --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \ + --use_vllm + + +# Evaluating llama2 chat model using chain-of-thought and chat format +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/llama2-chat-7B-cot-4shot \ + --model ../hf_llama2_models/7B-chat \ + --tokenizer ../hf_llama2_models/7B-chat \ + --n_shot 4 \ + --use_chat_format \ + --chat_formatting_function eval.templates.create_prompt_with_llama2_chat_format \ + --use_vllm + + +# Evaluating chatgpt using chain-of-thought +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/chatgpt-cot \ + --openai_engine "gpt-3.5-turbo-0301" \ + --eval_batch_size 20 \ + --n_shot 4 + + +# Evaluating chatgpt using direct answering (no chain-of-thought) +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/chatgpt-no-cot \ + --openai_engine "gpt-3.5-turbo-0301" \ + --eval_batch_size 20 \ + --n_shot 4 \ + --no_cot + + +# Evaluating gpt4 using chain-of-thought +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/gpt4-cot \ + --openai_engine "gpt-4-0314" \ + --eval_batch_size 20 \ + --n_shot 4 + + +# Evaluating gpt4 using direct answering (no chain-of-thought) +python -m eval.MATH.run_eval \ + --data_dir data/eval/MATH/ \ + --max_num_examples 200 \ + --save_dir results/MATH/gpt4-no-cot \ + --openai_engine "gpt-4-0314" \ + --eval_batch_size 20 \ + --n_shot 4 \ + --no_cot diff --git a/scripts/prepare_eval_data.sh b/scripts/prepare_eval_data.sh index f1b53f3b..866e44b3 100755 --- a/scripts/prepare_eval_data.sh +++ b/scripts/prepare_eval_data.sh @@ -24,6 +24,9 @@ wget -P data/eval/tydiqa/ https://storage.googleapis.com/tydiqa/v1.1/tydiqa-gold # GSM dataset wget -P data/eval/gsm/ https://github.com/openai/grade-school-math/raw/master/grade_school_math/data/test.jsonl +# MATH dataset +mkdir -p data/eval/MATH +wget -P data/eval/MATH/ https://raw.githubusercontent.com/deepseek-ai/DeepSeek-Math/main/evaluation/datasets/math/test.jsonl # Codex HumanEval wget -P data/eval/codex_humaneval https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz @@ -56,4 +59,4 @@ wget -P data/eval/xstest https://github.com/paul-rottger/exaggerated-safety/raw/ # we use self-instruct test set, and vicuna test set for our human evaluation mkdir -p data/eval/creative_tasks wget -O data/eval/creative_tasks/self_instruct_test.jsonl https://github.com/yizhongw/self-instruct/raw/main/human_eval/user_oriented_instructions.jsonl -wget -O data/eval/creative_tasks/vicuna_test.jsonl https://github.com/lm-sys/FastChat/raw/main/fastchat/eval/table/question.jsonl \ No newline at end of file +wget -O data/eval/creative_tasks/vicuna_test.jsonl https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/vicuna_bench/question.jsonl \ No newline at end of file diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index cd223b6d..ceff70d3 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -121,6 +121,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): "mmlu_5shot", "gsm_direct", "gsm_cot", + "MATH_cot", "bbh_direct", "bbh_cot", "tydiqa_goldp_1shot", @@ -239,6 +240,19 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): ''' if args.gsm_stop_at_double_newline: task_spec['arguments'][0] += " --stop_at_double_newline" + elif experiment_group == "MATH_cot": + task_spec['arguments'][0] = ''' + python -m eval.MATH.run_eval \ + --data_dir /data/MATH/ \ + --max_num_examples 200 \ + --save_dir /output/ \ + --use_vllm \ + --model_name_or_path /model \ + --tokenizer_name_or_path /model \ + --n_shot 4 \ + --use_chat_format \ + --chat_formatting_function eval.templates.create_prompt_with_tulu_chat_format \ + ''' elif experiment_group == "tydiqa_goldp_1shot": task_spec["arguments"][0] = ''' python -m eval.tydiqa.run_eval \