Created
December 10, 2024 00:34
-
-
Save hppRC/b820aa6fb63c6b1021b8338ada32381c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import PreTrainedTokenizer | |
| from vllm import LLM, SamplingParams | |
| from vllm.outputs import RequestOutput | |
| import datasets as ds | |
| def build_input_text(text: str, tokenizer: PreTrainedTokenizer) -> str: | |
| text = text.strip() | |
| # 翻訳をさせる場合のプロンプト | |
| # apply_chat_template関数を利用する場合はdictのlistを作る | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": f""" | |
| You are an outstanding English--Japanese translator. Please follow the instructions below to translate the given English Wikipedia text into Japanese. | |
| ### Instructions | |
| 1. Translate the English Wikipedia text into Japanese with the same meaning. | |
| 2. You can change the word order or the sequence of sentences, but you cannot add or omit any content. | |
| 3. Output should be fluent, polished and natural Japanese. | |
| 4. Output only the translated passage in Japanese. | |
| Passage: {text} | |
| """.strip(), | |
| }, | |
| ] | |
| # apply_chat_template関数がいい感じにフォーマットを整えてくれる | |
| input_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| return input_text | |
| def generate_text(texts: list[str], llm: LLM) -> ds.Dataset: | |
| tokenizer: PreTrainedTokenizer = llm.get_tokenizer() | |
| # chat_templateを適用して入力テキストのフォーマットを整える | |
| input_texts = [build_input_text(t, tokenizer) for t in texts] | |
| # 生成時のパラメータ調整 | |
| # 温度パラメータが高いほどランダム性が増す | |
| sampling_params = SamplingParams( | |
| temperature=0.99, | |
| top_p=0.95, | |
| max_tokens=2048, | |
| ) | |
| # LLMを使って生成、勝手にGPU並列化・バッチ処理してくれる | |
| responses: list[RequestOutput] = llm.generate( | |
| input_texts, | |
| sampling_params=sampling_params, | |
| ) | |
| # 出力されたテキストを取り出す | |
| output_texts: list[str] = [response.outputs[0].text for response in responses] | |
| return output_texts | |
| def main(): | |
| model_name = "google/gemma-2-9b-it" | |
| llm = LLM( | |
| model_name, | |
| trust_remote_code=True, | |
| tensor_parallel_size=2, # 2 GPU並列 | |
| quantization=None, | |
| dtype="bfloat16", | |
| gpu_memory_utilization=0.95, | |
| enforce_eager=False, | |
| max_model_len=4096, | |
| ) | |
| dataset = ds.load_dataset("hpprc/enwiki-100", split="train") | |
| texts = dataset["text"] | |
| output_texts = generate_text(texts, llm) | |
| print("=" * 80) | |
| print(texts[0]) | |
| print(output_texts[0]) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment