Last active
February 12, 2025 11:41
-
-
Save hppRC/554394fea0dde1ef0632163347815fad to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| import uuid | |
| from pathlib import Path | |
| import click | |
| from vllm import LLM, SamplingParams | |
| from vllm.outputs import RequestOutput | |
| import datasets as ds | |
| from src.data.common import normalize_text | |
| def make_input_text(passage: str, tokenizer) -> str: | |
| passage = passage.strip() | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "あなたは親切なAIアシスタントです。", | |
| }, | |
| { | |
| "role": "user", | |
| "content": f""" | |
| 以下の指示に従い、与えられた日本語Wikipediaの文章を言い換えてください。 | |
| ### 指示 | |
| 1. 日本語Wikipediaの文章を同じ意味の別の文章に言い換えてください | |
| 2. 言い換えた文章は元の文章と異なるものである必要があります。ただし、意味が全く変化しないようにしてください | |
| 3. 語順や文の順番を変えても構いませんが、内容を追加したり削除したりすることはできません | |
| 4. 文章全体で内容と意味が変化しないのであれば、文の順番をできるだけ入れ替えてください。 | |
| 5. 言い換えた文章のみを出力してください | |
| 文章: {passage} | |
| """.strip(), | |
| }, | |
| ] | |
| input_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| return input_text | |
| def create_dataset(texts: list[str], llm: LLM, tokenizer) -> ds.Dataset: | |
| sampling_params = SamplingParams( | |
| temperature=0.99, | |
| top_p=0.95, | |
| max_tokens=8192, | |
| ) | |
| inputs_text = [make_input_text(t, tokenizer) for t in texts] | |
| responses: list[RequestOutput] = llm.generate( | |
| inputs_text, | |
| sampling_params=sampling_params, | |
| ) | |
| output_texts: list[str] = [response.outputs[0].text.strip() for response in responses] | |
| output_texts = [normalize_text(t) for t in output_texts] | |
| return output_texts | |
| @click.command() | |
| @click.option("--dtype", type=str, default="bf16") | |
| @click.option("--tp", type=int, default=1) | |
| @click.option("--bs", type=int, default=10000) | |
| def main(dtype: str, tp: int, bs: int): | |
| model_name = "cyberagent/calm3-22b-chat" | |
| root_dir = Path("datasets/wiki_paraphrase4/calm3_22b") | |
| max_file_size = 1_000_000 | |
| if dtype == "bf16": | |
| dtype = "bfloat16" | |
| enable_prefix_caching = True | |
| elif dtype == "fp16": | |
| dtype = "float16" | |
| enable_prefix_caching = False | |
| else: | |
| raise ValueError(f"Invalid dtype: {dtype}") | |
| rng = random.SystemRandom() | |
| seed = rng.randint(0, 2**32 - 1) | |
| llm = LLM( | |
| model_name, | |
| trust_remote_code=True, | |
| tensor_parallel_size=tp, | |
| quantization=None, | |
| dtype=dtype, | |
| gpu_memory_utilization=0.95, | |
| seed=seed, | |
| enforce_eager=False, | |
| enable_prefix_caching=enable_prefix_caching, | |
| ) | |
| tokenizer = llm.get_tokenizer() | |
| # ここからサーバ依存の処理 | |
| ... | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment