Skip to content

Instantly share code, notes, and snippets.

@hppRC
Last active February 12, 2025 11:41
Show Gist options
  • Select an option

  • Save hppRC/554394fea0dde1ef0632163347815fad to your computer and use it in GitHub Desktop.

Select an option

Save hppRC/554394fea0dde1ef0632163347815fad to your computer and use it in GitHub Desktop.
import random
import uuid
from pathlib import Path
import click
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
import datasets as ds
from src.data.common import normalize_text
def make_input_text(passage: str, tokenizer) -> str:
passage = passage.strip()
messages = [
{
"role": "system",
"content": "あなたは親切なAIアシスタントです。",
},
{
"role": "user",
"content": f"""
以下の指示に従い、与えられた日本語Wikipediaの文章を言い換えてください。
### 指示
1. 日本語Wikipediaの文章を同じ意味の別の文章に言い換えてください
2. 言い換えた文章は元の文章と異なるものである必要があります。ただし、意味が全く変化しないようにしてください
3. 語順や文の順番を変えても構いませんが、内容を追加したり削除したりすることはできません
4. 文章全体で内容と意味が変化しないのであれば、文の順番をできるだけ入れ替えてください。
5. 言い換えた文章のみを出力してください
文章: {passage}
""".strip(),
},
]
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return input_text
def create_dataset(texts: list[str], llm: LLM, tokenizer) -> ds.Dataset:
sampling_params = SamplingParams(
temperature=0.99,
top_p=0.95,
max_tokens=8192,
)
inputs_text = [make_input_text(t, tokenizer) for t in texts]
responses: list[RequestOutput] = llm.generate(
inputs_text,
sampling_params=sampling_params,
)
output_texts: list[str] = [response.outputs[0].text.strip() for response in responses]
output_texts = [normalize_text(t) for t in output_texts]
return output_texts
@click.command()
@click.option("--dtype", type=str, default="bf16")
@click.option("--tp", type=int, default=1)
@click.option("--bs", type=int, default=10000)
def main(dtype: str, tp: int, bs: int):
model_name = "cyberagent/calm3-22b-chat"
root_dir = Path("datasets/wiki_paraphrase4/calm3_22b")
max_file_size = 1_000_000
if dtype == "bf16":
dtype = "bfloat16"
enable_prefix_caching = True
elif dtype == "fp16":
dtype = "float16"
enable_prefix_caching = False
else:
raise ValueError(f"Invalid dtype: {dtype}")
rng = random.SystemRandom()
seed = rng.randint(0, 2**32 - 1)
llm = LLM(
model_name,
trust_remote_code=True,
tensor_parallel_size=tp,
quantization=None,
dtype=dtype,
gpu_memory_utilization=0.95,
seed=seed,
enforce_eager=False,
enable_prefix_caching=enable_prefix_caching,
)
tokenizer = llm.get_tokenizer()
# ここからサーバ依存の処理
...
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment