Skip to content

Instantly share code, notes, and snippets.

@903124
Last active January 11, 2026 04:13
Show Gist options
  • Select an option

  • Save 903124/706c6aa128d3ef5fa21dab13b02ff1d8 to your computer and use it in GitHub Desktop.

Select an option

Save 903124/706c6aa128d3ef5fa21dab13b02ff1d8 to your computer and use it in GitHub Desktop.
import polars as pl
import numpy as np
import json
import fastplotlib as fpl
from fastplotlib.ui import EdgeWindow
from imgui_bundle import imgui
import matplotlib.pyplot as plt
import pyperclip # For clipboard functionality
# Set numpy defaults
np.set_printoptions(precision=8)
# Read the main data file
file_path = "practice_data/2024_West_Practice_1.snappy-009.parquet"
df = pl.read_parquet(file_path)
# Bad parquet duplicates: keep one copy, drop wrong labels
df = (
df.filter(pl.col("session_id") == 1)
.drop(["session_id", "drill_type"], strict=False)
)
# Assume only one dataset_id in this parquet
dataset_id = df.select(pl.first("dataset_id")).item()
# Load + dedupe session table
session_df = (
pl.read_csv("session_timestamps.csv")
.filter(pl.col("dataset_id") == dataset_id)
.unique(subset=["dataset_id", "session_id", "session_metadata_json"])
)
# Parse JSON to get startTs, endTs, drillType
json_schema = pl.Struct({
'startTs': pl.Int64,
'endTs': pl.Int64,
'drillType': pl.Utf8,
'session_id': pl.Int64,
'state': pl.Utf8,
'drillId': pl.Int64,
})
session_df = (
session_df
.with_columns(
pl.col('session_metadata_json').str.json_decode(json_schema).alias('metadata')
)
.with_columns([
(pl.col('metadata').struct.field('startTs') / 1000).alias('start_ts'),
(pl.col('metadata').struct.field('endTs') / 1000).alias('end_ts'),
pl.col('metadata').struct.field('drillType').fill_null('').alias('drill_type'),
pl.col('metadata').struct.field('session_id').alias('json_session_id'),
])
.select(['session_id', 'start_ts', 'end_ts', 'drill_type', 'json_session_id'])
.sort('start_ts')
)
print("Session lookup table:")
print(session_df)
# Convert df timestamps to seconds
df = df.with_columns(
(pl.col("ts").str.to_datetime().dt.epoch(time_unit="ns") / 1e9).alias("ts_seconds")
)
# Assign session_id + drill_type by iterating sessions (memory-efficient)
result_chunks = []
for row in session_df.iter_rows(named=True):
chunk = df.filter(
(pl.col("ts_seconds") >= row["start_ts"]) &
(pl.col("ts_seconds") <= row["end_ts"])
).with_columns(
pl.lit(row["session_id"]).alias("session_id"),
pl.lit(row["drill_type"]).alias("drill_type"),
)
result_chunks.append(chunk)
df = pl.concat(result_chunks).drop("ts_seconds")
print(f"\nFiltered shape: {len(df)}")
print(f"Unique sessions: {df['session_id'].n_unique()}")
print(f"\nSession distribution:")
print(df.group_by('session_id').len().sort('session_id'))
# Read player data
player_data = pl.read_parquet("shrine_bowl_players.parquet")
player_college_data = pl.read_csv("shrine_bowl_players_college_stats.csv")
# Handle gsis_id conversion
df = df.with_columns([
pl.col('gsis_id').fill_null(0).cast(pl.Int64)
])
player_college_data = player_college_data.with_columns([
pl.col('college_gsis_id').cast(pl.Int64)
])
# Get the last season for each player to get their most recent position
player_last_season = (
player_college_data
.sort('season')
.group_by('college_gsis_id')
.last()
)
df = df.join(
player_last_season.select(['college_gsis_id', 'position', 'player_name']),
left_on='gsis_id',
right_on='college_gsis_id',
how='left'
)
# Get unique sessions with drill types for filtering
session_info = df.select(['session_id', 'drill_type']).unique().sort('session_id')
session_list = session_info.to_dicts()
print("\nAvailable sessions:")
print(session_info)
# Create figure
figure = fpl.Figure(size=(1200, 700))
# Add football field lines
FIELD_WIDTH = 53.3
FIELD_LENGTH = 120
# Add field boundaries
figure[0,0].add_line(np.array([[0, 0, 0], [FIELD_LENGTH, 0, 0]]), thickness=2, colors="white")
figure[0,0].add_line(np.array([[0, FIELD_WIDTH, 0], [FIELD_LENGTH, FIELD_WIDTH, 0]]), thickness=2, colors="white")
figure[0,0].add_line(np.array([[0, 0, 0], [0, FIELD_WIDTH, 0]]), thickness=2, colors="white")
figure[0,0].add_line(np.array([[FIELD_LENGTH, 0, 0], [FIELD_LENGTH, FIELD_WIDTH, 0]]), thickness=2, colors="white")
# Add yard lines every 10 yards
for yard in range(10, int(FIELD_LENGTH), 10):
figure[0,0].add_line(
np.array([[yard, 0, 0], [yard, FIELD_WIDTH, 0]]),
thickness=1,
colors="gray",
alpha=0.5
)
# Set axis labels and camera view
figure[0,0].axes.x.name = "X Position (yards)"
figure[0,0].axes.y.name = "Y Position (yards)"
figure[0,0].camera.position = (60, 26.65, 100)
class TimelapseController(EdgeWindow):
def __init__(self, figure, df, session_list, size, location, title):
super().__init__(figure=figure, size=size, location=location, title=title)
# Store full dataset
self._full_df = df
self._session_list = session_list
self._current_session_idx = 0 # Index into session_list
# Initialize with first session
self._initialize_session(self._current_session_idx)
# Animation state
self._current_frame = 0
self._is_playing = False
self._play_speed = 50
self._frame_counter = 0
self._skip_frames = 1
# Selected player info
self._selected_player_id = None
self._selected_jersey = None
self._clicked_player_info = ""
# Set up event handlers
self._setup_click_handler()
self._setup_keyboard_handler()
def _initialize_session(self, session_idx):
"""Initialize or switch to a specific session"""
session_id = self._session_list[session_idx]['session_id']
drill_type = self._session_list[session_idx]['drill_type']
print(f"\nLoading Session {session_id}: {drill_type}")
# Filter data for this session
df_session = self._full_df.filter(pl.col('session_id') == session_id)
# Add z coordinate if not present
if 'z' not in df_session.columns:
df_session = df_session.with_columns([
pl.lit(0.0).alias('z')
])
# Prepare the data
df_sorted = df_session.sort('ts')
timestamps = df_sorted['ts'].unique().sort().to_list()
# Get unique players and assign colors
player_id_col = 'gsis_id'
unique_players = (
df_sorted
.filter(pl.col('entity_type') != 'ball')
.select(player_id_col)
.drop_nulls()
.unique()
.to_series()
.to_list()
)
n_players_found = len(unique_players)
# Check if ball exists in data
has_ball = (df_sorted['entity_type'] == 'ball').any()
print(f"Found {n_players_found} unique players")
print(f"Found {len(timestamps)} timestamps")
if has_ball:
print("Ball tracking data found")
# Create extended color palette
colors_tab20 = plt.cm.tab20(np.linspace(0, 1, 20))
colors_tab20b = plt.cm.tab20b(np.linspace(0, 1, 20))
colors_combined = np.vstack([colors_tab20, colors_tab20b])
player_colors = {pid: colors_combined[int(i) % len(colors_combined)] for i, pid in enumerate(unique_players)}
# Clear existing graphics
if hasattr(self, '_player_scatters'):
for scatter in self._player_scatters.values():
try:
self._figure[0,0].remove_graphic(scatter)
except:
pass
if self._ball_scatter:
try:
self._figure[0,0].remove_graphic(self._ball_scatter)
except:
pass
# Initialize scatter plots for each player
player_scatters = {}
player_data_dict = {}
ball_scatter = None
ball_data = None
print("Initializing player visualizations...")
for i, pid in enumerate(unique_players):
if i % 10 == 0:
print(f" Processing player {i+1}/{n_players_found}")
try:
player_df = df_sorted.filter(pl.col(player_id_col) == pid)
if len(player_df) > 0:
player_timestamps = player_df['ts'].unique().sort().to_list()
first_player_ts = player_timestamps[0]
first_frame = player_df.filter(pl.col('ts') == first_player_ts)
if len(first_frame) > 0:
data_float32 = first_frame.select(['x', 'y', 'z']).to_numpy().astype(np.float32)
scatter = self._figure[0,0].add_scatter(
data=data_float32,
sizes=10,
colors=player_colors[pid],
alpha=0.8
)
scatter.metadata = {'player_id': pid}
player_scatters[pid] = scatter
player_data_dict[pid] = player_df
except Exception as e:
import traceback
print(f"\nERROR INITIALIZING PLAYER {pid}: {str(e)}")
traceback.print_exc()
continue
# Initialize ball visualization if ball data exists
if has_ball:
print("Initializing ball visualization...")
ball_data = df_sorted.filter(pl.col('entity_type') == 'ball')
if len(ball_data) > 0:
ball_timestamps = ball_data['ts'].unique().sort().to_list()
first_ball_ts = ball_timestamps[0]
first_ball_frame = ball_data.filter(pl.col('ts') == first_ball_ts)
if len(first_ball_frame) > 0:
data_float32 = first_ball_frame.select(['x', 'y', 'z']).to_numpy().astype(np.float32)
ball_scatter = self._figure[0,0].add_scatter(
data=data_float32,
sizes=15,
colors='orange',
alpha=1.0
)
ball_scatter.metadata = {'is_ball': True}
# Store session data
self._timestamps = timestamps
self._player_scatters = player_scatters
self._player_data_dict = player_data_dict
self._player_colors = player_colors
self._df_sorted = df_sorted
self._ball_scatter = ball_scatter
self._ball_data = ball_data
self._current_session_id = session_id
self._current_drill_type = drill_type
print(f"Initialized {len(player_scatters)} player visualizations")
if ball_scatter:
print("Ball visualization initialized")
def _setup_click_handler(self):
"""Set up click event handler for player selection"""
def on_click(ev):
xy = self._figure[0, 0].map_screen_to_world(ev)[:-1]
min_distance = float('inf')
closest_pid = None
current_ts = self._timestamps[self._current_frame]
for pid, scatter in self._player_scatters.items():
player_df = self._player_data_dict[pid]
frame_data = player_df.filter(pl.col('ts') == current_ts)
if len(frame_data) > 0:
player_x = frame_data['x'][0]
player_y = frame_data['y'][0]
distance = np.sqrt((player_x - xy[0])**2 + (player_y - xy[1])**2)
if distance < min_distance:
min_distance = distance
closest_pid = pid
if closest_pid is not None and min_distance < 10:
self._selected_player_id = closest_pid
player_df = self._player_data_dict[closest_pid]
frame_data = player_df.filter(pl.col('ts') == current_ts)
if len(frame_data) > 0:
jersey = frame_data['jersey_number'][0] if 'jersey_number' in frame_data.columns else 'N/A'
position = frame_data['position'][0] if 'position' in frame_data.columns else 'N/A'
if jersey is None:
jersey = 'N/A'
if position is None:
position = 'N/A'
self._selected_jersey = jersey
self._clicked_player_info = f"GSIS ID: {closest_pid}\nJersey: {jersey}\nPosition: {position}"
else:
self._clicked_player_info = f"GSIS ID: {closest_pid}\nJersey: N/A\nPosition: N/A"
print(f"\nClicked Player - {self._clicked_player_info}")
self._figure.renderer.add_event_handler(on_click, "click")
def _setup_keyboard_handler(self):
"""Set up keyboard event handler for frame navigation"""
# Track which keys are currently pressed
self._keys_pressed = set()
self._key_hold_counter = 0
self._key_hold_delay = 10 # frames to wait before fast-forwarding
def on_key_down(ev):
key = ev.key.lower() if hasattr(ev, 'key') else None
if key:
self._keys_pressed.add(key)
self._key_hold_counter = 0
# Immediate single frame movement on first press
if key in ['w', 'arrowup', 'd', 'arrowright']:
if self._current_frame < len(self._timestamps) - 1:
self._current_frame += 1
self._update_positions()
elif key in ['s', 'arrowdown', 'a', 'arrowleft']:
if self._current_frame > 0:
self._current_frame -= 1
self._update_positions()
def on_key_up(ev):
key = ev.key.lower() if hasattr(ev, 'key') else None
if key and key in self._keys_pressed:
self._keys_pressed.discard(key)
self._key_hold_counter = 0
self._figure.renderer.add_event_handler(on_key_down, "key_down")
self._figure.renderer.add_event_handler(on_key_up, "key_up")
def update(self):
try:
# Handle held keys for continuous frame navigation
if self._keys_pressed:
self._key_hold_counter += 1
# After initial delay, advance frames continuously
if self._key_hold_counter > self._key_hold_delay:
forward_keys = {'w', 'arrowup', 'd', 'arrowright'}
backward_keys = {'s', 'arrowdown', 'a', 'arrowleft'}
if any(k in self._keys_pressed for k in forward_keys):
if self._current_frame < len(self._timestamps) - 1:
self._current_frame += 1
self._update_positions()
elif any(k in self._keys_pressed for k in backward_keys):
if self._current_frame > 0:
self._current_frame -= 1
self._update_positions()
# Session selector
imgui.text_colored((0.3, 0.7, 1.0, 1.0), "Session Selection:")
# Create session dropdown
session_labels = [
f"Session {s['session_id']}: {s['drill_type'] if s['drill_type'] else 'Unknown'}"
for s in self._session_list
]
changed, new_idx = imgui.combo(
"##session",
self._current_session_idx,
session_labels
)
if changed and new_idx != self._current_session_idx:
self._current_session_idx = new_idx
self._current_frame = 0
self._is_playing = False
self._clicked_player_info = ""
self._selected_player_id = None
self._initialize_session(new_idx)
imgui.text(f"Drill Type: {self._current_drill_type if self._current_drill_type else 'Unknown'}")
imgui.separator()
# Display current timestamp
current_ts = self._timestamps[self._current_frame]
imgui.text(f"Timestamp: {current_ts}")
# Copy timestamp button
imgui.same_line()
if imgui.button("Copy TS"):
try:
pyperclip.copy(str(current_ts))
print(f"Copied timestamp to clipboard: {current_ts}")
except Exception as e:
print(f"Failed to copy to clipboard: {e}")
imgui.text(f"Frame: {self._current_frame + 1:,} / {len(self._timestamps):,}")
imgui.separator()
# Display selected player info
if self._clicked_player_info:
imgui.text_colored((0.2, 0.8, 0.2, 1.0), "Selected Player:")
imgui.text(self._clicked_player_info)
# Copy GSIS ID button
if imgui.button("Copy GSIS ID"):
try:
pyperclip.copy(str(self._selected_player_id))
print(f"Copied GSIS ID to clipboard: {self._selected_player_id}")
except Exception as e:
print(f"Failed to copy GSIS ID to clipboard: {e}")
imgui.same_line()
if imgui.button("Clear Selection"):
self._clicked_player_info = ""
self._selected_player_id = None
self._selected_jersey = None
imgui.separator()
# Frame slider
changed, new_frame = imgui.slider_int(
"Frame",
v=self._current_frame,
v_min=0,
v_max=len(self._timestamps) - 1
)
if changed:
self._current_frame = new_frame
self._update_positions()
# Skip frames slider
changed, skip = imgui.slider_int(
"Skip Frames",
v=self._skip_frames,
v_min=1,
v_max=100
)
if changed:
self._skip_frames = skip
imgui.text(f"(Play every {self._skip_frames} frame(s))")
imgui.separator()
# Play/Pause button
if self._is_playing:
if imgui.button("Pause"):
self._is_playing = False
else:
if imgui.button("Play"):
self._is_playing = True
imgui.same_line()
if imgui.button("Reset"):
self._current_frame = 0
self._is_playing = False
self._update_positions()
imgui.same_line()
if imgui.button("End"):
self._current_frame = len(self._timestamps) - 1
self._update_positions()
# Play speed slider
changed, speed = imgui.slider_int(
"Speed (ms/frame)",
v=self._play_speed,
v_min=10,
v_max=500
)
if changed:
self._play_speed = speed
imgui.separator()
imgui.text(f"Total Players: {len(self._player_scatters)}")
imgui.text(f"Total Frames: {len(self._timestamps):,}")
imgui.separator()
imgui.text_colored((0.7, 0.7, 0.7, 1.0), "Keyboard Controls:")
imgui.text("W/↑ or D/→: Next frame")
imgui.text("S/↓ or A/←: Previous frame")
imgui.text("(Hold key to fast-forward)")
# Auto-play logic
if self._is_playing:
self._frame_counter += 1
if self._frame_counter >= max(1, self._play_speed // 16):
self._frame_counter = 0
self._current_frame = min(
self._current_frame + self._skip_frames,
len(self._timestamps) - 1
)
self._update_positions()
if self._current_frame >= len(self._timestamps) - 1:
self._is_playing = False
except Exception as e:
import traceback
print(f"\nERROR IN GUI UPDATE: {str(e)}")
traceback.print_exc()
imgui.text(f"Error: {str(e)}")
imgui.text("Check console for details")
self._is_playing = False
def _update_positions(self):
"""Update all player positions for the current frame"""
try:
current_ts = self._timestamps[self._current_frame]
# Update player positions
for pid, scatter in list(self._player_scatters.items()):
player_df = self._player_data_dict[pid]
frame_data = player_df.filter(pl.col('ts') == current_ts)
try:
self._figure[0,0].remove_graphic(scatter)
except:
pass
if len(frame_data) > 0:
data_float32 = frame_data.select(['x', 'y', 'z']).to_numpy().astype(np.float32)
new_scatter = self._figure[0,0].add_scatter(
data=data_float32,
sizes=10,
colors=self._player_colors[pid],
alpha=0.8
)
new_scatter.metadata = {'player_id': pid}
self._player_scatters[pid] = new_scatter
else:
new_scatter = self._figure[0,0].add_scatter(
data=np.array([[-1000, -1000, 0]], dtype=np.float32),
sizes=0.1,
colors=self._player_colors[pid],
alpha=0.0
)
new_scatter.metadata = {'player_id': pid}
self._player_scatters[pid] = new_scatter
# Update ball position
if self._ball_scatter is not None and self._ball_data is not None:
ball_frame_data = self._ball_data.filter(pl.col('ts') == current_ts)
try:
self._figure[0,0].remove_graphic(self._ball_scatter)
except:
pass
if len(ball_frame_data) > 0:
data_float32 = ball_frame_data.select(['x', 'y', 'z']).to_numpy().astype(np.float32)
new_ball_scatter = self._figure[0,0].add_scatter(
data=data_float32,
sizes=15,
colors='orange',
alpha=1.0
)
new_ball_scatter.metadata = {'is_ball': True}
self._ball_scatter = new_ball_scatter
else:
new_ball_scatter = self._figure[0,0].add_scatter(
data=np.array([[-1000, -1000, 0]], dtype=np.float32),
sizes=0.1,
colors='orange',
alpha=0.0
)
new_ball_scatter.metadata = {'is_ball': True}
self._ball_scatter = new_ball_scatter
except Exception as e:
import traceback
print(f"\nERROR IN _update_positions(): {str(e)}")
traceback.print_exc()
raise
# Create and add GUI
gui = TimelapseController(
figure,
df,
session_list,
size=320,
location="right",
title="Timelapse Controls"
)
figure.add_gui(gui)
# Show the figure
figure.show()
print(f"\nVisualization ready!")
print(f"Total sessions available: {len(session_list)}")
print("\nNew Features:")
print(" - Select different sessions from the dropdown menu")
print(" - View drill type for each session")
print(" - Click on any player to see their GSIS ID, Jersey Number, and Position")
print(" - Click 'Copy TS' button to copy current timestamp to clipboard")
print(" - Click 'Copy GSIS ID' button to copy selected player's GSIS ID to clipboard")
print(" - Use WASD or Arrow keys to navigate frames")
print(" - HOLD any navigation key to fast-forward frame-by-frame")
print("\nTips:")
print(" - Use 'Skip Frames' slider to speed up playback for large datasets")
print(" - Switch sessions using the dropdown at the top of the control panel")
print(" - Focus the visualization window to use keyboard controls")
print(" - Hold navigation keys for continuous frame advancement")
# Run the event loop if this is a standalone script
if __name__ == "__main__":
fpl.loop.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment