Last active
February 22, 2026 15:06
-
-
Save jakobhuss/ae71037f79f0850c06ab53df515b8c7f to your computer and use it in GitHub Desktop.
Script to search for optimal llama-server parameters using optuina
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import sys | |
| import os | |
| import stat | |
| import signal | |
| import optuna | |
| import subprocess | |
| import time | |
| import uuid | |
| import multiprocessing | |
| import argparse | |
| import socket | |
| import json | |
| import math | |
| import urllib.request | |
| import urllib.error | |
| import tempfile | |
| import re | |
| import shutil | |
| import statistics | |
| from http.client import RemoteDisconnected | |
| from datetime import datetime | |
| def get_free_port(): | |
| """Find a single random open port on the OS.""" | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| s.bind(("", 0)) | |
| return s.getsockname()[1] | |
| # Default starting state for tuning parameters | |
| initial_state = { | |
| "threads": max(1, multiprocessing.cpu_count() // 2), | |
| "threads_batch": max(1, multiprocessing.cpu_count() // 2), | |
| "batch_size": 2048, | |
| "ubatch_size": 512, | |
| "flash_attn": "auto", | |
| "kv_offload": True, | |
| "repack": True, | |
| "ctk": "f16", | |
| "ctv": "f16", | |
| "mmap": True, | |
| "direct_io": False, | |
| "numa": "none", | |
| "split_mode": "layer", | |
| "kv_unified": True, | |
| "cont_batching": True, | |
| "spec_type": "none", | |
| "spec_ngram_size_n": 12, | |
| "spec_ngram_size_m": 48, | |
| "poll": 50, | |
| "prio": 0, | |
| "cpu_strict": 0, | |
| "swa_full": False | |
| } | |
| # Mapping specific CLI flags to internal param keys for self-healing error detection | |
| flag_to_param = { | |
| "-t": "threads", "-tb": "threads_batch", "-b": "batch_size", "-ub": "ubatch_size", | |
| "-fa": "flash_attn", "--kv-offload": "kv_offload", "--no-kv-offload": "kv_offload", | |
| "--repack": "repack", "--no-repack": "repack", "-ctk": "ctk", "-ctv": "ctv", | |
| "--mmap": "mmap", "--no-mmap": "mmap", "--direct-io": "direct_io", "--no-direct-io": "direct_io", | |
| "--numa": "numa", "-sm": "split_mode", "--kv-unified": "kv_unified", "--no-kv-unified": "kv_unified", | |
| "--cont-batching": "cont_batching", "--no-cont-batching": "cont_batching", | |
| "--spec-type": "spec_type", "--spec-ngram-size-n": "spec_ngram_size_n", "--spec-ngram-size-m": "spec_ngram_size_m", | |
| "--poll": "poll", "--prio": "prio", "--cpu-strict": "cpu_strict", "--swa-full": "swa_full" | |
| } | |
| banned_params = set() | |
| def resolve_llama_server(user_cmd): | |
| """Robustly find the llama-server executable.""" | |
| if not user_cmd: | |
| return None, [] | |
| first_arg = user_cmd[0] | |
| if not first_arg.startswith("-") and ("llama-server" in first_arg or os.path.exists(first_arg)): | |
| exe_path = first_arg | |
| args = user_cmd[1:] | |
| else: | |
| exe_path = "llama-server" | |
| args = user_cmd | |
| which_path = shutil.which(exe_path) | |
| if which_path: | |
| return which_path, args | |
| if os.path.isfile(exe_path) and os.access(exe_path, os.X_OK): | |
| return os.path.abspath(exe_path), args | |
| fallbacks = ["./llama-server", "build/bin/llama-server", "../build/bin/llama-server"] | |
| for f in fallbacks: | |
| if os.path.isfile(f) and os.access(f, os.X_OK): | |
| return os.path.abspath(f), args | |
| return None, user_cmd | |
| def parse_tuner_args(): | |
| parser = argparse.ArgumentParser(description="Llama.cpp Native Performance Auto-Tuner") | |
| parser.add_argument("--n-trials", type=int, default=200, help="Number of tuning trials to run") | |
| parser.add_argument("--n-prompts", type=int, default=3, help="Number of times to run the benchmark prompt per trial. Score is the median of these runs.") | |
| parser.add_argument("--n-tokens", type=int, default=128, help="Tokens to generate per benchmark") | |
| parser.add_argument("--warmup-tokens", type=int, default=1, help="Tokens for memory warmup ping (0 to disable)") | |
| parser.add_argument("--timeout", type=float, default=300.0, help="Max time (seconds) to wait for a generation before pruning the config") | |
| parser.add_argument("--metric", choices=["tps", "prompt", "total", "balanced"], default="tps", | |
| help="Optimize for: tps, prompt, total, or balanced") | |
| parser.add_argument("--min-quality", type=float, default=0.9, | |
| help="Minimum acceptable quality threshold (0.0 to 1.0, default: 0.9 for 90%%). Trials below this score 0.") | |
| parser.add_argument("--frozen", type=str, default="", | |
| help="Comma-separated keys to NOT tune (e.g., 'threads,batch_size' or '-t,-b')") | |
| parser.add_argument("server_cmd", nargs=argparse.REMAINDER, help="Your standard llama-server command") | |
| args = parser.parse_args() | |
| if not args.server_cmd: | |
| parser.print_help() | |
| sys.exit(1) | |
| if args.server_cmd[0] == "--": | |
| args.server_cmd = args.server_cmd[1:] | |
| return args | |
| def build_trial_cmd(base_cmd, params, user_port, active_port, output_mode=False): | |
| cmd = base_cmd.copy() | |
| if user_port is not None: | |
| cmd.extend(["--port", str(user_port)]) | |
| elif not output_mode: | |
| cmd.extend(["--port", str(active_port)]) | |
| def add_val(flag, param_name, val): | |
| if param_name not in banned_params: | |
| cmd.extend([flag, str(val)]) | |
| def add_bool(flag_true, flag_false, param_name, val): | |
| if param_name not in banned_params: | |
| cmd.append(flag_true if val else flag_false) | |
| add_val("-t", "threads", params["threads"]) | |
| add_val("-tb", "threads_batch", params["threads_batch"]) | |
| add_val("-b", "batch_size", params["batch_size"]) | |
| add_val("-ub", "ubatch_size", params["ubatch_size"]) | |
| add_val("-fa", "flash_attn", params["flash_attn"]) | |
| add_bool("--kv-offload", "--no-kv-offload", "kv_offload", params["kv_offload"]) | |
| add_bool("--repack", "--no-repack", "repack", params["repack"]) | |
| add_val("-ctk", "ctk", params["ctk"]) | |
| add_val("-ctv", "ctv", params["ctv"]) | |
| add_bool("--mmap", "--no-mmap", "mmap", params["mmap"]) | |
| add_bool("--direct-io", "--no-direct-io", "direct_io", params["direct_io"]) | |
| if params["numa"] != "none" and "numa" not in banned_params: | |
| cmd.extend(["--numa", params["numa"]]) | |
| add_val("-sm", "split_mode", params["split_mode"]) | |
| add_bool("--kv-unified", "--no-kv-unified", "kv_unified", params["kv_unified"]) | |
| add_bool("--cont-batching", "--no-cont-batching", "cont_batching", params["cont_batching"]) | |
| add_val("--spec-type", "spec_type", params["spec_type"]) | |
| if params["spec_type"] != "none": | |
| add_val("--spec-ngram-size-n", "spec_ngram_size_n", params["spec_ngram_size_n"]) | |
| add_val("--spec-ngram-size-m", "spec_ngram_size_m", params["spec_ngram_size_m"]) | |
| add_val("--poll", "poll", params["poll"]) | |
| add_val("--prio", "prio", params["prio"]) | |
| add_val("--cpu-strict", "cpu_strict", params["cpu_strict"]) | |
| if params["swa_full"] and "swa_full" not in banned_params: | |
| cmd.append("--swa-full") | |
| return cmd | |
| def run_benchmark(server_url, payload_prompt, n_tokens, timeout=300.0): | |
| payload = { | |
| "model": "auto-tuner", | |
| "prompt": payload_prompt, | |
| "max_tokens": n_tokens, | |
| "n_predict": n_tokens, | |
| "cache_prompt": False, | |
| "stream": False, | |
| "n_probs": 1, # Native llama.cpp quality extraction | |
| "logprobs": 1 # Standard OpenAI quality extraction | |
| } | |
| data = json.dumps(payload).encode('utf-8') | |
| # Priority #1: Llama Native. This guarantees the JSON has the native "completion_probabilities" format | |
| endpoints_to_try = ["/completion", "/v1/completions", "/completions"] | |
| result = None | |
| for ep in endpoints_to_try: | |
| try: | |
| req = urllib.request.Request(f"{server_url}{ep}", data=data, headers={'Content-Type': 'application/json'}) | |
| with urllib.request.urlopen(req, timeout=timeout) as response: | |
| result = json.loads(response.read().decode('utf-8')) | |
| break | |
| except urllib.error.HTTPError as e: | |
| if e.code == 404: | |
| continue | |
| raise | |
| if result is None: | |
| raise urllib.error.URLError("Failed to find a valid text completions endpoint.") | |
| probs = [] | |
| # 1. Try parsing OpenAI standard logprobs format (if hitting /v1/completions) | |
| if "choices" in result and result["choices"]: | |
| logprobs_data = result["choices"][0].get("logprobs") | |
| if isinstance(logprobs_data, dict): | |
| # Format A: Legacy text completions | |
| if "token_logprobs" in logprobs_data and logprobs_data["token_logprobs"]: | |
| for lp in logprobs_data["token_logprobs"]: | |
| if lp is not None: | |
| probs.append(math.exp(lp)) | |
| # Format B: Modern Chat completions format | |
| elif "content" in logprobs_data and isinstance(logprobs_data["content"], list): | |
| for lp_item in logprobs_data["content"]: | |
| if "logprob" in lp_item and lp_item["logprob"] is not None: | |
| probs.append(math.exp(lp_item["logprob"])) | |
| # 2. Fallback to parsing Llama.cpp native format (if hitting /completion) | |
| if not probs and "completion_probabilities" in result: | |
| for token_data in result.get("completion_probabilities", []): | |
| if "logprob" in token_data and token_data["logprob"] is not None: | |
| probs.append(math.exp(token_data["logprob"])) | |
| else: | |
| t_probs = token_data.get("probs", []) | |
| if t_probs: | |
| probs.append(t_probs[0].get("prob", 0.0)) | |
| avg_quality = (sum(probs) / len(probs)) * 100.0 if probs else 0.0 | |
| return result.get("timings", {}), avg_quality | |
| def parse_invalid_args(stderr_text): | |
| """Extracts invalid arguments from llama-server crash logs.""" | |
| return set(re.findall(r"(?:invalid|unknown)\s+argument:\s*([^\s\n]+)", stderr_text, re.IGNORECASE)) | |
| def objective(trial, base_cmd, host, user_port, active_port, frozen_keys, tuner_args): | |
| params = {} | |
| max_cpus = multiprocessing.cpu_count() | |
| def get_val(key, dist_func, *args): | |
| if key in frozen_keys or key in banned_params: | |
| return initial_state[key] | |
| return dist_func(key, *args) | |
| # 1. Topology | |
| params["threads"] = get_val("threads", trial.suggest_int, 4, max_cpus) | |
| params["threads_batch"] = get_val("threads_batch", trial.suggest_int, 4, max_cpus) | |
| params["poll"] = get_val("poll", trial.suggest_categorical, [0, 50, 100]) | |
| params["prio"] = get_val("prio", trial.suggest_categorical, [0, 1, 2]) | |
| params["cpu_strict"] = get_val("cpu_strict", trial.suggest_categorical, [0, 1]) | |
| # 2. Batching Limitations | |
| params["batch_size"] = get_val("batch_size", trial.suggest_categorical, [512, 1024, 2048, 4096]) | |
| params["ubatch_size"] = get_val("ubatch_size", trial.suggest_categorical, [64, 128, 256, 512, 1024, 2048]) | |
| if params["ubatch_size"] > params["batch_size"]: | |
| raise optuna.exceptions.TrialPruned() | |
| params["cont_batching"] = get_val("cont_batching", trial.suggest_categorical, [True, False]) | |
| # 3. Model Engine & Caching | |
| params["flash_attn"] = get_val("flash_attn", trial.suggest_categorical, ["on", "off", "auto"]) | |
| params["kv_offload"] = get_val("kv_offload", trial.suggest_categorical, [True, False]) | |
| params["repack"] = get_val("repack", trial.suggest_categorical, [True, False]) | |
| params["kv_unified"] = get_val("kv_unified", trial.suggest_categorical, [True, False]) | |
| params["swa_full"] = get_val("swa_full", trial.suggest_categorical, [True, False]) | |
| cache_types = ["f16", "q8_0", "q4_0", "q5_0"] | |
| params["ctk"] = get_val("ctk", trial.suggest_categorical, cache_types) | |
| params["ctv"] = get_val("ctv", trial.suggest_categorical, cache_types) | |
| # 4. OS Level Overrides | |
| params["mmap"] = get_val("mmap", trial.suggest_categorical, [True, False]) | |
| params["direct_io"] = get_val("direct_io", trial.suggest_categorical, [True, False]) | |
| params["numa"] = get_val("numa", trial.suggest_categorical, ["none", "distribute", "isolate"]) | |
| params["split_mode"] = get_val("split_mode", trial.suggest_categorical, ["layer", "row", "none"]) | |
| # 5. Speculative Decoding | |
| params["spec_type"] = get_val("spec_type", trial.suggest_categorical, ["none", "ngram-cache", "ngram-map-k"]) | |
| if params["spec_type"] != "none": | |
| params["spec_ngram_size_n"] = get_val("spec_ngram_size_n", trial.suggest_int, 8, 16) | |
| params["spec_ngram_size_m"] = get_val("spec_ngram_size_m", trial.suggest_int, 16, 64) | |
| if params["spec_ngram_size_m"] <= params["spec_ngram_size_n"]: | |
| raise optuna.exceptions.TrialPruned() | |
| server_url = f"http://{host}:{active_port}" | |
| cmd = build_trial_cmd(base_cmd, params, user_port, active_port, output_mode=False) | |
| # Brief summary of key parameters for logging | |
| key_params = f"bs={params['batch_size']}, fa={params['flash_attn']}, ctk={params['ctk']}, ctv={params['ctv']}" | |
| print(f"\n[Trial {trial.number}] Starting server... ({key_params})") | |
| process = None | |
| try: | |
| with tempfile.TemporaryFile(mode="w+", encoding="utf-8") as stderr_file: | |
| kwargs = {'start_new_session': True} if os.name == 'posix' else {} | |
| process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=stderr_file, **kwargs) | |
| start_time = time.time() | |
| server_ready = False | |
| while time.time() - start_time < min(120, tuner_args.timeout): | |
| if process.poll() is not None: | |
| stderr_file.seek(0) | |
| err_logs = stderr_file.read() | |
| invalid_args = parse_invalid_args(err_logs) | |
| if invalid_args: | |
| for bad_arg in invalid_args: | |
| if bad_arg in flag_to_param: | |
| banned_params.add(flag_to_param[bad_arg]) | |
| print(f" [!] Auto-healing: Banning unsupported tunable '{bad_arg}'") | |
| elif bad_arg in base_cmd: | |
| base_cmd.remove(bad_arg) | |
| print(f" [!] Auto-healing: Stripped invalid user arg '{bad_arg}' from baseline.") | |
| raise optuna.exceptions.TrialPruned("Self-healing triggered. Retrying without invalid args.") | |
| else: | |
| print(f" [!] Server crashed with exit code {process.returncode} (OOM / Bad params).") | |
| print(" └─ Score: 0.00") | |
| return 0.0 | |
| try: | |
| req = urllib.request.Request(f"{server_url}/health") | |
| with urllib.request.urlopen(req, timeout=2.0) as res: | |
| if res.getcode() == 200: | |
| server_ready = True | |
| break | |
| except (urllib.error.URLError, TimeoutError, socket.timeout): | |
| pass | |
| time.sleep(3) | |
| if not server_ready: | |
| print(" [!] Server timeout during startup.") | |
| print(" └─ Score: 0.00") | |
| return 0.0 | |
| if tuner_args.warmup_tokens > 0: | |
| run_benchmark(server_url, "Warmup ping", n_tokens=tuner_args.warmup_tokens, timeout=min(60.0, tuner_args.timeout)) | |
| scores = [] | |
| gen_tps_list = [] | |
| prompt_tps_list = [] | |
| quality_list = [] | |
| for p_idx in range(tuner_args.n_prompts): | |
| unique_prompt = f"Provide a comprehensive architectural breakdown of modern concurrent systems. Sequence: {p_idx} | UUID: {uuid.uuid4()}" | |
| timings, quality = run_benchmark(server_url, unique_prompt, n_tokens=tuner_args.n_tokens, timeout=tuner_args.timeout) | |
| gen_tps = timings.get("predicted_per_second", 0.0) | |
| prompt_tps = timings.get("prompt_per_second", 0.0) | |
| if tuner_args.metric == "tps": | |
| run_score = gen_tps | |
| elif tuner_args.metric == "prompt": | |
| run_score = prompt_tps | |
| elif tuner_args.metric == "total": | |
| tt = (timings.get("prompt_ms", 0) + timings.get("predicted_ms", 0)) / 1000.0 | |
| run_score = (timings.get("prompt_n", 0) + timings.get("predicted_n", 0)) / tt if tt > 0 else 0.0 | |
| else: | |
| run_score = (2 * prompt_tps * gen_tps) / (prompt_tps + gen_tps) if (prompt_tps + gen_tps) > 0 else 0.0 | |
| scores.append(run_score) | |
| gen_tps_list.append(gen_tps) | |
| prompt_tps_list.append(prompt_tps) | |
| quality_list.append(quality) | |
| if not scores: | |
| return 0.0 | |
| med_score = statistics.median(scores) | |
| med_gen_tps = statistics.median(gen_tps_list) | |
| med_prompt_tps = statistics.median(prompt_tps_list) | |
| med_quality = statistics.median(quality_list) | |
| trial.set_user_attr("quality", med_quality) | |
| # Target percentage threshold (e.g., 0.9 * 100 = 90.0%) | |
| min_q_pct = tuner_args.min_quality * 100.0 | |
| if med_quality < min_q_pct: | |
| print(f" [!] Quality Drop Detected ({med_quality:.1f}% < {min_q_pct:.1f}%). Rejecting configuration.") | |
| return 0.0 | |
| prefix = "Med " if tuner_args.n_prompts > 1 else "" | |
| print(f" └─ {prefix}Score ({tuner_args.metric}): {med_score:.2f} | {prefix}P-TPS: {med_prompt_tps:.2f} | {prefix}G-TPS: {med_gen_tps:.2f} | {prefix}Quality: {med_quality:.1f}%") | |
| return med_score | |
| except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError, socket.timeout, ConnectionError, RemoteDisconnected) as e: | |
| print(f" [!] Trial failed during generation: {e} (Deadlock/Hang/Crash)") | |
| print(" └─ Score: 0.00") | |
| return 0.0 | |
| finally: | |
| if process and process.poll() is None: | |
| if os.name == 'posix': | |
| try: | |
| os.killpg(os.getpgid(process.pid), signal.SIGTERM) | |
| process.wait(timeout=5) | |
| except (subprocess.TimeoutExpired, ProcessLookupError): | |
| try: | |
| os.killpg(os.getpgid(process.pid), signal.SIGKILL) | |
| process.wait() | |
| except ProcessLookupError: | |
| pass | |
| else: | |
| process.terminate() | |
| process.wait() | |
| time.sleep(2) | |
| def make_record_callback(base_cmd, user_port, active_port, output_file, metric): | |
| def callback(study, trial): | |
| if trial.state == optuna.trial.TrialState.COMPLETE: | |
| if study.best_trial.number == trial.number and trial.value > 0.0: | |
| best_cmd = build_trial_cmd(base_cmd, trial.params, user_port, active_port, output_mode=True) | |
| quality = trial.user_attrs.get("quality", 0.0) | |
| print("\n" + "🚀"*15) | |
| print(f" NEW RECORD! Score ({metric}): {trial.value:.2f} (Quality: {quality:.1f}%)") | |
| print(f" Saved optimal configuration to: {output_file}") | |
| print("🚀"*15) | |
| with open(output_file, "w") as f: | |
| f.write("#!/bin/bash\n") | |
| f.write(f"# Auto-generated by Llama Native Tuner on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
| f.write(f"# Metric: {metric.upper()} | Best Score: {trial.value:.2f} | Confidence Quality: {quality:.1f}%\n") | |
| f.write(f"# \n# Tuned Parameters:\n") | |
| for k, v in trial.params.items(): | |
| if k not in banned_params: | |
| f.write(f"# {k}: {v}\n") | |
| f.write("\n") | |
| f.write(" ".join(best_cmd) + "\n") | |
| st = os.stat(output_file) | |
| os.chmod(output_file, st.st_mode | stat.S_IEXEC) | |
| return callback | |
| if __name__ == "__main__": | |
| tuner_args = parse_tuner_args() | |
| exe_path, trailing_args = resolve_llama_server(tuner_args.server_cmd) | |
| if not exe_path: | |
| print(f"Error: Could not locate 'llama-server' executable in path or current directory.") | |
| sys.exit(1) | |
| full_resolved_cmd = [exe_path] + trailing_args | |
| base_cmd = [] | |
| user_port = None | |
| host = "127.0.0.1" | |
| i = 0 | |
| while i < len(full_resolved_cmd): | |
| arg = full_resolved_cmd[i] | |
| matched_val = next(((k, v) for k, v in flag_to_param.items() if k == arg), None) | |
| if arg in ("--port", "-port"): | |
| user_port = int(full_resolved_cmd[i+1]) | |
| i += 2 | |
| elif arg in ("--host", "-host"): | |
| host = full_resolved_cmd[i+1] | |
| base_cmd.extend([arg, full_resolved_cmd[i+1]]) | |
| i += 2 | |
| elif matched_val: | |
| _, p_key = matched_val | |
| initial_state[p_key] = type(initial_state[p_key])(full_resolved_cmd[i+1]) | |
| i += 2 | |
| elif arg in flag_to_param: | |
| i += 1 | |
| else: | |
| base_cmd.append(arg) | |
| i += 1 | |
| frozen_keys = set() | |
| if tuner_args.frozen: | |
| for f_arg in [x.strip() for x in tuner_args.frozen.split(",")]: | |
| if f_arg in flag_to_param: | |
| frozen_keys.add(flag_to_param[f_arg]) | |
| elif f_arg in initial_state: | |
| frozen_keys.add(f_arg) | |
| active_port = user_port if user_port is not None else get_free_port() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_script_name = f"optimal_llama_cmd_{timestamp}.sh" | |
| print("="*60) | |
| print(f" Native Llama.cpp Performance Auto-Tuner") | |
| print("="*60) | |
| print(f" Binary Found : {exe_path}") | |
| print(f" Optimize Metric : {tuner_args.metric.upper()}") | |
| print(f" Min Quality Cut : {tuner_args.min_quality * 100:.1f}%") | |
| print(f" Bench Tokens : {tuner_args.n_tokens}") | |
| print(f" Prompts/Trial : {tuner_args.n_prompts}") | |
| print(f" Max Timeout : {tuner_args.timeout}s") | |
| print(f" Total Trials : {tuner_args.n_trials}") | |
| print(f" Frozen Settings : {', '.join(frozen_keys) if frozen_keys else 'None'}") | |
| print(f" Output Script : {output_script_name}") | |
| print("="*60 + "\n") | |
| optuna.logging.set_verbosity(optuna.logging.WARNING) | |
| study = optuna.create_study(direction="maximize", study_name="llama-server-tps") | |
| enqueue_dict = {k: v for k, v in initial_state.items() if k not in frozen_keys} | |
| if "spec_type" in enqueue_dict and enqueue_dict["spec_type"] == "none": | |
| enqueue_dict.pop("spec_ngram_size_n", None) | |
| enqueue_dict.pop("spec_ngram_size_m", None) | |
| if enqueue_dict: | |
| study.enqueue_trial(enqueue_dict) | |
| record_callback = make_record_callback(base_cmd, user_port, active_port, output_script_name, tuner_args.metric) | |
| try: | |
| study.optimize( | |
| lambda t: objective(t, base_cmd, host, user_port, active_port, frozen_keys, tuner_args), | |
| n_trials=tuner_args.n_trials, | |
| callbacks=[record_callback] | |
| ) | |
| except KeyboardInterrupt: | |
| print("\n[!] Optimization manually interrupted.") | |
| print("\n" + "="*50) | |
| print("🎯 OPTIMIZATION COMPLETE 🎯") | |
| print("="*50) | |
| try: | |
| best_trial = study.best_trial | |
| if best_trial.value > 0.0: | |
| print(f"Absolute Best Score ({tuner_args.metric}): {best_trial.value:.2f}") | |
| print(f"Maintained Output Quality: {best_trial.user_attrs.get('quality', 0):.1f}% Avg Confidence") | |
| print(f"Your ready-to-use script is successfully exported to: ./{output_script_name}") | |
| else: | |
| print("All trials failed or fell below the minimum quality threshold.") | |
| except ValueError: | |
| print("No successful trials finished. Ensure your model fits in memory and base parameters are valid.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment