Last active
December 30, 2025 18:23
-
-
Save H2CO3/d65fecc169cfc09476058572864e68b4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import pandas as pd | |
| from argparse import ArgumentParser | |
| from pathlib import Path | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.model_selection import train_test_split, cross_validate, cross_val_score, StratifiedGroupKFold | |
| argparser = ArgumentParser() | |
| argparser.add_argument('-i', '--input', type=Path, required=True, | |
| help='Path to input CSV created by the Rust program') | |
| argparser.add_argument('-s', '--single', action='store_true', | |
| help='Only select a single random move from each game to avoid spurious correlations') | |
| argparser.add_argument('-r', '--random', type=int, default=133742, | |
| help='Random seed for reproducibility') | |
| args = argparser.parse_args() | |
| df = pd.read_csv(args.input) | |
| print(f'Total data size: {df.shape}') | |
| df = df.query('result != 0') | |
| print(f'Data size without ties: {df.shape}') | |
| if args.single: | |
| tmp = [] | |
| rng = np.random.default_rng(seed=args.random) | |
| for game_id, tmpdf in df.groupby('game_id'): | |
| tmp.append(tmpdf.iloc[rng.choice(len(tmpdf))]) | |
| df = pd.DataFrame(tmp) | |
| print(f'Data size with 1 random move per game: {df.shape}') | |
| cv_results = cross_validate( | |
| LogisticRegression(C=np.inf), | |
| X=df[['pawns']] * 100.0, # convert to centipawns | |
| y=df['result'], | |
| groups=df['game_id'], | |
| scoring=['accuracy', 'balanced_accuracy', 'precision', 'recall', 'matthews_corrcoef'], | |
| n_jobs=-1, | |
| verbose=1, | |
| cv=StratifiedGroupKFold(shuffle=True, random_state=args.random), | |
| return_estimator=True, | |
| ) | |
| coefs = pd.DataFrame([{'slope': m.coef_[0, 0], 'intercept': m.intercept_[0]} for m in cv_results.pop('estimator')]) | |
| cv_results = pd.DataFrame(cv_results) | |
| print('CV metrics:\n----') | |
| print(cv_results) | |
| print() | |
| print('CV fitted model parameters:\n----') | |
| print(coefs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::io::{BufRead, Write}; | |
| use regex::Regex; | |
| enum State { | |
| Init, | |
| Event, | |
| Result, | |
| WhiteElo, | |
| BlackElo, | |
| Termination, | |
| } | |
| fn main() { | |
| let rx_event = Regex::new(r#"^\[Event"#).unwrap(); | |
| let rx_result = Regex::new(r#"^\[Result\s+"([^"]+)"\]"#).unwrap(); | |
| let rx_welo = Regex::new(r#"^\[WhiteElo\s+"(\d+)""#).unwrap(); | |
| let rx_belo = Regex::new(r#"^\[BlackElo\s+"(\d+)"\]"#).unwrap(); | |
| let rx_term = Regex::new(r#"^\[Termination\s+"([^"]+)"\]"#).unwrap(); | |
| let rx_move = Regex::new(r#"\d+\.\s+.*?\[%eval\s+([\+-]?\d+\.\d+)\]"#).unwrap(); | |
| let stdin = std::io::stdin(); | |
| let stdin = stdin.lock(); | |
| let stdout = std::io::stdout(); | |
| let mut stdout = stdout.lock(); | |
| writeln!(stdout, "game_id,result,white_elo,black_elo,pawns").ok(); | |
| let mut game_id: usize = 1; | |
| let mut state = State::Init; | |
| let mut is_valid = false; | |
| let mut result = None::<isize>; | |
| let mut white_elo = None::<usize>; | |
| let mut black_elo = None::<usize>; | |
| let mut pawns = Vec::<f64>::new(); | |
| for line in stdin.lines() { | |
| let line = line.unwrap(); | |
| if rx_event.is_match(&line) { | |
| if let ( | |
| true, | |
| State::Termination, | |
| Some(result), | |
| Some(white_elo), | |
| Some(black_elo), | |
| ) = ( | |
| is_valid, | |
| state, | |
| result, | |
| white_elo, | |
| black_elo, | |
| ) { | |
| for pawn in &pawns { | |
| writeln!(stdout, "{game_id},{result},{white_elo},{black_elo},{pawn}").ok(); | |
| } | |
| } | |
| game_id += 1; | |
| state = State::Event; | |
| is_valid = false; | |
| result = None; | |
| white_elo = None; | |
| black_elo = None; | |
| pawns.clear(); | |
| continue; | |
| } | |
| match state { | |
| State::Init => {} | |
| State::Event => { | |
| if let Some(captures) = rx_result.captures(&line) { | |
| result = Some(match &captures[1] { | |
| "1-0" => 1, | |
| "0-1" => -1, | |
| _ => 0, | |
| }); | |
| state = State::Result; | |
| } | |
| } | |
| State::Result => { | |
| if let Some(captures) = rx_welo.captures(&line) { | |
| white_elo = Some(captures[1].parse().unwrap()); | |
| state = State::WhiteElo; | |
| } | |
| } | |
| State::WhiteElo => { | |
| if let Some(captures) = rx_belo.captures(&line) { | |
| black_elo = Some(captures[1].parse().unwrap()); | |
| state = State::BlackElo; | |
| } | |
| } | |
| State::BlackElo => { | |
| if let Some(captures) = rx_term.captures(&line) { | |
| if &captures[1] == "Normal" { | |
| is_valid = true; | |
| } | |
| state = State::Termination; | |
| } | |
| } | |
| State::Termination => { | |
| for captures in rx_move.captures_iter(&line) { | |
| pawns.push(captures[1].parse().unwrap()); | |
| } | |
| } | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment