Created
March 25, 2021 13:55
-
-
Save teelinsan/1889b75819ab1903320940c0434ca0ca to your computer and use it in GitHub Desktop.
Simple script to calculate the "multiple_choice_grade” accuracy for BIGBench. Run the script in the main folder of the project.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import pipeline | |
| from tqdm import tqdm | |
| import json | |
| def check_mnli_benchmark(benchmark_json_path, classifier, base_path='bigbench/benchmark_tasks/', use_prefix_out=True): | |
| """ | |
| Function to compute the “multiple_choice_grade” accuracy of the classifier provided. | |
| The script save in the path a result json with the score and the list of the results for the elements in the dataset | |
| Args: | |
| benchmark_json_path: str path of the json file with the task | |
| classifier: transformers.pipeline classifier used to predict | |
| base_path: str base path of the benchmarks | |
| use_prefix_out: bool whether to include or not the possible answers in the prefix text | |
| Returns: | |
| None | |
| """ | |
| base_path = base_path | |
| json_benchmark_path = base_path + benchmark_json_path | |
| with open(json_benchmark_path) as json_file: | |
| data = json.load(json_file) | |
| total_results = [] | |
| task_prefix = data['task_prefix'] | |
| for example in tqdm(data['examples']): | |
| input = example['input'] | |
| target_scores = list(example['target_scores'].keys()) | |
| output_choice = '' | |
| for choice in target_scores: | |
| output_choice += 'choice: ' + choice | |
| output_choice += 'A: ' | |
| if use_prefix_out: | |
| tot_input = task_prefix + input + output_choice | |
| else: | |
| tot_input = task_prefix + input | |
| dict_res = classifier(tot_input, target_scores) | |
| if example['target_scores'][dict_res['labels'][0]] == 1: | |
| total_results.append(1) | |
| else: | |
| total_results.append(0) | |
| if len(total_results) == 0: | |
| tot_elem = 1 | |
| else: | |
| tot_elem = len(total_results) | |
| accuracy = sum(total_results) / tot_elem | |
| print(f'Accuracy on the {json_benchmark_path}: {accuracy}') | |
| dict_to_save = { | |
| 'task': json_benchmark_path, | |
| 'accuracy': accuracy, | |
| 'result_list': total_results | |
| } | |
| path_to_save = base_path + benchmark_json_path.split('.')[0] + '_results_mlni.json' | |
| with open(path_to_save, 'w') as outfile: | |
| json.dump(dict_to_save, outfile) | |
| if __name__ == '__main__': | |
| model_name = "facebook/bart-large-mnli" | |
| # model_name = "roberta-large-mnli" | |
| list_path = ['SIB-plain', 'SIB-tricky', 'SIB-adversarial', 'SIB-emoji-agnostic', 'SIB-name-agnostic'] | |
| classifier = pipeline("zero-shot-classification", model=model_name, framework='pt', device=0) | |
| for elm in list_path: | |
| path = elm + '/task.json' | |
| check_mnli_benchmark(path, classifier) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment