Created
January 8, 2026 17:03
-
-
Save lukehinds/394ee6983629c064afc252ae6712f210 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Generic Dataset Quality Filter for Tool-Calling Datasets | |
| This script filters out problematic patterns from ANY synthetic tool-calling dataset | |
| that can cause models to develop bad habits during training. | |
| Key features: | |
| 1. Auto-detection mode: Discovers problematic patterns from the data itself | |
| 2. Schema-agnostic: Works with any tool-calling dataset (Blender, Kubernetes, GitHub, etc.) | |
| 3. Configurable via YAML or CLI arguments | |
| 4. Removes samples with broken/placeholder responses | |
| 5. Filters samples with excessive calls to the same tool | |
| 6. Detects and removes recovery patterns (tool A followed by tool B) | |
| 7. Deduplicates near-identical samples based on question similarity | |
| 8. Balances tool call distribution to prevent over-representation | |
| Usage: | |
| python filter_tool_dataset.py input.jsonl output.jsonl [options] | |
| Examples: | |
| # Analyze dataset and auto-detect problematic patterns | |
| python filter_tool_dataset.py --analyze input.jsonl | |
| # Filter with auto-detected settings | |
| python filter_tool_dataset.py input.jsonl output.jsonl --auto | |
| # Filter with explicit tool balancing | |
| python filter_tool_dataset.py input.jsonl output.jsonl --balance-tools "list_pods,get_deployment" | |
| # Filter with custom recovery patterns | |
| python filter_tool_dataset.py input.jsonl output.jsonl --recovery-patterns "get_pod:list_pods" | |
| # Use YAML config file | |
| python filter_tool_dataset.py input.jsonl output.jsonl --config filter.yaml | |
| """ | |
| import argparse | |
| import json | |
| import re | |
| import sys | |
| from collections import Counter, defaultdict | |
| from dataclasses import dataclass, field | |
| from difflib import SequenceMatcher | |
| from pathlib import Path | |
| from typing import Any | |
| try: | |
| import yaml | |
| YAML_AVAILABLE = True | |
| except ImportError: | |
| YAML_AVAILABLE = False | |
| # Default thresholds for auto-detection | |
| DEFAULT_FIRST_TOOL_THRESHOLD = 0.15 # Flag tools appearing first in >15% of samples | |
| DEFAULT_SEQUENCE_MIN_COUNT = 10 # Flag sequences appearing >10 times | |
| DEFAULT_SEQUENCE_MIN_PERCENTAGE = 0.05 # Flag sequences in >5% of samples (conservative) | |
| # Patterns that are likely legitimate workflows, not recovery patterns | |
| # These are heuristics: status checks, polling, category browsing are normal | |
| LEGITIMATE_SEQUENCE_PATTERNS = [ | |
| # Status check -> action is normal workflow | |
| (r"get_.*_status$", r".*"), | |
| # Category browsing -> search is normal workflow | |
| (r"get_.*_categories$", r"search_.*"), | |
| # Search -> search (refining results) can be legitimate | |
| # Generate -> poll (waiting for completion) is normal | |
| (r"generate_.*", r"poll_.*"), | |
| # Poll -> import (completion workflow) is normal | |
| (r"poll_.*", r"import_.*"), | |
| ] | |
| @dataclass | |
| class FilterConfig: | |
| """Configuration for dataset filtering.""" | |
| # Maximum times the same tool can be called in a single sample | |
| max_same_tool_calls: int = 3 | |
| # Maximum percentage of samples that can start with a given tool | |
| max_first_tool_percentage: float = 0.15 | |
| # Patterns that indicate broken/placeholder responses | |
| broken_response_patterns: list = field( | |
| default_factory=lambda: [ | |
| r"\{\{[^}]+\}\}", # Template placeholders like {{objectName}} | |
| ] | |
| ) | |
| # Similarity threshold for deduplication (0.0 - 1.0) | |
| similarity_threshold: float = 0.85 | |
| # Tools to downsample if over-represented (empty = no balancing) | |
| tools_to_balance: list = field(default_factory=list) | |
| # Target percentage for balanced tools | |
| balance_target_percentage: float = 0.10 | |
| # Whether to remove samples where a "recovery" pattern is detected | |
| remove_recovery_patterns: bool = True | |
| # Recovery patterns: tool A followed by tool B indicates recovery/fallback | |
| # Empty list = no recovery pattern filtering | |
| recovery_patterns: list = field(default_factory=list) | |
| # Auto-detection mode: discover patterns from data | |
| auto_detect: bool = False | |
| # Thresholds for auto-detection | |
| auto_first_tool_threshold: float = DEFAULT_FIRST_TOOL_THRESHOLD | |
| auto_sequence_min_count: int = DEFAULT_SEQUENCE_MIN_COUNT | |
| auto_sequence_min_percentage: float = DEFAULT_SEQUENCE_MIN_PERCENTAGE | |
| @dataclass | |
| class FilterStats: | |
| """Statistics from the filtering process.""" | |
| total_input: int = 0 | |
| removed_broken_responses: int = 0 | |
| removed_excessive_same_tool: int = 0 | |
| removed_recovery_patterns: int = 0 | |
| removed_duplicates: int = 0 | |
| removed_for_balance: int = 0 | |
| total_output: int = 0 | |
| def summary(self) -> str: | |
| """Generate a summary of filtering results.""" | |
| lines = [ | |
| "=" * 60, | |
| "Dataset Filtering Summary", | |
| "=" * 60, | |
| f"Total input samples: {self.total_input}", | |
| f"Removed (broken responses): {self.removed_broken_responses}", | |
| f"Removed (excessive same tool): {self.removed_excessive_same_tool}", | |
| f"Removed (recovery patterns): {self.removed_recovery_patterns}", | |
| f"Removed (duplicates): {self.removed_duplicates}", | |
| f"Removed (for balance): {self.removed_for_balance}", | |
| "-" * 60, | |
| f"Total output samples: {self.total_output}", | |
| f"Retention rate: {self.total_output/max(self.total_input,1)*100:.1f}%", | |
| "=" * 60, | |
| ] | |
| return "\n".join(lines) | |
| @dataclass | |
| class DetectedPatterns: | |
| """Patterns detected via auto-detection.""" | |
| overrepresented_first_tools: list = field(default_factory=list) | |
| recovery_sequences: list = field(default_factory=list) | |
| broken_pattern_count: int = 0 | |
| excessive_same_tool_count: int = 0 | |
| def extract_tool_calls(sample: dict) -> list[tuple[str, str]]: | |
| """Extract tool calls from a sample as (tool_name, arguments) tuples.""" | |
| tool_calls = [] | |
| messages = sample.get("messages", []) | |
| for msg in messages: | |
| if msg.get("role") == "assistant" and msg.get("tool_calls"): | |
| for tc in msg["tool_calls"]: | |
| func = tc.get("function", {}) | |
| tool_calls.append((func.get("name", ""), func.get("arguments", ""))) | |
| return tool_calls | |
| def extract_tool_results(sample: dict) -> list[str]: | |
| """Extract tool results from a sample.""" | |
| results = [] | |
| messages = sample.get("messages", []) | |
| for msg in messages: | |
| if msg.get("role") == "tool": | |
| results.append(msg.get("content", "")) | |
| return results | |
| def get_user_question(sample: dict) -> str: | |
| """Extract the user's question from a sample.""" | |
| messages = sample.get("messages", []) | |
| for msg in messages: | |
| if msg.get("role") == "user": | |
| return msg.get("content", "") | |
| return sample.get("tool_context", {}).get("question", "") | |
| def has_broken_responses(sample: dict, patterns: list[str]) -> bool: | |
| """Check if sample contains broken/placeholder responses.""" | |
| if not patterns: | |
| return False | |
| sample_str = json.dumps(sample) | |
| return any(re.search(pattern, sample_str) for pattern in patterns) | |
| def has_excessive_same_tool(tool_calls: list[tuple[str, str]], max_calls: int) -> bool: | |
| """Check if any tool is called more than max_calls times.""" | |
| tool_counts = Counter(name for name, _ in tool_calls) | |
| return any(count > max_calls for count in tool_counts.values()) | |
| def has_recovery_pattern( | |
| tool_calls: list[tuple[str, str]], recovery_patterns: list[tuple[str, str]] | |
| ) -> bool: | |
| """Check if the sample shows a recovery pattern (tool A followed by tool B).""" | |
| if not recovery_patterns: | |
| return False | |
| tool_names = [name for name, _ in tool_calls] | |
| for i in range(len(tool_names) - 1): | |
| for pattern_a, pattern_b in recovery_patterns: | |
| if tool_names[i] == pattern_a and tool_names[i + 1] == pattern_b: | |
| return True | |
| return False | |
| def get_first_tool(tool_calls: list[tuple[str, str]]) -> str | None: | |
| """Get the name of the first tool called.""" | |
| if tool_calls: | |
| return tool_calls[0][0] | |
| return None | |
| def get_tool_sequences(tool_calls: list[tuple[str, str]]) -> list[tuple[str, str]]: | |
| """Extract consecutive tool pairs from a sample.""" | |
| tool_names = [name for name, _ in tool_calls] | |
| sequences = [] | |
| for i in range(len(tool_names) - 1): | |
| sequences.append((tool_names[i], tool_names[i + 1])) | |
| return sequences | |
| def is_legitimate_sequence(tool_a: str, tool_b: str) -> bool: | |
| """ | |
| Check if a tool sequence is a legitimate workflow pattern. | |
| Some sequences are normal and should NOT be flagged as recovery patterns: | |
| - status checks followed by actions | |
| - generate followed by poll (async workflow) | |
| - poll followed by import (completion workflow) | |
| """ | |
| for pattern_a, pattern_b in LEGITIMATE_SEQUENCE_PATTERNS: | |
| if re.match(pattern_a, tool_a) and re.match(pattern_b, tool_b): | |
| return True | |
| return False | |
| def is_suspicious_sequence(tool_a: str, tool_b: str) -> bool: | |
| """ | |
| Check if a tool sequence is suspicious and likely indicates recovery behavior. | |
| Suspicious patterns: | |
| - Same tool called twice in a row (potential loop) - EXCEPT for execute/run tools | |
| - Info/getter tool followed by another info/getter (fallback behavior) | |
| - Any tool followed by a "get_scene" or "list_*" tool (common recovery) | |
| """ | |
| # Same tool twice - suspicious for info/getter tools, but ok for execute/run tools | |
| if tool_a == tool_b: | |
| # Execute/run tools can legitimately be called multiple times | |
| if tool_a.startswith("execute") or tool_a.startswith("run"): | |
| return False | |
| # Search tools can legitimately refine results | |
| return not tool_a.startswith("search") | |
| # Info tool followed by scene/list tool is suspicious (fallback behavior) | |
| return tool_a.startswith("get_") and ( | |
| tool_b.startswith("get_scene") or tool_b.startswith("list_") | |
| ) | |
| def similarity(a: str, b: str) -> float: | |
| """Calculate similarity ratio between two strings.""" | |
| return SequenceMatcher(None, a.lower(), b.lower()).ratio() | |
| def detect_patterns(samples: list[dict], config: FilterConfig) -> DetectedPatterns: | |
| """ | |
| Auto-detect problematic patterns in the dataset. | |
| Analyzes the dataset to find: | |
| - Over-represented first tools (tools that start too many samples) | |
| - Recovery sequences (tool A followed by tool B patterns) | |
| - Samples with broken responses | |
| - Samples with excessive same-tool calls | |
| """ | |
| detected = DetectedPatterns() | |
| first_tool_counts: Counter = Counter() | |
| sequence_counts: Counter = Counter() | |
| total_samples = len(samples) | |
| for sample in samples: | |
| tool_calls = extract_tool_calls(sample) | |
| # Count first tools | |
| first_tool = get_first_tool(tool_calls) | |
| if first_tool: | |
| first_tool_counts[first_tool] += 1 | |
| # Count sequences | |
| sequences = get_tool_sequences(tool_calls) | |
| for seq in sequences: | |
| sequence_counts[seq] += 1 | |
| # Count broken responses | |
| if has_broken_responses(sample, config.broken_response_patterns): | |
| detected.broken_pattern_count += 1 | |
| # Count excessive same-tool calls | |
| if has_excessive_same_tool(tool_calls, config.max_same_tool_calls): | |
| detected.excessive_same_tool_count += 1 | |
| # Find over-represented first tools | |
| for tool, count in first_tool_counts.items(): | |
| percentage = count / max(total_samples, 1) | |
| if percentage > config.auto_first_tool_threshold: | |
| detected.overrepresented_first_tools.append({ | |
| "tool": tool, | |
| "count": count, | |
| "percentage": percentage, | |
| }) | |
| # Sort by percentage descending | |
| detected.overrepresented_first_tools.sort( | |
| key=lambda x: x["percentage"], reverse=True | |
| ) | |
| # Find recovery sequences (only suspicious ones, not legitimate workflows) | |
| for (tool_a, tool_b), count in sequence_counts.items(): | |
| percentage = count / max(total_samples, 1) | |
| if (count >= config.auto_sequence_min_count and | |
| percentage >= config.auto_sequence_min_percentage): | |
| # Skip legitimate workflow patterns | |
| if is_legitimate_sequence(tool_a, tool_b): | |
| continue | |
| # Only flag suspicious patterns or explicitly mark them | |
| is_suspicious = is_suspicious_sequence(tool_a, tool_b) | |
| detected.recovery_sequences.append({ | |
| "from_tool": tool_a, | |
| "to_tool": tool_b, | |
| "count": count, | |
| "percentage": percentage, | |
| "suspicious": is_suspicious, | |
| }) | |
| # Sort by count descending | |
| detected.recovery_sequences.sort(key=lambda x: x["count"], reverse=True) | |
| return detected | |
| def load_yaml_config(config_path: Path) -> dict[str, Any]: | |
| """Load configuration from a YAML file.""" | |
| if not YAML_AVAILABLE: | |
| print("Error: PyYAML is required for --config. Install with: pip install pyyaml") | |
| sys.exit(1) | |
| with open(config_path) as f: | |
| return yaml.safe_load(f) or {} | |
| def config_from_yaml(yaml_config: dict[str, Any]) -> FilterConfig: | |
| """Create FilterConfig from YAML configuration.""" | |
| config = FilterConfig() | |
| if "broken_patterns" in yaml_config: | |
| config.broken_response_patterns = yaml_config["broken_patterns"] | |
| if "recovery_sequences" in yaml_config: | |
| # Convert from list of lists to list of tuples | |
| config.recovery_patterns = [ | |
| tuple(seq) for seq in yaml_config["recovery_sequences"] | |
| ] | |
| config.remove_recovery_patterns = True | |
| if "balance_tools" in yaml_config: | |
| # Support both simple list and pattern-based config | |
| balance_config = yaml_config["balance_tools"] | |
| if isinstance(balance_config, list): | |
| # Simple list of tool names | |
| if all(isinstance(t, str) for t in balance_config): | |
| config.tools_to_balance = balance_config | |
| # List of dicts with patterns | |
| elif all(isinstance(t, dict) for t in balance_config): | |
| # For now, just extract tool names from patterns | |
| # TODO: Support regex patterns | |
| config.tools_to_balance = [] | |
| for item in balance_config: | |
| if "tool" in item: | |
| config.tools_to_balance.append(item["tool"]) | |
| if "max_first_call_pct" in item: | |
| config.max_first_tool_percentage = item["max_first_call_pct"] | |
| if "max_same_tool_calls" in yaml_config: | |
| config.max_same_tool_calls = yaml_config["max_same_tool_calls"] | |
| if "similarity_threshold" in yaml_config: | |
| config.similarity_threshold = yaml_config["similarity_threshold"] | |
| if "balance_target_percentage" in yaml_config: | |
| config.balance_target_percentage = yaml_config["balance_target_percentage"] | |
| return config | |
| def filter_dataset( | |
| input_path: Path, output_path: Path, config: FilterConfig | |
| ) -> FilterStats: | |
| """Filter the dataset according to the configuration.""" | |
| stats = FilterStats() | |
| # First pass: load all samples | |
| samples = [] | |
| with open(input_path) as f: | |
| for raw_line in f: | |
| stripped = raw_line.strip() | |
| if not stripped: | |
| continue | |
| try: | |
| sample = json.loads(stripped) | |
| stats.total_input += 1 | |
| samples.append(sample) | |
| except json.JSONDecodeError: | |
| continue | |
| # Auto-detection mode: discover patterns from data | |
| if config.auto_detect: | |
| detected = detect_patterns(samples, config) | |
| # Apply detected patterns to config | |
| if detected.overrepresented_first_tools: | |
| config.tools_to_balance = [ | |
| item["tool"] for item in detected.overrepresented_first_tools | |
| ] | |
| # Only apply suspicious sequences (not legitimate workflow patterns) | |
| suspicious_sequences = [ | |
| s for s in detected.recovery_sequences if s.get("suspicious") | |
| ] | |
| if suspicious_sequences: | |
| config.recovery_patterns = [ | |
| (item["from_tool"], item["to_tool"]) | |
| for item in suspicious_sequences | |
| ] | |
| config.remove_recovery_patterns = True | |
| # Apply filters | |
| filtered_samples = [] | |
| first_tool_counts = Counter() | |
| for sample in samples: | |
| # Check for broken responses | |
| if has_broken_responses(sample, config.broken_response_patterns): | |
| stats.removed_broken_responses += 1 | |
| continue | |
| tool_calls = extract_tool_calls(sample) | |
| # Check for excessive same tool calls | |
| if has_excessive_same_tool(tool_calls, config.max_same_tool_calls): | |
| stats.removed_excessive_same_tool += 1 | |
| continue | |
| # Check for recovery patterns | |
| if config.remove_recovery_patterns and has_recovery_pattern( | |
| tool_calls, config.recovery_patterns | |
| ): | |
| stats.removed_recovery_patterns += 1 | |
| continue | |
| first_tool = get_first_tool(tool_calls) | |
| if first_tool: | |
| first_tool_counts[first_tool] += 1 | |
| filtered_samples.append(sample) | |
| # Deduplication based on question similarity | |
| deduplicated_samples = [] | |
| seen_questions = [] | |
| for sample in filtered_samples: | |
| question = get_user_question(sample) | |
| is_duplicate = False | |
| for seen_q in seen_questions: | |
| if similarity(question, seen_q) > config.similarity_threshold: | |
| is_duplicate = True | |
| stats.removed_duplicates += 1 | |
| break | |
| if not is_duplicate: | |
| deduplicated_samples.append(sample) | |
| seen_questions.append(question) | |
| # Balance tool distribution | |
| if config.tools_to_balance: | |
| balanced_samples = [] | |
| tool_sample_map = defaultdict(list) | |
| for sample in deduplicated_samples: | |
| tool_calls = extract_tool_calls(sample) | |
| first_tool = get_first_tool(tool_calls) | |
| if first_tool in config.tools_to_balance: | |
| tool_sample_map[first_tool].append(sample) | |
| else: | |
| balanced_samples.append(sample) | |
| # Calculate target count for balanced tools | |
| non_balanced_count = len(balanced_samples) | |
| if non_balanced_count > 0: | |
| target_count = int( | |
| non_balanced_count * config.balance_target_percentage | |
| / max(0.01, 1 - config.balance_target_percentage * len(config.tools_to_balance)) | |
| ) | |
| else: | |
| target_count = 0 | |
| for _tool, tool_samples in tool_sample_map.items(): | |
| if len(tool_samples) > target_count: | |
| stats.removed_for_balance += len(tool_samples) - target_count | |
| balanced_samples.extend(tool_samples[:target_count]) | |
| else: | |
| balanced_samples.extend(tool_samples) | |
| final_samples = balanced_samples | |
| else: | |
| final_samples = deduplicated_samples | |
| stats.total_output = len(final_samples) | |
| # Write output | |
| with open(output_path, "w") as f: | |
| for sample in final_samples: | |
| f.write(json.dumps(sample) + "\n") | |
| return stats | |
| def analyze_dataset(input_path: Path, config: FilterConfig) -> None: | |
| """Analyze dataset and print statistics with auto-detected patterns.""" | |
| samples = [] | |
| with open(input_path) as f: | |
| for raw_line in f: | |
| stripped = raw_line.strip() | |
| if not stripped: | |
| continue | |
| try: | |
| samples.append(json.loads(stripped)) | |
| except json.JSONDecodeError: | |
| continue | |
| print(f"\nDataset Analysis: {input_path}") | |
| print("=" * 60) | |
| print(f"Total samples: {len(samples)}") | |
| # Tool call distribution | |
| all_tool_calls = [] | |
| first_tool_calls = [] | |
| for sample in samples: | |
| tool_calls = extract_tool_calls(sample) | |
| all_tool_calls.extend(name for name, _ in tool_calls) | |
| first_tool = get_first_tool(tool_calls) | |
| if first_tool: | |
| first_tool_calls.append(first_tool) | |
| if all_tool_calls: | |
| print("\nTool Call Distribution (all calls):") | |
| for tool, count in Counter(all_tool_calls).most_common(15): | |
| print(f" {tool}: {count} ({count/len(all_tool_calls)*100:.1f}%)") | |
| if first_tool_calls: | |
| print("\nFirst Tool Call Distribution:") | |
| for tool, count in Counter(first_tool_calls).most_common(15): | |
| pct = count / len(first_tool_calls) * 100 | |
| flag = " [*]" if pct > config.auto_first_tool_threshold * 100 else "" | |
| print(f" {tool}: {count} ({pct:.1f}%){flag}") | |
| # Auto-detect patterns | |
| print("\n" + "=" * 60) | |
| print("Auto-Detected Patterns") | |
| print("=" * 60) | |
| detected = detect_patterns(samples, config) | |
| # Over-represented first tools | |
| if detected.overrepresented_first_tools: | |
| print(f"\nOver-represented first tools (>{config.auto_first_tool_threshold*100:.0f}% of samples):") | |
| tools_for_cli = [] | |
| for item in detected.overrepresented_first_tools: | |
| print(f" {item['tool']}: {item['count']} ({item['percentage']*100:.1f}%)") | |
| tools_for_cli.append(item['tool']) | |
| print(f"\n Suggested: --balance-tools \"{','.join(tools_for_cli)}\"") | |
| else: | |
| print("\nNo over-represented first tools detected.") | |
| # Recovery sequences | |
| if detected.recovery_sequences: | |
| suspicious_sequences = [s for s in detected.recovery_sequences if s.get("suspicious")] | |
| other_sequences = [s for s in detected.recovery_sequences if not s.get("suspicious")] | |
| if suspicious_sequences: | |
| print("\nSuspicious sequences (likely recovery patterns):") | |
| patterns_for_cli = [] | |
| for item in suspicious_sequences: | |
| marker = "[LOOP]" if item["from_tool"] == item["to_tool"] else "[FALLBACK]" | |
| print(f" {marker} {item['from_tool']} -> {item['to_tool']}: {item['count']} ({item['percentage']*100:.1f}%)") | |
| patterns_for_cli.append(f"{item['from_tool']}:{item['to_tool']}") | |
| print(f"\n Suggested: --recovery-patterns \"{','.join(patterns_for_cli)}\"") | |
| if other_sequences: | |
| print("\nOther frequent sequences (may be normal workflow):") | |
| for item in other_sequences[:5]: # Show top 5 only | |
| print(f" {item['from_tool']} -> {item['to_tool']}: {item['count']} ({item['percentage']*100:.1f}%)") | |
| else: | |
| print("\nNo significant recovery sequences detected.") | |
| # Broken responses | |
| print(f"\nSamples with template placeholders: {detected.broken_pattern_count}") | |
| print(f"Samples with excessive same-tool calls (>{config.max_same_tool_calls}): {detected.excessive_same_tool_count}") | |
| # Estimate impact of filtering | |
| print("\n" + "=" * 60) | |
| print("Estimated Filtering Impact") | |
| print("=" * 60) | |
| # Calculate with auto-detected patterns | |
| recovery_count = 0 | |
| if detected.recovery_sequences: | |
| recovery_patterns = [ | |
| (item["from_tool"], item["to_tool"]) | |
| for item in detected.recovery_sequences | |
| ] | |
| for sample in samples: | |
| tool_calls = extract_tool_calls(sample) | |
| if has_recovery_pattern(tool_calls, recovery_patterns): | |
| recovery_count += 1 | |
| removable = detected.broken_pattern_count + detected.excessive_same_tool_count + recovery_count | |
| print(f"Samples that would be removed: ~{removable} ({removable/max(len(samples),1)*100:.1f}%)") | |
| print(f"Samples that would remain: ~{len(samples) - removable}") | |
| # Print recommended command | |
| print("\n" + "=" * 60) | |
| print("Recommended Command") | |
| print("=" * 60) | |
| cmd_parts = ["python filter_tool_dataset.py", str(input_path), "output.jsonl"] | |
| if detected.overrepresented_first_tools: | |
| tools = ",".join(item["tool"] for item in detected.overrepresented_first_tools) | |
| cmd_parts.append(f'--balance-tools "{tools}"') | |
| # Only include suspicious sequences in the recommended command | |
| suspicious_sequences = [s for s in detected.recovery_sequences if s.get("suspicious")] | |
| if suspicious_sequences: | |
| patterns = ",".join( | |
| f"{item['from_tool']}:{item['to_tool']}" | |
| for item in suspicious_sequences | |
| ) | |
| cmd_parts.append(f'--recovery-patterns "{patterns}"') | |
| print(f"\n {' '.join(cmd_parts)}") | |
| print("\nOr use auto-detection mode:") | |
| print(f" python filter_tool_dataset.py {input_path} output.jsonl --auto") | |
| def parse_recovery_patterns(patterns_str: str) -> list[tuple[str, str]]: | |
| """Parse recovery patterns from CLI string format 'A:B,C:D'.""" | |
| if not patterns_str: | |
| return [] | |
| patterns = [] | |
| for raw_pair in patterns_str.split(","): | |
| stripped_pair = raw_pair.strip() | |
| if ":" in stripped_pair: | |
| a, b = stripped_pair.split(":", 1) | |
| patterns.append((a.strip(), b.strip())) | |
| return patterns | |
| def parse_balance_tools(tools_str: str) -> list[str]: | |
| """Parse balance tools from CLI string format 'tool1,tool2'.""" | |
| if not tools_str: | |
| return [] | |
| return [t.strip() for t in tools_str.split(",") if t.strip()] | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Generic filter for tool-calling datasets", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Analyze dataset and auto-detect problematic patterns | |
| %(prog)s --analyze input.jsonl | |
| # Filter with auto-detected settings | |
| %(prog)s input.jsonl output.jsonl --auto | |
| # Filter with explicit tool balancing | |
| %(prog)s input.jsonl output.jsonl --balance-tools "list_pods,get_deployment" | |
| # Filter with custom recovery patterns | |
| %(prog)s input.jsonl output.jsonl --recovery-patterns "get_pod:list_pods,get_object:get_scene" | |
| # Use YAML config file | |
| %(prog)s input.jsonl output.jsonl --config filter.yaml | |
| # Combine: config file with CLI overrides | |
| %(prog)s input.jsonl output.jsonl --config filter.yaml --max-same-tool 2 | |
| """, | |
| ) | |
| parser.add_argument("input", type=Path, help="Input JSONL file") | |
| parser.add_argument( | |
| "output", type=Path, nargs="?", help="Output JSONL file (optional for analyze)" | |
| ) | |
| parser.add_argument( | |
| "--analyze", action="store_true", help="Only analyze, don't filter" | |
| ) | |
| parser.add_argument( | |
| "--auto", action="store_true", | |
| help="Auto-detect and apply patterns from the data" | |
| ) | |
| parser.add_argument( | |
| "--config", type=Path, | |
| help="YAML config file for domain-specific rules" | |
| ) | |
| parser.add_argument( | |
| "--max-same-tool", | |
| type=int, | |
| default=3, | |
| help="Maximum calls to same tool per sample (default: 3)", | |
| ) | |
| parser.add_argument( | |
| "--similarity-threshold", | |
| type=float, | |
| default=0.85, | |
| help="Similarity threshold for deduplication (default: 0.85)", | |
| ) | |
| parser.add_argument( | |
| "--balance-target", | |
| type=float, | |
| default=0.10, | |
| help="Target percentage for balanced tools (default: 0.10)", | |
| ) | |
| parser.add_argument( | |
| "--balance-tools", | |
| type=str, | |
| help="Comma-separated list of tools to balance (e.g., 'get_scene_info,list_pods')" | |
| ) | |
| parser.add_argument( | |
| "--recovery-patterns", | |
| type=str, | |
| help="Comma-separated recovery patterns as 'A:B' pairs (e.g., 'get_object:get_scene')" | |
| ) | |
| parser.add_argument( | |
| "--first-tool-threshold", | |
| type=float, | |
| default=DEFAULT_FIRST_TOOL_THRESHOLD, | |
| help=f"Threshold for flagging over-represented first tools (default: {DEFAULT_FIRST_TOOL_THRESHOLD})" | |
| ) | |
| parser.add_argument( | |
| "--no-balance", action="store_true", help="Skip tool balancing" | |
| ) | |
| parser.add_argument( | |
| "--keep-recovery", | |
| action="store_true", | |
| help="Keep samples with recovery patterns", | |
| ) | |
| parser.add_argument( | |
| "--keep-broken", | |
| action="store_true", | |
| help="Keep samples with broken/placeholder responses", | |
| ) | |
| args = parser.parse_args() | |
| if not args.input.exists(): | |
| print(f"Error: Input file not found: {args.input}") | |
| sys.exit(1) | |
| # Build configuration | |
| config = FilterConfig() | |
| # Load YAML config first (if provided) | |
| if args.config: | |
| if not args.config.exists(): | |
| print(f"Error: Config file not found: {args.config}") | |
| sys.exit(1) | |
| yaml_config = load_yaml_config(args.config) | |
| config = config_from_yaml(yaml_config) | |
| # Apply CLI overrides | |
| config.max_same_tool_calls = args.max_same_tool | |
| config.similarity_threshold = args.similarity_threshold | |
| config.balance_target_percentage = args.balance_target | |
| config.auto_first_tool_threshold = args.first_tool_threshold | |
| config.auto_detect = args.auto | |
| # CLI-specified patterns override config | |
| if args.balance_tools: | |
| config.tools_to_balance = parse_balance_tools(args.balance_tools) | |
| elif args.no_balance: | |
| config.tools_to_balance = [] | |
| if args.recovery_patterns: | |
| config.recovery_patterns = parse_recovery_patterns(args.recovery_patterns) | |
| config.remove_recovery_patterns = True | |
| elif args.keep_recovery: | |
| config.remove_recovery_patterns = False | |
| if args.keep_broken: | |
| config.broken_response_patterns = [] | |
| # Analyze mode | |
| if args.analyze: | |
| analyze_dataset(args.input, config) | |
| return | |
| # Filter mode | |
| if not args.output: | |
| print("Error: Output file required when not using --analyze") | |
| sys.exit(1) | |
| print(f"Filtering {args.input} -> {args.output}") | |
| if config.auto_detect: | |
| print("Auto-detection mode enabled - discovering patterns from data...") | |
| if config.tools_to_balance: | |
| print(f"Balancing tools: {', '.join(config.tools_to_balance)}") | |
| if config.recovery_patterns: | |
| patterns_str = ", ".join(f"{a}->{b}" for a, b in config.recovery_patterns) | |
| print(f"Recovery patterns: {patterns_str}") | |
| stats = filter_dataset(args.input, args.output, config) | |
| print(stats.summary()) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment