Skip to content

Instantly share code, notes, and snippets.

@hppRC
Created December 10, 2024 00:34
Show Gist options
  • Select an option

  • Save hppRC/b820aa6fb63c6b1021b8338ada32381c to your computer and use it in GitHub Desktop.

Select an option

Save hppRC/b820aa6fb63c6b1021b8338ada32381c to your computer and use it in GitHub Desktop.
from transformers import PreTrainedTokenizer
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
import datasets as ds
def build_input_text(text: str, tokenizer: PreTrainedTokenizer) -> str:
text = text.strip()
# 翻訳をさせる場合のプロンプト
# apply_chat_template関数を利用する場合はdictのlistを作る
messages = [
{
"role": "user",
"content": f"""
You are an outstanding English--Japanese translator. Please follow the instructions below to translate the given English Wikipedia text into Japanese.
### Instructions
1. Translate the English Wikipedia text into Japanese with the same meaning.
2. You can change the word order or the sequence of sentences, but you cannot add or omit any content.
3. Output should be fluent, polished and natural Japanese.
4. Output only the translated passage in Japanese.
Passage: {text}
""".strip(),
},
]
# apply_chat_template関数がいい感じにフォーマットを整えてくれる
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return input_text
def generate_text(texts: list[str], llm: LLM) -> ds.Dataset:
tokenizer: PreTrainedTokenizer = llm.get_tokenizer()
# chat_templateを適用して入力テキストのフォーマットを整える
input_texts = [build_input_text(t, tokenizer) for t in texts]
# 生成時のパラメータ調整
# 温度パラメータが高いほどランダム性が増す
sampling_params = SamplingParams(
temperature=0.99,
top_p=0.95,
max_tokens=2048,
)
# LLMを使って生成、勝手にGPU並列化・バッチ処理してくれる
responses: list[RequestOutput] = llm.generate(
input_texts,
sampling_params=sampling_params,
)
# 出力されたテキストを取り出す
output_texts: list[str] = [response.outputs[0].text for response in responses]
return output_texts
def main():
model_name = "google/gemma-2-9b-it"
llm = LLM(
model_name,
trust_remote_code=True,
tensor_parallel_size=2, # 2 GPU並列
quantization=None,
dtype="bfloat16",
gpu_memory_utilization=0.95,
enforce_eager=False,
max_model_len=4096,
)
dataset = ds.load_dataset("hpprc/enwiki-100", split="train")
texts = dataset["text"]
output_texts = generate_text(texts, llm)
print("=" * 80)
print(texts[0])
print(output_texts[0])
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment