Skip to content

Instantly share code, notes, and snippets.

@rahulunair
Last active March 8, 2024 02:37
Show Gist options
  • Select an option

  • Save rahulunair/b1f2dd45141054aaeda3eda30b7ad850 to your computer and use it in GitHub Desktop.

Select an option

Save rahulunair/b1f2dd45141054aaeda3eda30b7ad850 to your computer and use it in GitHub Desktop.
import warnings
warnings.simplefilter(action="ignore")
from transformers.utils import logging
logging.set_verbosity_error()
import torch
import intel_extension_for_pytorch as ipex
from transformers import AutoModelForCausalLM, AutoTokenizer
def chat_with_model(device="xpu"):
model_path = "microsoft/DialoGPT-small"
tokenizer_path = "microsoft/DialoGPT-small"
model_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(model_path).to(torch.bfloat16)
if device == "xpu":
device = torch.device("xpu")
else:
device = torch.device("cpu")
model.to(device)
print("Chatbot Ready! Start chatting! (Type 'quit' to exit)")
while True:
user_input = input("You: ")
if user_input.lower() == "quit":
break
if not user_input.strip():
print("Please enter a valid message.")
continue
new_input_ids = model_tokenizer.encode(
user_input + model_tokenizer.eos_token, return_tensors="pt"
).to(device)
response_ids = model.generate(
new_input_ids,
max_length=150,
pad_token_id=model_tokenizer.eos_token_id,
do_sample=True,
temperature=0.7,
top_p=0.95,
)
response = model_tokenizer.decode(
response_ids[:, new_input_ids.shape[-1] :][0], skip_special_tokens=True
)
print("Chatbot: ", response)
if __name__ == "__main__":
chat_with_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment