Skip to content

Instantly share code, notes, and snippets.

@naufalso
Last active January 14, 2026 13:19
Show Gist options
  • Select an option

  • Save naufalso/919d3183272f98842cd7d06ae510b9cf to your computer and use it in GitHub Desktop.

Select an option

Save naufalso/919d3183272f98842cd7d06ae510b9cf to your computer and use it in GitHub Desktop.
Script to Sync Hugging Face trainer_state.json to a new WandB Run
import json
import wandb
import gradio as gr
import os
import tempfile
def sync_to_wandb(json_file, api_key, project_name, run_name, entity):
"""
Core logic to authenticate with WandB and log HF trainer state data.
"""
if not json_file:
return "❌ Please upload a trainer_state.json file."
if not api_key:
return "❌ Please provide your WandB API Key."
if not project_name:
return "❌ Please provide a project name."
try:
# 1. Authenticate
os.environ["WANDB_API_KEY"] = api_key.strip()
# 2. Load JSON
with open(json_file.name, 'r') as f:
state_data = json.load(f)
# 3. Initialize Run
run = wandb.init(
project=project_name,
name=run_name if run_name else "imported-hf-run",
entity=entity if entity else None,
reinit=True
)
run_url = run.get_url()
status_logs = [f"πŸš€ Started WandB run: {run.name}", f"πŸ”— View at: {run_url}"]
# 4. Log Metadata/Summary
summary_keys = [
"best_global_step", "best_metric", "best_model_checkpoint",
"epoch", "global_step", "max_steps", "num_train_epochs",
"total_flos", "trial_name", "trial_params"
]
for key in summary_keys:
if key in state_data:
wandb.run.summary[key] = state_data[key]
# 5. Log History
if "log_history" in state_data:
history = state_data["log_history"]
status_logs.append(f"πŸ“Š Found {len(history)} log entries. Syncing...")
for entry in history:
log_payload = entry.copy()
step = log_payload.get("step")
if step is not None:
wandb.log(log_payload, step=int(step))
else:
wandb.log(log_payload)
status_logs.append("βœ… Successfully synced all history.")
else:
status_logs.append("⚠️ Warning: No 'log_history' found in JSON.")
# 6. Finish
wandb.finish()
status_logs.append("🏁 Sync complete. WandB run finished.")
return "\n".join(status_logs)
except Exception as e:
return f"❌ Error: {str(e)}"
# Define the Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🧊 Hugging Face to WandB Syncer
Upload your `trainer_state.json` file to sync logs from a previous training session to Weights & Biases.
"""
)
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload trainer_state.json", file_types=[".json"])
api_key_input = gr.Textbox(
label="WandB API Key",
placeholder="Paste your API key here...",
type="password"
)
with gr.Row():
project_input = gr.Textbox(label="Project Name", value="hf-import")
run_name_input = gr.Textbox(label="Run Name", placeholder="e.g., llama-finetuning-v1")
entity_input = gr.Textbox(
label="Entity (Optional)",
placeholder="Team or Username"
)
sync_btn = gr.Button("Sync to WandB", variant="primary")
with gr.Column():
output_log = gr.Textbox(
label="Status / Logs",
interactive=False,
lines=15
)
sync_btn.click(
fn=sync_to_wandb,
inputs=[file_input, api_key_input, project_input, run_name_input, entity_input],
outputs=output_log
)
gr.Markdown(
"--- \n *Note: You can find your API key at [wandb.ai/authorize](https://wandb.ai/authorize).*"
)
if __name__ == "__main__":
demo.launch()
import json
import wandb
import argparse
import os
import sys
def sync_trainer_state(json_path, project_name, run_name=None, entity=None):
"""
Reads a Hugging Face trainer_state.json file and logs the history to WandB.
"""
# 1. Load the JSON content
try:
with open(json_path, 'r') as f:
state_data = json.load(f)
except FileNotFoundError:
print(f"Error: The file '{json_path}' was not found.")
sys.exit(1)
except json.JSONDecodeError:
print(f"Error: Failed to decode JSON from '{json_path}'.")
sys.exit(1)
# 2. Initialize WandB
# We use reinit=True to ensure a clean run if called multiple times in a notebook
run = wandb.init(
project=project_name,
name=run_name,
entity=entity,
reinit=True
)
print(f"Started WandB run: {run.name}")
# 3. Log Top-Level Summary/Metadata
# These fields from the file define the final state of training
summary_keys = [
"best_global_step",
"best_metric",
"best_model_checkpoint",
"epoch",
"global_step",
"max_steps",
"num_train_epochs",
"total_flos",
"trial_name",
"trial_params"
]
for key in summary_keys:
if key in state_data:
# Add to wandb summary (displayed at the top of the run page)
wandb.run.summary[key] = state_data[key]
# 4. Log History
# The 'log_history' list contains the time-series data (loss, learning rate, etc.)
if "log_history" in state_data:
history = state_data["log_history"]
print(f"Found {len(history)} log entries. Syncing...")
for entry in history:
# We copy the entry to avoid modifying the original data
log_payload = entry.copy()
# Hugging Face logs usually contain a 'step' key.
# We extract it to use as the explicit x-axis for WandB.
step = log_payload.get("step")
# Log the dictionary to WandB
if step is not None:
wandb.log(log_payload, step=step)
else:
wandb.log(log_payload)
else:
print("Warning: No 'log_history' key found in the JSON file.")
# 5. Finish the run
wandb.finish()
print("Sync complete. You can view the run at the URL above.")
if __name__ == "__main__":
# Setup command line arguments
parser = argparse.ArgumentParser(
description="Sync Hugging Face trainer_state.json to a new WandB Run"
)
parser.add_argument("json_file", help="Path to the trainer_state.json file")
parser.add_argument("--project", required=True, help="Name of the WandB Project")
parser.add_argument("--run_name", default="imported-hf-run", help="Name for the WandB Run")
parser.add_argument("--entity", default=None, help="WandB Entity (username or team name)")
args = parser.parse_args()
sync_trainer_state(
json_path=args.json_file,
project_name=args.project,
run_name=args.run_name,
entity=args.entity
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment