Created
October 1, 2024 14:17
-
-
Save AMR-KELEG/3a2911ba977e9e89744bb81084f59416 to your computer and use it in GitHub Desktop.
Automatically estimate the ALDi score and dialect of sentences
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| import torch | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| DIALECTS = [ | |
| "Algeria", | |
| "Bahrain", | |
| "Egypt", | |
| "Iraq", | |
| "Jordan", | |
| "Kuwait", | |
| "Lebanon", | |
| "Libya", | |
| "Morocco", | |
| "Oman", | |
| "Palestine", | |
| "Qatar", | |
| "Saudi_Arabia", | |
| "Sudan", | |
| "Syria", | |
| "Tunisia", | |
| "UAE", | |
| "Yemen", | |
| ] | |
| assert len(DIALECTS) == 18 | |
| di_model_name = "AMR-KELEG/NADI2024-baseline" | |
| di_tokenizer = AutoTokenizer.from_pretrained(di_model_name) | |
| di_model = AutoModelForSequenceClassification.from_pretrained(di_model_name) | |
| di_model = di_model.to("cuda") | |
| def predict_top_p(text, P=0.9): | |
| """Predict the top dialects with an accumulative confidence of at least P.""" | |
| assert P <= 1 and P >= 0 | |
| logits = di_model( | |
| **di_tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to( | |
| "cuda" | |
| ) | |
| ).logits | |
| probabilities = torch.softmax(logits, dim=1).flatten().tolist() | |
| topk_predictions = torch.topk(logits, 18).indices.flatten().tolist() | |
| predictions = [0 for _ in range(18)] | |
| total_prob = 0 | |
| for i in range(18): | |
| total_prob += probabilities[topk_predictions[i]] | |
| predictions[topk_predictions[i]] = 1 | |
| if total_prob >= P: | |
| break | |
| return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1] | |
| # ALDi model | |
| aldi_model_name = "AMR-KELEG/Sentence-ALDi" | |
| aldi_tokenizer = AutoTokenizer.from_pretrained(aldi_model_name) | |
| aldi_model = AutoModelForSequenceClassification.from_pretrained( | |
| aldi_model_name, attn_implementation="sdpa" | |
| ) | |
| aldi_model = aldi_model.to("cuda") | |
| def preprocess_text(arabic_text): | |
| """Apply preprocessing to the given Arabic text. | |
| Args: | |
| arabic_text: The Arabic text to be preprocessed. | |
| Returns: | |
| The preprocessed Arabic text. | |
| """ | |
| no_urls = re.sub( | |
| r"(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b", | |
| "", | |
| arabic_text, | |
| flags=re.MULTILINE, | |
| ) | |
| no_english = re.sub(r"[a-zA-Z]", "", no_urls) | |
| return no_english | |
| def compute_ALDi_batch(sentences): | |
| """Computes the ALDi score for the given sentences. | |
| Args: | |
| sentences: A list of Arabic sentences. | |
| Returns: | |
| A list of ALDi scores for the given sentences. | |
| """ | |
| preprocessed_sentences = [preprocess_text(str(sentence)) for sentence in sentences] | |
| inputs = aldi_tokenizer( | |
| preprocessed_sentences, | |
| return_tensors="pt", | |
| padding=True, | |
| max_length=512, | |
| truncation=True, | |
| ).to("cuda") | |
| output = aldi_model(**inputs).logits.reshape(-1).tolist() | |
| return output | |
| # TODO: Check if this sorting step is useful! | |
| def compute_ALDi_after_sorting_and_batching(sentences): | |
| sentence_to_index_dict = {sentence: i for i, sentence in enumerate(sentences)} | |
| raw_sentences = list(sentence_to_index_dict.keys()) | |
| raw_sentences = sorted(raw_sentences, key=lambda s: len(s)) | |
| BATCH_SIZE = 32 | |
| ALDi_scores_batched = [ | |
| compute_ALDi_batch(raw_sentences[i : i + BATCH_SIZE]) | |
| for i in tqdm(range(0, len(raw_sentences), BATCH_SIZE)) | |
| ] | |
| ALDi_scores = [s for l in ALDi_scores_batched for s in l] | |
| # TODO: Refactor this! | |
| sorted_ALDi_scores = [0 for _ in range(len(sentences))] | |
| for s, ALDi_score in zip(raw_sentences, ALDi_scores): | |
| try: | |
| sorted_ALDi_scores[sentence_to_index_dict[s]] = ALDi_score | |
| except: | |
| pass | |
| return [min(max(0, s), 1) for s in sorted_ALDi_scores] | |
| if __name__ == "__main__": | |
| # Specify (1) input filename, (2) output filename, (3) name of the column with text data | |
| input_filename = "" | |
| output_filename = "" | |
| text_column_name = "" | |
| df = pd.read_csv(input_filename) | |
| df["ALDi"] = compute_ALDi_after_sorting_and_batching(df[text_column_name].tolist()) | |
| df["dialects"] = df[text_column_name].apply(lambda s: predict_top_p(s)) | |
| df.to_csv(output_filename, index=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment