Skip to content

Instantly share code, notes, and snippets.

@lukehinds
Created January 8, 2026 17:03
Show Gist options
  • Select an option

  • Save lukehinds/394ee6983629c064afc252ae6712f210 to your computer and use it in GitHub Desktop.

Select an option

Save lukehinds/394ee6983629c064afc252ae6712f210 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Generic Dataset Quality Filter for Tool-Calling Datasets
This script filters out problematic patterns from ANY synthetic tool-calling dataset
that can cause models to develop bad habits during training.
Key features:
1. Auto-detection mode: Discovers problematic patterns from the data itself
2. Schema-agnostic: Works with any tool-calling dataset (Blender, Kubernetes, GitHub, etc.)
3. Configurable via YAML or CLI arguments
4. Removes samples with broken/placeholder responses
5. Filters samples with excessive calls to the same tool
6. Detects and removes recovery patterns (tool A followed by tool B)
7. Deduplicates near-identical samples based on question similarity
8. Balances tool call distribution to prevent over-representation
Usage:
python filter_tool_dataset.py input.jsonl output.jsonl [options]
Examples:
# Analyze dataset and auto-detect problematic patterns
python filter_tool_dataset.py --analyze input.jsonl
# Filter with auto-detected settings
python filter_tool_dataset.py input.jsonl output.jsonl --auto
# Filter with explicit tool balancing
python filter_tool_dataset.py input.jsonl output.jsonl --balance-tools "list_pods,get_deployment"
# Filter with custom recovery patterns
python filter_tool_dataset.py input.jsonl output.jsonl --recovery-patterns "get_pod:list_pods"
# Use YAML config file
python filter_tool_dataset.py input.jsonl output.jsonl --config filter.yaml
"""
import argparse
import json
import re
import sys
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from difflib import SequenceMatcher
from pathlib import Path
from typing import Any
try:
import yaml
YAML_AVAILABLE = True
except ImportError:
YAML_AVAILABLE = False
# Default thresholds for auto-detection
DEFAULT_FIRST_TOOL_THRESHOLD = 0.15 # Flag tools appearing first in >15% of samples
DEFAULT_SEQUENCE_MIN_COUNT = 10 # Flag sequences appearing >10 times
DEFAULT_SEQUENCE_MIN_PERCENTAGE = 0.05 # Flag sequences in >5% of samples (conservative)
# Patterns that are likely legitimate workflows, not recovery patterns
# These are heuristics: status checks, polling, category browsing are normal
LEGITIMATE_SEQUENCE_PATTERNS = [
# Status check -> action is normal workflow
(r"get_.*_status$", r".*"),
# Category browsing -> search is normal workflow
(r"get_.*_categories$", r"search_.*"),
# Search -> search (refining results) can be legitimate
# Generate -> poll (waiting for completion) is normal
(r"generate_.*", r"poll_.*"),
# Poll -> import (completion workflow) is normal
(r"poll_.*", r"import_.*"),
]
@dataclass
class FilterConfig:
"""Configuration for dataset filtering."""
# Maximum times the same tool can be called in a single sample
max_same_tool_calls: int = 3
# Maximum percentage of samples that can start with a given tool
max_first_tool_percentage: float = 0.15
# Patterns that indicate broken/placeholder responses
broken_response_patterns: list = field(
default_factory=lambda: [
r"\{\{[^}]+\}\}", # Template placeholders like {{objectName}}
]
)
# Similarity threshold for deduplication (0.0 - 1.0)
similarity_threshold: float = 0.85
# Tools to downsample if over-represented (empty = no balancing)
tools_to_balance: list = field(default_factory=list)
# Target percentage for balanced tools
balance_target_percentage: float = 0.10
# Whether to remove samples where a "recovery" pattern is detected
remove_recovery_patterns: bool = True
# Recovery patterns: tool A followed by tool B indicates recovery/fallback
# Empty list = no recovery pattern filtering
recovery_patterns: list = field(default_factory=list)
# Auto-detection mode: discover patterns from data
auto_detect: bool = False
# Thresholds for auto-detection
auto_first_tool_threshold: float = DEFAULT_FIRST_TOOL_THRESHOLD
auto_sequence_min_count: int = DEFAULT_SEQUENCE_MIN_COUNT
auto_sequence_min_percentage: float = DEFAULT_SEQUENCE_MIN_PERCENTAGE
@dataclass
class FilterStats:
"""Statistics from the filtering process."""
total_input: int = 0
removed_broken_responses: int = 0
removed_excessive_same_tool: int = 0
removed_recovery_patterns: int = 0
removed_duplicates: int = 0
removed_for_balance: int = 0
total_output: int = 0
def summary(self) -> str:
"""Generate a summary of filtering results."""
lines = [
"=" * 60,
"Dataset Filtering Summary",
"=" * 60,
f"Total input samples: {self.total_input}",
f"Removed (broken responses): {self.removed_broken_responses}",
f"Removed (excessive same tool): {self.removed_excessive_same_tool}",
f"Removed (recovery patterns): {self.removed_recovery_patterns}",
f"Removed (duplicates): {self.removed_duplicates}",
f"Removed (for balance): {self.removed_for_balance}",
"-" * 60,
f"Total output samples: {self.total_output}",
f"Retention rate: {self.total_output/max(self.total_input,1)*100:.1f}%",
"=" * 60,
]
return "\n".join(lines)
@dataclass
class DetectedPatterns:
"""Patterns detected via auto-detection."""
overrepresented_first_tools: list = field(default_factory=list)
recovery_sequences: list = field(default_factory=list)
broken_pattern_count: int = 0
excessive_same_tool_count: int = 0
def extract_tool_calls(sample: dict) -> list[tuple[str, str]]:
"""Extract tool calls from a sample as (tool_name, arguments) tuples."""
tool_calls = []
messages = sample.get("messages", [])
for msg in messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
tool_calls.append((func.get("name", ""), func.get("arguments", "")))
return tool_calls
def extract_tool_results(sample: dict) -> list[str]:
"""Extract tool results from a sample."""
results = []
messages = sample.get("messages", [])
for msg in messages:
if msg.get("role") == "tool":
results.append(msg.get("content", ""))
return results
def get_user_question(sample: dict) -> str:
"""Extract the user's question from a sample."""
messages = sample.get("messages", [])
for msg in messages:
if msg.get("role") == "user":
return msg.get("content", "")
return sample.get("tool_context", {}).get("question", "")
def has_broken_responses(sample: dict, patterns: list[str]) -> bool:
"""Check if sample contains broken/placeholder responses."""
if not patterns:
return False
sample_str = json.dumps(sample)
return any(re.search(pattern, sample_str) for pattern in patterns)
def has_excessive_same_tool(tool_calls: list[tuple[str, str]], max_calls: int) -> bool:
"""Check if any tool is called more than max_calls times."""
tool_counts = Counter(name for name, _ in tool_calls)
return any(count > max_calls for count in tool_counts.values())
def has_recovery_pattern(
tool_calls: list[tuple[str, str]], recovery_patterns: list[tuple[str, str]]
) -> bool:
"""Check if the sample shows a recovery pattern (tool A followed by tool B)."""
if not recovery_patterns:
return False
tool_names = [name for name, _ in tool_calls]
for i in range(len(tool_names) - 1):
for pattern_a, pattern_b in recovery_patterns:
if tool_names[i] == pattern_a and tool_names[i + 1] == pattern_b:
return True
return False
def get_first_tool(tool_calls: list[tuple[str, str]]) -> str | None:
"""Get the name of the first tool called."""
if tool_calls:
return tool_calls[0][0]
return None
def get_tool_sequences(tool_calls: list[tuple[str, str]]) -> list[tuple[str, str]]:
"""Extract consecutive tool pairs from a sample."""
tool_names = [name for name, _ in tool_calls]
sequences = []
for i in range(len(tool_names) - 1):
sequences.append((tool_names[i], tool_names[i + 1]))
return sequences
def is_legitimate_sequence(tool_a: str, tool_b: str) -> bool:
"""
Check if a tool sequence is a legitimate workflow pattern.
Some sequences are normal and should NOT be flagged as recovery patterns:
- status checks followed by actions
- generate followed by poll (async workflow)
- poll followed by import (completion workflow)
"""
for pattern_a, pattern_b in LEGITIMATE_SEQUENCE_PATTERNS:
if re.match(pattern_a, tool_a) and re.match(pattern_b, tool_b):
return True
return False
def is_suspicious_sequence(tool_a: str, tool_b: str) -> bool:
"""
Check if a tool sequence is suspicious and likely indicates recovery behavior.
Suspicious patterns:
- Same tool called twice in a row (potential loop) - EXCEPT for execute/run tools
- Info/getter tool followed by another info/getter (fallback behavior)
- Any tool followed by a "get_scene" or "list_*" tool (common recovery)
"""
# Same tool twice - suspicious for info/getter tools, but ok for execute/run tools
if tool_a == tool_b:
# Execute/run tools can legitimately be called multiple times
if tool_a.startswith("execute") or tool_a.startswith("run"):
return False
# Search tools can legitimately refine results
return not tool_a.startswith("search")
# Info tool followed by scene/list tool is suspicious (fallback behavior)
return tool_a.startswith("get_") and (
tool_b.startswith("get_scene") or tool_b.startswith("list_")
)
def similarity(a: str, b: str) -> float:
"""Calculate similarity ratio between two strings."""
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
def detect_patterns(samples: list[dict], config: FilterConfig) -> DetectedPatterns:
"""
Auto-detect problematic patterns in the dataset.
Analyzes the dataset to find:
- Over-represented first tools (tools that start too many samples)
- Recovery sequences (tool A followed by tool B patterns)
- Samples with broken responses
- Samples with excessive same-tool calls
"""
detected = DetectedPatterns()
first_tool_counts: Counter = Counter()
sequence_counts: Counter = Counter()
total_samples = len(samples)
for sample in samples:
tool_calls = extract_tool_calls(sample)
# Count first tools
first_tool = get_first_tool(tool_calls)
if first_tool:
first_tool_counts[first_tool] += 1
# Count sequences
sequences = get_tool_sequences(tool_calls)
for seq in sequences:
sequence_counts[seq] += 1
# Count broken responses
if has_broken_responses(sample, config.broken_response_patterns):
detected.broken_pattern_count += 1
# Count excessive same-tool calls
if has_excessive_same_tool(tool_calls, config.max_same_tool_calls):
detected.excessive_same_tool_count += 1
# Find over-represented first tools
for tool, count in first_tool_counts.items():
percentage = count / max(total_samples, 1)
if percentage > config.auto_first_tool_threshold:
detected.overrepresented_first_tools.append({
"tool": tool,
"count": count,
"percentage": percentage,
})
# Sort by percentage descending
detected.overrepresented_first_tools.sort(
key=lambda x: x["percentage"], reverse=True
)
# Find recovery sequences (only suspicious ones, not legitimate workflows)
for (tool_a, tool_b), count in sequence_counts.items():
percentage = count / max(total_samples, 1)
if (count >= config.auto_sequence_min_count and
percentage >= config.auto_sequence_min_percentage):
# Skip legitimate workflow patterns
if is_legitimate_sequence(tool_a, tool_b):
continue
# Only flag suspicious patterns or explicitly mark them
is_suspicious = is_suspicious_sequence(tool_a, tool_b)
detected.recovery_sequences.append({
"from_tool": tool_a,
"to_tool": tool_b,
"count": count,
"percentage": percentage,
"suspicious": is_suspicious,
})
# Sort by count descending
detected.recovery_sequences.sort(key=lambda x: x["count"], reverse=True)
return detected
def load_yaml_config(config_path: Path) -> dict[str, Any]:
"""Load configuration from a YAML file."""
if not YAML_AVAILABLE:
print("Error: PyYAML is required for --config. Install with: pip install pyyaml")
sys.exit(1)
with open(config_path) as f:
return yaml.safe_load(f) or {}
def config_from_yaml(yaml_config: dict[str, Any]) -> FilterConfig:
"""Create FilterConfig from YAML configuration."""
config = FilterConfig()
if "broken_patterns" in yaml_config:
config.broken_response_patterns = yaml_config["broken_patterns"]
if "recovery_sequences" in yaml_config:
# Convert from list of lists to list of tuples
config.recovery_patterns = [
tuple(seq) for seq in yaml_config["recovery_sequences"]
]
config.remove_recovery_patterns = True
if "balance_tools" in yaml_config:
# Support both simple list and pattern-based config
balance_config = yaml_config["balance_tools"]
if isinstance(balance_config, list):
# Simple list of tool names
if all(isinstance(t, str) for t in balance_config):
config.tools_to_balance = balance_config
# List of dicts with patterns
elif all(isinstance(t, dict) for t in balance_config):
# For now, just extract tool names from patterns
# TODO: Support regex patterns
config.tools_to_balance = []
for item in balance_config:
if "tool" in item:
config.tools_to_balance.append(item["tool"])
if "max_first_call_pct" in item:
config.max_first_tool_percentage = item["max_first_call_pct"]
if "max_same_tool_calls" in yaml_config:
config.max_same_tool_calls = yaml_config["max_same_tool_calls"]
if "similarity_threshold" in yaml_config:
config.similarity_threshold = yaml_config["similarity_threshold"]
if "balance_target_percentage" in yaml_config:
config.balance_target_percentage = yaml_config["balance_target_percentage"]
return config
def filter_dataset(
input_path: Path, output_path: Path, config: FilterConfig
) -> FilterStats:
"""Filter the dataset according to the configuration."""
stats = FilterStats()
# First pass: load all samples
samples = []
with open(input_path) as f:
for raw_line in f:
stripped = raw_line.strip()
if not stripped:
continue
try:
sample = json.loads(stripped)
stats.total_input += 1
samples.append(sample)
except json.JSONDecodeError:
continue
# Auto-detection mode: discover patterns from data
if config.auto_detect:
detected = detect_patterns(samples, config)
# Apply detected patterns to config
if detected.overrepresented_first_tools:
config.tools_to_balance = [
item["tool"] for item in detected.overrepresented_first_tools
]
# Only apply suspicious sequences (not legitimate workflow patterns)
suspicious_sequences = [
s for s in detected.recovery_sequences if s.get("suspicious")
]
if suspicious_sequences:
config.recovery_patterns = [
(item["from_tool"], item["to_tool"])
for item in suspicious_sequences
]
config.remove_recovery_patterns = True
# Apply filters
filtered_samples = []
first_tool_counts = Counter()
for sample in samples:
# Check for broken responses
if has_broken_responses(sample, config.broken_response_patterns):
stats.removed_broken_responses += 1
continue
tool_calls = extract_tool_calls(sample)
# Check for excessive same tool calls
if has_excessive_same_tool(tool_calls, config.max_same_tool_calls):
stats.removed_excessive_same_tool += 1
continue
# Check for recovery patterns
if config.remove_recovery_patterns and has_recovery_pattern(
tool_calls, config.recovery_patterns
):
stats.removed_recovery_patterns += 1
continue
first_tool = get_first_tool(tool_calls)
if first_tool:
first_tool_counts[first_tool] += 1
filtered_samples.append(sample)
# Deduplication based on question similarity
deduplicated_samples = []
seen_questions = []
for sample in filtered_samples:
question = get_user_question(sample)
is_duplicate = False
for seen_q in seen_questions:
if similarity(question, seen_q) > config.similarity_threshold:
is_duplicate = True
stats.removed_duplicates += 1
break
if not is_duplicate:
deduplicated_samples.append(sample)
seen_questions.append(question)
# Balance tool distribution
if config.tools_to_balance:
balanced_samples = []
tool_sample_map = defaultdict(list)
for sample in deduplicated_samples:
tool_calls = extract_tool_calls(sample)
first_tool = get_first_tool(tool_calls)
if first_tool in config.tools_to_balance:
tool_sample_map[first_tool].append(sample)
else:
balanced_samples.append(sample)
# Calculate target count for balanced tools
non_balanced_count = len(balanced_samples)
if non_balanced_count > 0:
target_count = int(
non_balanced_count * config.balance_target_percentage
/ max(0.01, 1 - config.balance_target_percentage * len(config.tools_to_balance))
)
else:
target_count = 0
for _tool, tool_samples in tool_sample_map.items():
if len(tool_samples) > target_count:
stats.removed_for_balance += len(tool_samples) - target_count
balanced_samples.extend(tool_samples[:target_count])
else:
balanced_samples.extend(tool_samples)
final_samples = balanced_samples
else:
final_samples = deduplicated_samples
stats.total_output = len(final_samples)
# Write output
with open(output_path, "w") as f:
for sample in final_samples:
f.write(json.dumps(sample) + "\n")
return stats
def analyze_dataset(input_path: Path, config: FilterConfig) -> None:
"""Analyze dataset and print statistics with auto-detected patterns."""
samples = []
with open(input_path) as f:
for raw_line in f:
stripped = raw_line.strip()
if not stripped:
continue
try:
samples.append(json.loads(stripped))
except json.JSONDecodeError:
continue
print(f"\nDataset Analysis: {input_path}")
print("=" * 60)
print(f"Total samples: {len(samples)}")
# Tool call distribution
all_tool_calls = []
first_tool_calls = []
for sample in samples:
tool_calls = extract_tool_calls(sample)
all_tool_calls.extend(name for name, _ in tool_calls)
first_tool = get_first_tool(tool_calls)
if first_tool:
first_tool_calls.append(first_tool)
if all_tool_calls:
print("\nTool Call Distribution (all calls):")
for tool, count in Counter(all_tool_calls).most_common(15):
print(f" {tool}: {count} ({count/len(all_tool_calls)*100:.1f}%)")
if first_tool_calls:
print("\nFirst Tool Call Distribution:")
for tool, count in Counter(first_tool_calls).most_common(15):
pct = count / len(first_tool_calls) * 100
flag = " [*]" if pct > config.auto_first_tool_threshold * 100 else ""
print(f" {tool}: {count} ({pct:.1f}%){flag}")
# Auto-detect patterns
print("\n" + "=" * 60)
print("Auto-Detected Patterns")
print("=" * 60)
detected = detect_patterns(samples, config)
# Over-represented first tools
if detected.overrepresented_first_tools:
print(f"\nOver-represented first tools (>{config.auto_first_tool_threshold*100:.0f}% of samples):")
tools_for_cli = []
for item in detected.overrepresented_first_tools:
print(f" {item['tool']}: {item['count']} ({item['percentage']*100:.1f}%)")
tools_for_cli.append(item['tool'])
print(f"\n Suggested: --balance-tools \"{','.join(tools_for_cli)}\"")
else:
print("\nNo over-represented first tools detected.")
# Recovery sequences
if detected.recovery_sequences:
suspicious_sequences = [s for s in detected.recovery_sequences if s.get("suspicious")]
other_sequences = [s for s in detected.recovery_sequences if not s.get("suspicious")]
if suspicious_sequences:
print("\nSuspicious sequences (likely recovery patterns):")
patterns_for_cli = []
for item in suspicious_sequences:
marker = "[LOOP]" if item["from_tool"] == item["to_tool"] else "[FALLBACK]"
print(f" {marker} {item['from_tool']} -> {item['to_tool']}: {item['count']} ({item['percentage']*100:.1f}%)")
patterns_for_cli.append(f"{item['from_tool']}:{item['to_tool']}")
print(f"\n Suggested: --recovery-patterns \"{','.join(patterns_for_cli)}\"")
if other_sequences:
print("\nOther frequent sequences (may be normal workflow):")
for item in other_sequences[:5]: # Show top 5 only
print(f" {item['from_tool']} -> {item['to_tool']}: {item['count']} ({item['percentage']*100:.1f}%)")
else:
print("\nNo significant recovery sequences detected.")
# Broken responses
print(f"\nSamples with template placeholders: {detected.broken_pattern_count}")
print(f"Samples with excessive same-tool calls (>{config.max_same_tool_calls}): {detected.excessive_same_tool_count}")
# Estimate impact of filtering
print("\n" + "=" * 60)
print("Estimated Filtering Impact")
print("=" * 60)
# Calculate with auto-detected patterns
recovery_count = 0
if detected.recovery_sequences:
recovery_patterns = [
(item["from_tool"], item["to_tool"])
for item in detected.recovery_sequences
]
for sample in samples:
tool_calls = extract_tool_calls(sample)
if has_recovery_pattern(tool_calls, recovery_patterns):
recovery_count += 1
removable = detected.broken_pattern_count + detected.excessive_same_tool_count + recovery_count
print(f"Samples that would be removed: ~{removable} ({removable/max(len(samples),1)*100:.1f}%)")
print(f"Samples that would remain: ~{len(samples) - removable}")
# Print recommended command
print("\n" + "=" * 60)
print("Recommended Command")
print("=" * 60)
cmd_parts = ["python filter_tool_dataset.py", str(input_path), "output.jsonl"]
if detected.overrepresented_first_tools:
tools = ",".join(item["tool"] for item in detected.overrepresented_first_tools)
cmd_parts.append(f'--balance-tools "{tools}"')
# Only include suspicious sequences in the recommended command
suspicious_sequences = [s for s in detected.recovery_sequences if s.get("suspicious")]
if suspicious_sequences:
patterns = ",".join(
f"{item['from_tool']}:{item['to_tool']}"
for item in suspicious_sequences
)
cmd_parts.append(f'--recovery-patterns "{patterns}"')
print(f"\n {' '.join(cmd_parts)}")
print("\nOr use auto-detection mode:")
print(f" python filter_tool_dataset.py {input_path} output.jsonl --auto")
def parse_recovery_patterns(patterns_str: str) -> list[tuple[str, str]]:
"""Parse recovery patterns from CLI string format 'A:B,C:D'."""
if not patterns_str:
return []
patterns = []
for raw_pair in patterns_str.split(","):
stripped_pair = raw_pair.strip()
if ":" in stripped_pair:
a, b = stripped_pair.split(":", 1)
patterns.append((a.strip(), b.strip()))
return patterns
def parse_balance_tools(tools_str: str) -> list[str]:
"""Parse balance tools from CLI string format 'tool1,tool2'."""
if not tools_str:
return []
return [t.strip() for t in tools_str.split(",") if t.strip()]
def main():
parser = argparse.ArgumentParser(
description="Generic filter for tool-calling datasets",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Analyze dataset and auto-detect problematic patterns
%(prog)s --analyze input.jsonl
# Filter with auto-detected settings
%(prog)s input.jsonl output.jsonl --auto
# Filter with explicit tool balancing
%(prog)s input.jsonl output.jsonl --balance-tools "list_pods,get_deployment"
# Filter with custom recovery patterns
%(prog)s input.jsonl output.jsonl --recovery-patterns "get_pod:list_pods,get_object:get_scene"
# Use YAML config file
%(prog)s input.jsonl output.jsonl --config filter.yaml
# Combine: config file with CLI overrides
%(prog)s input.jsonl output.jsonl --config filter.yaml --max-same-tool 2
""",
)
parser.add_argument("input", type=Path, help="Input JSONL file")
parser.add_argument(
"output", type=Path, nargs="?", help="Output JSONL file (optional for analyze)"
)
parser.add_argument(
"--analyze", action="store_true", help="Only analyze, don't filter"
)
parser.add_argument(
"--auto", action="store_true",
help="Auto-detect and apply patterns from the data"
)
parser.add_argument(
"--config", type=Path,
help="YAML config file for domain-specific rules"
)
parser.add_argument(
"--max-same-tool",
type=int,
default=3,
help="Maximum calls to same tool per sample (default: 3)",
)
parser.add_argument(
"--similarity-threshold",
type=float,
default=0.85,
help="Similarity threshold for deduplication (default: 0.85)",
)
parser.add_argument(
"--balance-target",
type=float,
default=0.10,
help="Target percentage for balanced tools (default: 0.10)",
)
parser.add_argument(
"--balance-tools",
type=str,
help="Comma-separated list of tools to balance (e.g., 'get_scene_info,list_pods')"
)
parser.add_argument(
"--recovery-patterns",
type=str,
help="Comma-separated recovery patterns as 'A:B' pairs (e.g., 'get_object:get_scene')"
)
parser.add_argument(
"--first-tool-threshold",
type=float,
default=DEFAULT_FIRST_TOOL_THRESHOLD,
help=f"Threshold for flagging over-represented first tools (default: {DEFAULT_FIRST_TOOL_THRESHOLD})"
)
parser.add_argument(
"--no-balance", action="store_true", help="Skip tool balancing"
)
parser.add_argument(
"--keep-recovery",
action="store_true",
help="Keep samples with recovery patterns",
)
parser.add_argument(
"--keep-broken",
action="store_true",
help="Keep samples with broken/placeholder responses",
)
args = parser.parse_args()
if not args.input.exists():
print(f"Error: Input file not found: {args.input}")
sys.exit(1)
# Build configuration
config = FilterConfig()
# Load YAML config first (if provided)
if args.config:
if not args.config.exists():
print(f"Error: Config file not found: {args.config}")
sys.exit(1)
yaml_config = load_yaml_config(args.config)
config = config_from_yaml(yaml_config)
# Apply CLI overrides
config.max_same_tool_calls = args.max_same_tool
config.similarity_threshold = args.similarity_threshold
config.balance_target_percentage = args.balance_target
config.auto_first_tool_threshold = args.first_tool_threshold
config.auto_detect = args.auto
# CLI-specified patterns override config
if args.balance_tools:
config.tools_to_balance = parse_balance_tools(args.balance_tools)
elif args.no_balance:
config.tools_to_balance = []
if args.recovery_patterns:
config.recovery_patterns = parse_recovery_patterns(args.recovery_patterns)
config.remove_recovery_patterns = True
elif args.keep_recovery:
config.remove_recovery_patterns = False
if args.keep_broken:
config.broken_response_patterns = []
# Analyze mode
if args.analyze:
analyze_dataset(args.input, config)
return
# Filter mode
if not args.output:
print("Error: Output file required when not using --analyze")
sys.exit(1)
print(f"Filtering {args.input} -> {args.output}")
if config.auto_detect:
print("Auto-detection mode enabled - discovering patterns from data...")
if config.tools_to_balance:
print(f"Balancing tools: {', '.join(config.tools_to_balance)}")
if config.recovery_patterns:
patterns_str = ", ".join(f"{a}->{b}" for a, b in config.recovery_patterns)
print(f"Recovery patterns: {patterns_str}")
stats = filter_dataset(args.input, args.output, config)
print(stats.summary())
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment