Created
October 5, 2025 16:30
-
-
Save EvilFreelancer/5b391cc49224d648aaf749c8607f5d2f to your computer and use it in GitHub Desktop.
OpenRLHF-like dataset coverter to TXT file
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Build a full calibration text that includes both the prompt (rendered via chat_template) | |
| and the expected assistant reply parsed from `label`. | |
| - Reads mem-agent JSONL with: | |
| - context_messages: list of {role, content} | |
| - label: string that contains separators and, at the end, the gold answer | |
| - Renders via tokenizer.apply_chat_template with add_generation_prompt=False | |
| and appends the assistant turn (parsed from label) as the final message. | |
| - By default writes multi-line blocks separated by \n\n---\n\n | |
| (use --one-per-line to flatten). | |
| """ | |
| import argparse | |
| import io | |
| import json | |
| import os | |
| import re | |
| from typing import Any, Dict, Iterable, List, Optional | |
| from urllib.parse import urlparse | |
| import requests | |
| from transformers import AutoTokenizer | |
| ROLE_MAP = { | |
| "human": "user", | |
| "user": "user", | |
| "assistant": "assistant", | |
| "bot": "assistant", | |
| "system": "system", | |
| } | |
| SEP_RE = re.compile(r"\n\s*~[~/\s]+~\s*\n", re.MULTILINE) | |
| def is_url(path: str) -> bool: | |
| try: | |
| return urlparse(path).scheme in ("http", "https") | |
| except Exception: | |
| return False | |
| def iter_jsonl(path_or_url: str): | |
| if is_url(path_or_url): | |
| with requests.get(path_or_url, stream=True, timeout=60) as r: | |
| r.raise_for_status() | |
| for line in r.iter_lines(decode_unicode=True): | |
| if line: | |
| yield json.loads(line) | |
| else: | |
| with io.open(path_or_url, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if line.strip(): | |
| yield json.loads(line) | |
| def to_hf_messages(obj: Dict[str, Any], input_key: str) -> List[Dict[str, str]]: | |
| # Prefer context_messages; fall back to messages / input-output | |
| if isinstance(obj.get(input_key), list): | |
| msgs = [] | |
| for m in obj[input_key]: | |
| if not isinstance(m, dict): | |
| continue | |
| role = ROLE_MAP.get(str(m.get("role", "")).lower(), "user") | |
| content = m.get("content") or "" | |
| if content.strip(): | |
| msgs.append({"role": role, "content": str(content)}) | |
| if msgs: | |
| return msgs | |
| if isinstance(obj.get("messages"), list): | |
| msgs = [] | |
| for m in obj["messages"]: | |
| if not isinstance(m, dict): | |
| continue | |
| role = ROLE_MAP.get(str(m.get("role", "")).lower(), "user") | |
| content = m.get("content") or "" | |
| if content.strip(): | |
| msgs.append({"role": role, "content": str(content)}) | |
| if msgs: | |
| return msgs | |
| # Fallback from input/output for completeness | |
| inp = obj.get("input") or obj.get("prompt") or obj.get("query") | |
| out = obj.get("output") or obj.get("response") or obj.get("answer") | |
| msgs = [] | |
| if inp: | |
| msgs.append({"role": "user", "content": str(inp)}) | |
| if out: | |
| msgs.append({"role": "assistant", "content": str(out)}) | |
| return msgs | |
| def parse_label_answer(label: Any) -> Optional[str]: | |
| if not isinstance(label, str): | |
| return None | |
| parts = SEP_RE.split(label) | |
| candidate = parts[-1].strip() if parts else label.strip() | |
| return candidate or None | |
| def render_full_text( | |
| tok, | |
| messages: List[Dict[str, str]], | |
| assistant_text: str, | |
| ) -> str: | |
| """ | |
| Render prompt via chat_template (no generation prompt), | |
| then append assistant as the last message and re-render the full conversation. | |
| """ | |
| # Ensure the last message is not already assistant; if it is, drop it | |
| msgs = messages[:] | |
| while msgs and msgs[-1]["role"] == "assistant": | |
| msgs.pop() | |
| # Rebuild full conversation with gold assistant reply | |
| msgs.append({"role": "assistant", "content": assistant_text}) | |
| # Render the entire conversation exactly as the model expects | |
| text = tok.apply_chat_template( | |
| msgs, | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| return text.rstrip() + "\n" | |
| def main(): | |
| ap = argparse.ArgumentParser(description="Generate full chat text with gold assistant answers from `label`.") | |
| ap.add_argument("--input", required=True, help="Path/URL to JSONL (e.g., valid.jsonl).") | |
| ap.add_argument("--input-key", default="context_messages", help="Key that holds chat messages.") | |
| ap.add_argument("--model", default="driaforall/mem-agent", help="HF model id for tokenizer/chat_template.") | |
| ap.add_argument("--output", default="calib_full.txt", help="Output text file.") | |
| ap.add_argument("--require-label", action="store_true", | |
| help="Fail on samples without a parseable label answer.") | |
| ap.add_argument("--one-per-line", action="store_true", | |
| help="Flatten each sample to a single line (replace newlines with spaces).") | |
| args = ap.parse_args() | |
| tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | |
| rendered: List[str] = [] | |
| for obj in iter_jsonl(args.input): | |
| msgs = to_hf_messages(obj, args.input_key) | |
| if not msgs: | |
| if args.require_label: | |
| continue | |
| else: | |
| # Skip if no messages | |
| continue | |
| answer = parse_label_answer(obj.get("label")) | |
| if not answer: | |
| if args.require_label: | |
| continue | |
| else: | |
| # If no label, skip | |
| continue | |
| text = render_full_text(tok, msgs, answer) | |
| if args.one_per_line: | |
| text = text.replace("\n", " ").strip() + "\n" | |
| rendered.append(text) | |
| os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True) | |
| with io.open(args.output, "w", encoding="utf-8", newline="\n") as f: | |
| if args.one_per_line: | |
| for r in rendered: | |
| f.write(r if r.endswith("\n") else r + "\n") | |
| else: | |
| f.write("".join(x.strip() for x in rendered) + "\n") | |
| print(f"Wrote {len(rendered)} samples to: {args.output}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment