Skip to content

Instantly share code, notes, and snippets.

@jakobhuss
Last active February 22, 2026 15:06
Show Gist options
  • Select an option

  • Save jakobhuss/ae71037f79f0850c06ab53df515b8c7f to your computer and use it in GitHub Desktop.

Select an option

Save jakobhuss/ae71037f79f0850c06ab53df515b8c7f to your computer and use it in GitHub Desktop.
Script to search for optimal llama-server parameters using optuina
#!/usr/bin/env python3
import sys
import os
import stat
import signal
import optuna
import subprocess
import time
import uuid
import multiprocessing
import argparse
import socket
import json
import math
import urllib.request
import urllib.error
import tempfile
import re
import shutil
import statistics
from http.client import RemoteDisconnected
from datetime import datetime
def get_free_port():
"""Find a single random open port on the OS."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
# Default starting state for tuning parameters
initial_state = {
"threads": max(1, multiprocessing.cpu_count() // 2),
"threads_batch": max(1, multiprocessing.cpu_count() // 2),
"batch_size": 2048,
"ubatch_size": 512,
"flash_attn": "auto",
"kv_offload": True,
"repack": True,
"ctk": "f16",
"ctv": "f16",
"mmap": True,
"direct_io": False,
"numa": "none",
"split_mode": "layer",
"kv_unified": True,
"cont_batching": True,
"spec_type": "none",
"spec_ngram_size_n": 12,
"spec_ngram_size_m": 48,
"poll": 50,
"prio": 0,
"cpu_strict": 0,
"swa_full": False
}
# Mapping specific CLI flags to internal param keys for self-healing error detection
flag_to_param = {
"-t": "threads", "-tb": "threads_batch", "-b": "batch_size", "-ub": "ubatch_size",
"-fa": "flash_attn", "--kv-offload": "kv_offload", "--no-kv-offload": "kv_offload",
"--repack": "repack", "--no-repack": "repack", "-ctk": "ctk", "-ctv": "ctv",
"--mmap": "mmap", "--no-mmap": "mmap", "--direct-io": "direct_io", "--no-direct-io": "direct_io",
"--numa": "numa", "-sm": "split_mode", "--kv-unified": "kv_unified", "--no-kv-unified": "kv_unified",
"--cont-batching": "cont_batching", "--no-cont-batching": "cont_batching",
"--spec-type": "spec_type", "--spec-ngram-size-n": "spec_ngram_size_n", "--spec-ngram-size-m": "spec_ngram_size_m",
"--poll": "poll", "--prio": "prio", "--cpu-strict": "cpu_strict", "--swa-full": "swa_full"
}
banned_params = set()
def resolve_llama_server(user_cmd):
"""Robustly find the llama-server executable."""
if not user_cmd:
return None, []
first_arg = user_cmd[0]
if not first_arg.startswith("-") and ("llama-server" in first_arg or os.path.exists(first_arg)):
exe_path = first_arg
args = user_cmd[1:]
else:
exe_path = "llama-server"
args = user_cmd
which_path = shutil.which(exe_path)
if which_path:
return which_path, args
if os.path.isfile(exe_path) and os.access(exe_path, os.X_OK):
return os.path.abspath(exe_path), args
fallbacks = ["./llama-server", "build/bin/llama-server", "../build/bin/llama-server"]
for f in fallbacks:
if os.path.isfile(f) and os.access(f, os.X_OK):
return os.path.abspath(f), args
return None, user_cmd
def parse_tuner_args():
parser = argparse.ArgumentParser(description="Llama.cpp Native Performance Auto-Tuner")
parser.add_argument("--n-trials", type=int, default=200, help="Number of tuning trials to run")
parser.add_argument("--n-prompts", type=int, default=3, help="Number of times to run the benchmark prompt per trial. Score is the median of these runs.")
parser.add_argument("--n-tokens", type=int, default=128, help="Tokens to generate per benchmark")
parser.add_argument("--warmup-tokens", type=int, default=1, help="Tokens for memory warmup ping (0 to disable)")
parser.add_argument("--timeout", type=float, default=300.0, help="Max time (seconds) to wait for a generation before pruning the config")
parser.add_argument("--metric", choices=["tps", "prompt", "total", "balanced"], default="tps",
help="Optimize for: tps, prompt, total, or balanced")
parser.add_argument("--min-quality", type=float, default=0.9,
help="Minimum acceptable quality threshold (0.0 to 1.0, default: 0.9 for 90%%). Trials below this score 0.")
parser.add_argument("--frozen", type=str, default="",
help="Comma-separated keys to NOT tune (e.g., 'threads,batch_size' or '-t,-b')")
parser.add_argument("server_cmd", nargs=argparse.REMAINDER, help="Your standard llama-server command")
args = parser.parse_args()
if not args.server_cmd:
parser.print_help()
sys.exit(1)
if args.server_cmd[0] == "--":
args.server_cmd = args.server_cmd[1:]
return args
def build_trial_cmd(base_cmd, params, user_port, active_port, output_mode=False):
cmd = base_cmd.copy()
if user_port is not None:
cmd.extend(["--port", str(user_port)])
elif not output_mode:
cmd.extend(["--port", str(active_port)])
def add_val(flag, param_name, val):
if param_name not in banned_params:
cmd.extend([flag, str(val)])
def add_bool(flag_true, flag_false, param_name, val):
if param_name not in banned_params:
cmd.append(flag_true if val else flag_false)
add_val("-t", "threads", params["threads"])
add_val("-tb", "threads_batch", params["threads_batch"])
add_val("-b", "batch_size", params["batch_size"])
add_val("-ub", "ubatch_size", params["ubatch_size"])
add_val("-fa", "flash_attn", params["flash_attn"])
add_bool("--kv-offload", "--no-kv-offload", "kv_offload", params["kv_offload"])
add_bool("--repack", "--no-repack", "repack", params["repack"])
add_val("-ctk", "ctk", params["ctk"])
add_val("-ctv", "ctv", params["ctv"])
add_bool("--mmap", "--no-mmap", "mmap", params["mmap"])
add_bool("--direct-io", "--no-direct-io", "direct_io", params["direct_io"])
if params["numa"] != "none" and "numa" not in banned_params:
cmd.extend(["--numa", params["numa"]])
add_val("-sm", "split_mode", params["split_mode"])
add_bool("--kv-unified", "--no-kv-unified", "kv_unified", params["kv_unified"])
add_bool("--cont-batching", "--no-cont-batching", "cont_batching", params["cont_batching"])
add_val("--spec-type", "spec_type", params["spec_type"])
if params["spec_type"] != "none":
add_val("--spec-ngram-size-n", "spec_ngram_size_n", params["spec_ngram_size_n"])
add_val("--spec-ngram-size-m", "spec_ngram_size_m", params["spec_ngram_size_m"])
add_val("--poll", "poll", params["poll"])
add_val("--prio", "prio", params["prio"])
add_val("--cpu-strict", "cpu_strict", params["cpu_strict"])
if params["swa_full"] and "swa_full" not in banned_params:
cmd.append("--swa-full")
return cmd
def run_benchmark(server_url, payload_prompt, n_tokens, timeout=300.0):
payload = {
"model": "auto-tuner",
"prompt": payload_prompt,
"max_tokens": n_tokens,
"n_predict": n_tokens,
"cache_prompt": False,
"stream": False,
"n_probs": 1, # Native llama.cpp quality extraction
"logprobs": 1 # Standard OpenAI quality extraction
}
data = json.dumps(payload).encode('utf-8')
# Priority #1: Llama Native. This guarantees the JSON has the native "completion_probabilities" format
endpoints_to_try = ["/completion", "/v1/completions", "/completions"]
result = None
for ep in endpoints_to_try:
try:
req = urllib.request.Request(f"{server_url}{ep}", data=data, headers={'Content-Type': 'application/json'})
with urllib.request.urlopen(req, timeout=timeout) as response:
result = json.loads(response.read().decode('utf-8'))
break
except urllib.error.HTTPError as e:
if e.code == 404:
continue
raise
if result is None:
raise urllib.error.URLError("Failed to find a valid text completions endpoint.")
probs = []
# 1. Try parsing OpenAI standard logprobs format (if hitting /v1/completions)
if "choices" in result and result["choices"]:
logprobs_data = result["choices"][0].get("logprobs")
if isinstance(logprobs_data, dict):
# Format A: Legacy text completions
if "token_logprobs" in logprobs_data and logprobs_data["token_logprobs"]:
for lp in logprobs_data["token_logprobs"]:
if lp is not None:
probs.append(math.exp(lp))
# Format B: Modern Chat completions format
elif "content" in logprobs_data and isinstance(logprobs_data["content"], list):
for lp_item in logprobs_data["content"]:
if "logprob" in lp_item and lp_item["logprob"] is not None:
probs.append(math.exp(lp_item["logprob"]))
# 2. Fallback to parsing Llama.cpp native format (if hitting /completion)
if not probs and "completion_probabilities" in result:
for token_data in result.get("completion_probabilities", []):
if "logprob" in token_data and token_data["logprob"] is not None:
probs.append(math.exp(token_data["logprob"]))
else:
t_probs = token_data.get("probs", [])
if t_probs:
probs.append(t_probs[0].get("prob", 0.0))
avg_quality = (sum(probs) / len(probs)) * 100.0 if probs else 0.0
return result.get("timings", {}), avg_quality
def parse_invalid_args(stderr_text):
"""Extracts invalid arguments from llama-server crash logs."""
return set(re.findall(r"(?:invalid|unknown)\s+argument:\s*([^\s\n]+)", stderr_text, re.IGNORECASE))
def objective(trial, base_cmd, host, user_port, active_port, frozen_keys, tuner_args):
params = {}
max_cpus = multiprocessing.cpu_count()
def get_val(key, dist_func, *args):
if key in frozen_keys or key in banned_params:
return initial_state[key]
return dist_func(key, *args)
# 1. Topology
params["threads"] = get_val("threads", trial.suggest_int, 4, max_cpus)
params["threads_batch"] = get_val("threads_batch", trial.suggest_int, 4, max_cpus)
params["poll"] = get_val("poll", trial.suggest_categorical, [0, 50, 100])
params["prio"] = get_val("prio", trial.suggest_categorical, [0, 1, 2])
params["cpu_strict"] = get_val("cpu_strict", trial.suggest_categorical, [0, 1])
# 2. Batching Limitations
params["batch_size"] = get_val("batch_size", trial.suggest_categorical, [512, 1024, 2048, 4096])
params["ubatch_size"] = get_val("ubatch_size", trial.suggest_categorical, [64, 128, 256, 512, 1024, 2048])
if params["ubatch_size"] > params["batch_size"]:
raise optuna.exceptions.TrialPruned()
params["cont_batching"] = get_val("cont_batching", trial.suggest_categorical, [True, False])
# 3. Model Engine & Caching
params["flash_attn"] = get_val("flash_attn", trial.suggest_categorical, ["on", "off", "auto"])
params["kv_offload"] = get_val("kv_offload", trial.suggest_categorical, [True, False])
params["repack"] = get_val("repack", trial.suggest_categorical, [True, False])
params["kv_unified"] = get_val("kv_unified", trial.suggest_categorical, [True, False])
params["swa_full"] = get_val("swa_full", trial.suggest_categorical, [True, False])
cache_types = ["f16", "q8_0", "q4_0", "q5_0"]
params["ctk"] = get_val("ctk", trial.suggest_categorical, cache_types)
params["ctv"] = get_val("ctv", trial.suggest_categorical, cache_types)
# 4. OS Level Overrides
params["mmap"] = get_val("mmap", trial.suggest_categorical, [True, False])
params["direct_io"] = get_val("direct_io", trial.suggest_categorical, [True, False])
params["numa"] = get_val("numa", trial.suggest_categorical, ["none", "distribute", "isolate"])
params["split_mode"] = get_val("split_mode", trial.suggest_categorical, ["layer", "row", "none"])
# 5. Speculative Decoding
params["spec_type"] = get_val("spec_type", trial.suggest_categorical, ["none", "ngram-cache", "ngram-map-k"])
if params["spec_type"] != "none":
params["spec_ngram_size_n"] = get_val("spec_ngram_size_n", trial.suggest_int, 8, 16)
params["spec_ngram_size_m"] = get_val("spec_ngram_size_m", trial.suggest_int, 16, 64)
if params["spec_ngram_size_m"] <= params["spec_ngram_size_n"]:
raise optuna.exceptions.TrialPruned()
server_url = f"http://{host}:{active_port}"
cmd = build_trial_cmd(base_cmd, params, user_port, active_port, output_mode=False)
# Brief summary of key parameters for logging
key_params = f"bs={params['batch_size']}, fa={params['flash_attn']}, ctk={params['ctk']}, ctv={params['ctv']}"
print(f"\n[Trial {trial.number}] Starting server... ({key_params})")
process = None
try:
with tempfile.TemporaryFile(mode="w+", encoding="utf-8") as stderr_file:
kwargs = {'start_new_session': True} if os.name == 'posix' else {}
process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=stderr_file, **kwargs)
start_time = time.time()
server_ready = False
while time.time() - start_time < min(120, tuner_args.timeout):
if process.poll() is not None:
stderr_file.seek(0)
err_logs = stderr_file.read()
invalid_args = parse_invalid_args(err_logs)
if invalid_args:
for bad_arg in invalid_args:
if bad_arg in flag_to_param:
banned_params.add(flag_to_param[bad_arg])
print(f" [!] Auto-healing: Banning unsupported tunable '{bad_arg}'")
elif bad_arg in base_cmd:
base_cmd.remove(bad_arg)
print(f" [!] Auto-healing: Stripped invalid user arg '{bad_arg}' from baseline.")
raise optuna.exceptions.TrialPruned("Self-healing triggered. Retrying without invalid args.")
else:
print(f" [!] Server crashed with exit code {process.returncode} (OOM / Bad params).")
print(" └─ Score: 0.00")
return 0.0
try:
req = urllib.request.Request(f"{server_url}/health")
with urllib.request.urlopen(req, timeout=2.0) as res:
if res.getcode() == 200:
server_ready = True
break
except (urllib.error.URLError, TimeoutError, socket.timeout):
pass
time.sleep(3)
if not server_ready:
print(" [!] Server timeout during startup.")
print(" └─ Score: 0.00")
return 0.0
if tuner_args.warmup_tokens > 0:
run_benchmark(server_url, "Warmup ping", n_tokens=tuner_args.warmup_tokens, timeout=min(60.0, tuner_args.timeout))
scores = []
gen_tps_list = []
prompt_tps_list = []
quality_list = []
for p_idx in range(tuner_args.n_prompts):
unique_prompt = f"Provide a comprehensive architectural breakdown of modern concurrent systems. Sequence: {p_idx} | UUID: {uuid.uuid4()}"
timings, quality = run_benchmark(server_url, unique_prompt, n_tokens=tuner_args.n_tokens, timeout=tuner_args.timeout)
gen_tps = timings.get("predicted_per_second", 0.0)
prompt_tps = timings.get("prompt_per_second", 0.0)
if tuner_args.metric == "tps":
run_score = gen_tps
elif tuner_args.metric == "prompt":
run_score = prompt_tps
elif tuner_args.metric == "total":
tt = (timings.get("prompt_ms", 0) + timings.get("predicted_ms", 0)) / 1000.0
run_score = (timings.get("prompt_n", 0) + timings.get("predicted_n", 0)) / tt if tt > 0 else 0.0
else:
run_score = (2 * prompt_tps * gen_tps) / (prompt_tps + gen_tps) if (prompt_tps + gen_tps) > 0 else 0.0
scores.append(run_score)
gen_tps_list.append(gen_tps)
prompt_tps_list.append(prompt_tps)
quality_list.append(quality)
if not scores:
return 0.0
med_score = statistics.median(scores)
med_gen_tps = statistics.median(gen_tps_list)
med_prompt_tps = statistics.median(prompt_tps_list)
med_quality = statistics.median(quality_list)
trial.set_user_attr("quality", med_quality)
# Target percentage threshold (e.g., 0.9 * 100 = 90.0%)
min_q_pct = tuner_args.min_quality * 100.0
if med_quality < min_q_pct:
print(f" [!] Quality Drop Detected ({med_quality:.1f}% < {min_q_pct:.1f}%). Rejecting configuration.")
return 0.0
prefix = "Med " if tuner_args.n_prompts > 1 else ""
print(f" └─ {prefix}Score ({tuner_args.metric}): {med_score:.2f} | {prefix}P-TPS: {med_prompt_tps:.2f} | {prefix}G-TPS: {med_gen_tps:.2f} | {prefix}Quality: {med_quality:.1f}%")
return med_score
except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError, socket.timeout, ConnectionError, RemoteDisconnected) as e:
print(f" [!] Trial failed during generation: {e} (Deadlock/Hang/Crash)")
print(" └─ Score: 0.00")
return 0.0
finally:
if process and process.poll() is None:
if os.name == 'posix':
try:
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
process.wait(timeout=5)
except (subprocess.TimeoutExpired, ProcessLookupError):
try:
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
process.wait()
except ProcessLookupError:
pass
else:
process.terminate()
process.wait()
time.sleep(2)
def make_record_callback(base_cmd, user_port, active_port, output_file, metric):
def callback(study, trial):
if trial.state == optuna.trial.TrialState.COMPLETE:
if study.best_trial.number == trial.number and trial.value > 0.0:
best_cmd = build_trial_cmd(base_cmd, trial.params, user_port, active_port, output_mode=True)
quality = trial.user_attrs.get("quality", 0.0)
print("\n" + "🚀"*15)
print(f" NEW RECORD! Score ({metric}): {trial.value:.2f} (Quality: {quality:.1f}%)")
print(f" Saved optimal configuration to: {output_file}")
print("🚀"*15)
with open(output_file, "w") as f:
f.write("#!/bin/bash\n")
f.write(f"# Auto-generated by Llama Native Tuner on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"# Metric: {metric.upper()} | Best Score: {trial.value:.2f} | Confidence Quality: {quality:.1f}%\n")
f.write(f"# \n# Tuned Parameters:\n")
for k, v in trial.params.items():
if k not in banned_params:
f.write(f"# {k}: {v}\n")
f.write("\n")
f.write(" ".join(best_cmd) + "\n")
st = os.stat(output_file)
os.chmod(output_file, st.st_mode | stat.S_IEXEC)
return callback
if __name__ == "__main__":
tuner_args = parse_tuner_args()
exe_path, trailing_args = resolve_llama_server(tuner_args.server_cmd)
if not exe_path:
print(f"Error: Could not locate 'llama-server' executable in path or current directory.")
sys.exit(1)
full_resolved_cmd = [exe_path] + trailing_args
base_cmd = []
user_port = None
host = "127.0.0.1"
i = 0
while i < len(full_resolved_cmd):
arg = full_resolved_cmd[i]
matched_val = next(((k, v) for k, v in flag_to_param.items() if k == arg), None)
if arg in ("--port", "-port"):
user_port = int(full_resolved_cmd[i+1])
i += 2
elif arg in ("--host", "-host"):
host = full_resolved_cmd[i+1]
base_cmd.extend([arg, full_resolved_cmd[i+1]])
i += 2
elif matched_val:
_, p_key = matched_val
initial_state[p_key] = type(initial_state[p_key])(full_resolved_cmd[i+1])
i += 2
elif arg in flag_to_param:
i += 1
else:
base_cmd.append(arg)
i += 1
frozen_keys = set()
if tuner_args.frozen:
for f_arg in [x.strip() for x in tuner_args.frozen.split(",")]:
if f_arg in flag_to_param:
frozen_keys.add(flag_to_param[f_arg])
elif f_arg in initial_state:
frozen_keys.add(f_arg)
active_port = user_port if user_port is not None else get_free_port()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_script_name = f"optimal_llama_cmd_{timestamp}.sh"
print("="*60)
print(f" Native Llama.cpp Performance Auto-Tuner")
print("="*60)
print(f" Binary Found : {exe_path}")
print(f" Optimize Metric : {tuner_args.metric.upper()}")
print(f" Min Quality Cut : {tuner_args.min_quality * 100:.1f}%")
print(f" Bench Tokens : {tuner_args.n_tokens}")
print(f" Prompts/Trial : {tuner_args.n_prompts}")
print(f" Max Timeout : {tuner_args.timeout}s")
print(f" Total Trials : {tuner_args.n_trials}")
print(f" Frozen Settings : {', '.join(frozen_keys) if frozen_keys else 'None'}")
print(f" Output Script : {output_script_name}")
print("="*60 + "\n")
optuna.logging.set_verbosity(optuna.logging.WARNING)
study = optuna.create_study(direction="maximize", study_name="llama-server-tps")
enqueue_dict = {k: v for k, v in initial_state.items() if k not in frozen_keys}
if "spec_type" in enqueue_dict and enqueue_dict["spec_type"] == "none":
enqueue_dict.pop("spec_ngram_size_n", None)
enqueue_dict.pop("spec_ngram_size_m", None)
if enqueue_dict:
study.enqueue_trial(enqueue_dict)
record_callback = make_record_callback(base_cmd, user_port, active_port, output_script_name, tuner_args.metric)
try:
study.optimize(
lambda t: objective(t, base_cmd, host, user_port, active_port, frozen_keys, tuner_args),
n_trials=tuner_args.n_trials,
callbacks=[record_callback]
)
except KeyboardInterrupt:
print("\n[!] Optimization manually interrupted.")
print("\n" + "="*50)
print("🎯 OPTIMIZATION COMPLETE 🎯")
print("="*50)
try:
best_trial = study.best_trial
if best_trial.value > 0.0:
print(f"Absolute Best Score ({tuner_args.metric}): {best_trial.value:.2f}")
print(f"Maintained Output Quality: {best_trial.user_attrs.get('quality', 0):.1f}% Avg Confidence")
print(f"Your ready-to-use script is successfully exported to: ./{output_script_name}")
else:
print("All trials failed or fell below the minimum quality threshold.")
except ValueError:
print("No successful trials finished. Ensure your model fits in memory and base parameters are valid.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment