Last active
March 8, 2024 02:37
-
-
Save rahulunair/b1f2dd45141054aaeda3eda30b7ad850 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import warnings | |
| warnings.simplefilter(action="ignore") | |
| from transformers.utils import logging | |
| logging.set_verbosity_error() | |
| import torch | |
| import intel_extension_for_pytorch as ipex | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def chat_with_model(device="xpu"): | |
| model_path = "microsoft/DialoGPT-small" | |
| tokenizer_path = "microsoft/DialoGPT-small" | |
| model_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side="left") | |
| model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.bfloat16) | |
| if device == "xpu": | |
| device = torch.device("xpu") | |
| else: | |
| device = torch.device("cpu") | |
| model.to(device) | |
| print("Chatbot Ready! Start chatting! (Type 'quit' to exit)") | |
| while True: | |
| user_input = input("You: ") | |
| if user_input.lower() == "quit": | |
| break | |
| if not user_input.strip(): | |
| print("Please enter a valid message.") | |
| continue | |
| new_input_ids = model_tokenizer.encode( | |
| user_input + model_tokenizer.eos_token, return_tensors="pt" | |
| ).to(device) | |
| response_ids = model.generate( | |
| new_input_ids, | |
| max_length=150, | |
| pad_token_id=model_tokenizer.eos_token_id, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| ) | |
| response = model_tokenizer.decode( | |
| response_ids[:, new_input_ids.shape[-1] :][0], skip_special_tokens=True | |
| ) | |
| print("Chatbot: ", response) | |
| if __name__ == "__main__": | |
| chat_with_model() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment