Skip to content

Instantly share code, notes, and snippets.

@teelinsan
Created March 25, 2021 13:55
Show Gist options
  • Select an option

  • Save teelinsan/1889b75819ab1903320940c0434ca0ca to your computer and use it in GitHub Desktop.

Select an option

Save teelinsan/1889b75819ab1903320940c0434ca0ca to your computer and use it in GitHub Desktop.
Simple script to calculate the "multiple_choice_grade” accuracy for BIGBench. Run the script in the main folder of the project.
from transformers import pipeline
from tqdm import tqdm
import json
def check_mnli_benchmark(benchmark_json_path, classifier, base_path='bigbench/benchmark_tasks/', use_prefix_out=True):
"""
Function to compute the “multiple_choice_grade” accuracy of the classifier provided.
The script save in the path a result json with the score and the list of the results for the elements in the dataset
Args:
benchmark_json_path: str path of the json file with the task
classifier: transformers.pipeline classifier used to predict
base_path: str base path of the benchmarks
use_prefix_out: bool whether to include or not the possible answers in the prefix text
Returns:
None
"""
base_path = base_path
json_benchmark_path = base_path + benchmark_json_path
with open(json_benchmark_path) as json_file:
data = json.load(json_file)
total_results = []
task_prefix = data['task_prefix']
for example in tqdm(data['examples']):
input = example['input']
target_scores = list(example['target_scores'].keys())
output_choice = ''
for choice in target_scores:
output_choice += 'choice: ' + choice
output_choice += 'A: '
if use_prefix_out:
tot_input = task_prefix + input + output_choice
else:
tot_input = task_prefix + input
dict_res = classifier(tot_input, target_scores)
if example['target_scores'][dict_res['labels'][0]] == 1:
total_results.append(1)
else:
total_results.append(0)
if len(total_results) == 0:
tot_elem = 1
else:
tot_elem = len(total_results)
accuracy = sum(total_results) / tot_elem
print(f'Accuracy on the {json_benchmark_path}: {accuracy}')
dict_to_save = {
'task': json_benchmark_path,
'accuracy': accuracy,
'result_list': total_results
}
path_to_save = base_path + benchmark_json_path.split('.')[0] + '_results_mlni.json'
with open(path_to_save, 'w') as outfile:
json.dump(dict_to_save, outfile)
if __name__ == '__main__':
model_name = "facebook/bart-large-mnli"
# model_name = "roberta-large-mnli"
list_path = ['SIB-plain', 'SIB-tricky', 'SIB-adversarial', 'SIB-emoji-agnostic', 'SIB-name-agnostic']
classifier = pipeline("zero-shot-classification", model=model_name, framework='pt', device=0)
for elm in list_path:
path = elm + '/task.json'
check_mnli_benchmark(path, classifier)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment