Skip to content

Instantly share code, notes, and snippets.

@H2CO3
Last active December 30, 2025 18:23
Show Gist options
  • Select an option

  • Save H2CO3/d65fecc169cfc09476058572864e68b4 to your computer and use it in GitHub Desktop.

Select an option

Save H2CO3/d65fecc169cfc09476058572864e68b4 to your computer and use it in GitHub Desktop.
import numpy as np
import pandas as pd
from argparse import ArgumentParser
from pathlib import Path
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, cross_validate, cross_val_score, StratifiedGroupKFold
argparser = ArgumentParser()
argparser.add_argument('-i', '--input', type=Path, required=True,
help='Path to input CSV created by the Rust program')
argparser.add_argument('-s', '--single', action='store_true',
help='Only select a single random move from each game to avoid spurious correlations')
argparser.add_argument('-r', '--random', type=int, default=133742,
help='Random seed for reproducibility')
args = argparser.parse_args()
df = pd.read_csv(args.input)
print(f'Total data size: {df.shape}')
df = df.query('result != 0')
print(f'Data size without ties: {df.shape}')
if args.single:
tmp = []
rng = np.random.default_rng(seed=args.random)
for game_id, tmpdf in df.groupby('game_id'):
tmp.append(tmpdf.iloc[rng.choice(len(tmpdf))])
df = pd.DataFrame(tmp)
print(f'Data size with 1 random move per game: {df.shape}')
cv_results = cross_validate(
LogisticRegression(C=np.inf),
X=df[['pawns']] * 100.0, # convert to centipawns
y=df['result'],
groups=df['game_id'],
scoring=['accuracy', 'balanced_accuracy', 'precision', 'recall', 'matthews_corrcoef'],
n_jobs=-1,
verbose=1,
cv=StratifiedGroupKFold(shuffle=True, random_state=args.random),
return_estimator=True,
)
coefs = pd.DataFrame([{'slope': m.coef_[0, 0], 'intercept': m.intercept_[0]} for m in cv_results.pop('estimator')])
cv_results = pd.DataFrame(cv_results)
print('CV metrics:\n----')
print(cv_results)
print()
print('CV fitted model parameters:\n----')
print(coefs)
use std::io::{BufRead, Write};
use regex::Regex;
enum State {
Init,
Event,
Result,
WhiteElo,
BlackElo,
Termination,
}
fn main() {
let rx_event = Regex::new(r#"^\[Event"#).unwrap();
let rx_result = Regex::new(r#"^\[Result\s+"([^"]+)"\]"#).unwrap();
let rx_welo = Regex::new(r#"^\[WhiteElo\s+"(\d+)""#).unwrap();
let rx_belo = Regex::new(r#"^\[BlackElo\s+"(\d+)"\]"#).unwrap();
let rx_term = Regex::new(r#"^\[Termination\s+"([^"]+)"\]"#).unwrap();
let rx_move = Regex::new(r#"\d+\.\s+.*?\[%eval\s+([\+-]?\d+\.\d+)\]"#).unwrap();
let stdin = std::io::stdin();
let stdin = stdin.lock();
let stdout = std::io::stdout();
let mut stdout = stdout.lock();
writeln!(stdout, "game_id,result,white_elo,black_elo,pawns").ok();
let mut game_id: usize = 1;
let mut state = State::Init;
let mut is_valid = false;
let mut result = None::<isize>;
let mut white_elo = None::<usize>;
let mut black_elo = None::<usize>;
let mut pawns = Vec::<f64>::new();
for line in stdin.lines() {
let line = line.unwrap();
if rx_event.is_match(&line) {
if let (
true,
State::Termination,
Some(result),
Some(white_elo),
Some(black_elo),
) = (
is_valid,
state,
result,
white_elo,
black_elo,
) {
for pawn in &pawns {
writeln!(stdout, "{game_id},{result},{white_elo},{black_elo},{pawn}").ok();
}
}
game_id += 1;
state = State::Event;
is_valid = false;
result = None;
white_elo = None;
black_elo = None;
pawns.clear();
continue;
}
match state {
State::Init => {}
State::Event => {
if let Some(captures) = rx_result.captures(&line) {
result = Some(match &captures[1] {
"1-0" => 1,
"0-1" => -1,
_ => 0,
});
state = State::Result;
}
}
State::Result => {
if let Some(captures) = rx_welo.captures(&line) {
white_elo = Some(captures[1].parse().unwrap());
state = State::WhiteElo;
}
}
State::WhiteElo => {
if let Some(captures) = rx_belo.captures(&line) {
black_elo = Some(captures[1].parse().unwrap());
state = State::BlackElo;
}
}
State::BlackElo => {
if let Some(captures) = rx_term.captures(&line) {
if &captures[1] == "Normal" {
is_valid = true;
}
state = State::Termination;
}
}
State::Termination => {
for captures in rx_move.captures_iter(&line) {
pawns.push(captures[1].parse().unwrap());
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment