I'm tryting to get stats of the inference time of different code-completion models on the HumanEval dataset. Since timing is a crucial part of this project, I don't want to time the model when it generates irrelevant tokens. Thus, I hope to implement StoppingCriteria on the code-completion models, namely models from the Codegen, Code LLAMA, and WizardCoder families.
Currently, when the model generates the full answer but hasn't reached the max number of new tokens (here I set it to 200), it might end with an <|endoftext|>
token, but more often it would generate double new lines and continue generating irrelevant text. This largely affects the timing.
Therefore, I hope the generation can stop when it first encounters a "\n\n"
token, or two consecutive \n
tokens ([**"\n", "\n"**]
). How can I implement this?
To simplify the testing case, here I set the batch size to 1 for each generation. I'd appreciate if it also works when I set num_return_sequences to k, so I can get pass@k stats.
The environment is pulled on 08-29-2023 from the latest huggingface transformers main branch, v4.33. The github repo is provided below:https://github.com/huggingface/transformers
The Python environment should be above 3.8.0. To test with various model checkpoints, use the checkpoint names are given in the comments. I recommend to test with smaller models if you don't have enough GPU VRAM.
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaListimport timeimport argparseimport torchparser = argparse.ArgumentParser()parser.add_argument("--checkpoint", type=str, default="Salesforce/codegen-2B-mono", help="Model path")FLAGS = parser.parse_args()# WizardCoder Family# WizardLM/WizardCoder-Python-34B-V1.0# WizardLM/WizardCoder-Python-13B-V1.0# WizardLM/WizardCoder-15B-V1.0# WizardLM/WizardCoder-3B-V1.0# WizardLM/WizardCoder-1B-V1.0# Code LLAMA 2 Family# codellama/CodeLlama-7b-hf# codellama/CodeLlama-13b-hf# codellama/CodeLlama-34b-hf# Salesforce Codegen Family# Salesforce/codegen-350M-mono# Salesforce/codegen-2B-mono# Salesforce/codegen-6B-mono# Salesforce/codegen-16B-monostop_words = ["\n\n"]# HumanEval Q0prompt_0 = "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n"# HumanEval Q31prompt_31 = "\n\ndef is_prime(n):\n \"\"\"Return true if a given number is prime, and false otherwise.\n >>> is_prime(6)\n False\n >>> is_prime(101)\n True\n >>> is_prime(11)\n True\n >>> is_prime(13441)\n True\n >>> is_prime(61)\n True\n >>> is_prime(4)\n False\n >>> is_prime(1)\n False\n \"\"\"\n"# HumanEval Q35prompt_35 = "\n\ndef max_element(l: list):\n \"\"\"Return maximum element in the list.\n >>> max_element([1, 2, 3])\n 3\n >>> max_element([5, 3, -5, 2, -3, 3, 9, 0, 123, 1, -10])\n 123\n \"\"\"\n"# HumanEval Q161prompt_161 = "\ndef solve(s):\n \"\"\"You are given a string s.\n if s[i] is a letter, reverse its case from lower to upper or vise versa, \n otherwise keep it as it is.\n If the string contains no letters, reverse the string.\n The function should return the resulted string.\n Examples\n solve(\"1234\") = \"4321\"\n solve(\"ab\") = \"AB\"\n solve(\"#a@C\") = \"#A@c\"\n \"\"\"\n"def main(args): # Initialize model and tokenizer checkpoint = args.checkpoint tokenizer = AutoTokenizer.from_pretrained(checkpoint, device_map="auto") start_load_model = time.time() model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto") print(f"Time to load model {checkpoint} is {time.time() - start_load_model}") # Generate the selcted prompts for prompt in [prompt_0, prompt_31, prompt_35, prompt_161]: input_ids = tokenizer(prompt, return_tensors="pt").input_ids start_generating = time.time() generated_ids = model.generate( input_ids, use_cache = True, pad_token_id = tokenizer.eos_token_id, max_new_tokens = 200, do_sample = True, temperature = 0.8, num_beams=1, # stopping_criteria=stopping_criteria, ) generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) decoded_list = [] for ids in generated_ids[0]: word = tokenizer.decode(int(ids)) decoded_list.append(word) generated_len = len(decoded_list) - len(input_ids[0]) # Print outputs print(f"Time to generate is {time.time() - start_generating}") print(f"per token time is {(time.time()-start_generating)/generated_len}") print(f"decoded_list is {decoded_list[:generated_len]}") prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids prompt = tokenizer.decode(prompt_ids[0]) print(f"\ngenerated_text is:\n{generated_text[0]}")if __name__== "__main__": main(FLAGS)