Last active
January 14, 2026 13:19
-
-
Save naufalso/919d3183272f98842cd7d06ae510b9cf to your computer and use it in GitHub Desktop.
Script to Sync Hugging Face trainer_state.json to a new WandB Run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import wandb | |
| import gradio as gr | |
| import os | |
| import tempfile | |
| def sync_to_wandb(json_file, api_key, project_name, run_name, entity): | |
| """ | |
| Core logic to authenticate with WandB and log HF trainer state data. | |
| """ | |
| if not json_file: | |
| return "β Please upload a trainer_state.json file." | |
| if not api_key: | |
| return "β Please provide your WandB API Key." | |
| if not project_name: | |
| return "β Please provide a project name." | |
| try: | |
| # 1. Authenticate | |
| os.environ["WANDB_API_KEY"] = api_key.strip() | |
| # 2. Load JSON | |
| with open(json_file.name, 'r') as f: | |
| state_data = json.load(f) | |
| # 3. Initialize Run | |
| run = wandb.init( | |
| project=project_name, | |
| name=run_name if run_name else "imported-hf-run", | |
| entity=entity if entity else None, | |
| reinit=True | |
| ) | |
| run_url = run.get_url() | |
| status_logs = [f"π Started WandB run: {run.name}", f"π View at: {run_url}"] | |
| # 4. Log Metadata/Summary | |
| summary_keys = [ | |
| "best_global_step", "best_metric", "best_model_checkpoint", | |
| "epoch", "global_step", "max_steps", "num_train_epochs", | |
| "total_flos", "trial_name", "trial_params" | |
| ] | |
| for key in summary_keys: | |
| if key in state_data: | |
| wandb.run.summary[key] = state_data[key] | |
| # 5. Log History | |
| if "log_history" in state_data: | |
| history = state_data["log_history"] | |
| status_logs.append(f"π Found {len(history)} log entries. Syncing...") | |
| for entry in history: | |
| log_payload = entry.copy() | |
| step = log_payload.get("step") | |
| if step is not None: | |
| wandb.log(log_payload, step=int(step)) | |
| else: | |
| wandb.log(log_payload) | |
| status_logs.append("β Successfully synced all history.") | |
| else: | |
| status_logs.append("β οΈ Warning: No 'log_history' found in JSON.") | |
| # 6. Finish | |
| wandb.finish() | |
| status_logs.append("π Sync complete. WandB run finished.") | |
| return "\n".join(status_logs) | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| # Define the Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π§ Hugging Face to WandB Syncer | |
| Upload your `trainer_state.json` file to sync logs from a previous training session to Weights & Biases. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File(label="Upload trainer_state.json", file_types=[".json"]) | |
| api_key_input = gr.Textbox( | |
| label="WandB API Key", | |
| placeholder="Paste your API key here...", | |
| type="password" | |
| ) | |
| with gr.Row(): | |
| project_input = gr.Textbox(label="Project Name", value="hf-import") | |
| run_name_input = gr.Textbox(label="Run Name", placeholder="e.g., llama-finetuning-v1") | |
| entity_input = gr.Textbox( | |
| label="Entity (Optional)", | |
| placeholder="Team or Username" | |
| ) | |
| sync_btn = gr.Button("Sync to WandB", variant="primary") | |
| with gr.Column(): | |
| output_log = gr.Textbox( | |
| label="Status / Logs", | |
| interactive=False, | |
| lines=15 | |
| ) | |
| sync_btn.click( | |
| fn=sync_to_wandb, | |
| inputs=[file_input, api_key_input, project_input, run_name_input, entity_input], | |
| outputs=output_log | |
| ) | |
| gr.Markdown( | |
| "--- \n *Note: You can find your API key at [wandb.ai/authorize](https://wandb.ai/authorize).*" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import wandb | |
| import argparse | |
| import os | |
| import sys | |
| def sync_trainer_state(json_path, project_name, run_name=None, entity=None): | |
| """ | |
| Reads a Hugging Face trainer_state.json file and logs the history to WandB. | |
| """ | |
| # 1. Load the JSON content | |
| try: | |
| with open(json_path, 'r') as f: | |
| state_data = json.load(f) | |
| except FileNotFoundError: | |
| print(f"Error: The file '{json_path}' was not found.") | |
| sys.exit(1) | |
| except json.JSONDecodeError: | |
| print(f"Error: Failed to decode JSON from '{json_path}'.") | |
| sys.exit(1) | |
| # 2. Initialize WandB | |
| # We use reinit=True to ensure a clean run if called multiple times in a notebook | |
| run = wandb.init( | |
| project=project_name, | |
| name=run_name, | |
| entity=entity, | |
| reinit=True | |
| ) | |
| print(f"Started WandB run: {run.name}") | |
| # 3. Log Top-Level Summary/Metadata | |
| # These fields from the file define the final state of training | |
| summary_keys = [ | |
| "best_global_step", | |
| "best_metric", | |
| "best_model_checkpoint", | |
| "epoch", | |
| "global_step", | |
| "max_steps", | |
| "num_train_epochs", | |
| "total_flos", | |
| "trial_name", | |
| "trial_params" | |
| ] | |
| for key in summary_keys: | |
| if key in state_data: | |
| # Add to wandb summary (displayed at the top of the run page) | |
| wandb.run.summary[key] = state_data[key] | |
| # 4. Log History | |
| # The 'log_history' list contains the time-series data (loss, learning rate, etc.) | |
| if "log_history" in state_data: | |
| history = state_data["log_history"] | |
| print(f"Found {len(history)} log entries. Syncing...") | |
| for entry in history: | |
| # We copy the entry to avoid modifying the original data | |
| log_payload = entry.copy() | |
| # Hugging Face logs usually contain a 'step' key. | |
| # We extract it to use as the explicit x-axis for WandB. | |
| step = log_payload.get("step") | |
| # Log the dictionary to WandB | |
| if step is not None: | |
| wandb.log(log_payload, step=step) | |
| else: | |
| wandb.log(log_payload) | |
| else: | |
| print("Warning: No 'log_history' key found in the JSON file.") | |
| # 5. Finish the run | |
| wandb.finish() | |
| print("Sync complete. You can view the run at the URL above.") | |
| if __name__ == "__main__": | |
| # Setup command line arguments | |
| parser = argparse.ArgumentParser( | |
| description="Sync Hugging Face trainer_state.json to a new WandB Run" | |
| ) | |
| parser.add_argument("json_file", help="Path to the trainer_state.json file") | |
| parser.add_argument("--project", required=True, help="Name of the WandB Project") | |
| parser.add_argument("--run_name", default="imported-hf-run", help="Name for the WandB Run") | |
| parser.add_argument("--entity", default=None, help="WandB Entity (username or team name)") | |
| args = parser.parse_args() | |
| sync_trainer_state( | |
| json_path=args.json_file, | |
| project_name=args.project, | |
| run_name=args.run_name, | |
| entity=args.entity | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment