Last active
January 11, 2026 04:13
-
-
Save 903124/706c6aa128d3ef5fa21dab13b02ff1d8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import polars as pl | |
| import numpy as np | |
| import json | |
| import fastplotlib as fpl | |
| from fastplotlib.ui import EdgeWindow | |
| from imgui_bundle import imgui | |
| import matplotlib.pyplot as plt | |
| import pyperclip # For clipboard functionality | |
| # Set numpy defaults | |
| np.set_printoptions(precision=8) | |
| # Read the main data file | |
| file_path = "practice_data/2024_West_Practice_1.snappy-009.parquet" | |
| df = pl.read_parquet(file_path) | |
| # Bad parquet duplicates: keep one copy, drop wrong labels | |
| df = ( | |
| df.filter(pl.col("session_id") == 1) | |
| .drop(["session_id", "drill_type"], strict=False) | |
| ) | |
| # Assume only one dataset_id in this parquet | |
| dataset_id = df.select(pl.first("dataset_id")).item() | |
| # Load + dedupe session table | |
| session_df = ( | |
| pl.read_csv("session_timestamps.csv") | |
| .filter(pl.col("dataset_id") == dataset_id) | |
| .unique(subset=["dataset_id", "session_id", "session_metadata_json"]) | |
| ) | |
| # Parse JSON to get startTs, endTs, drillType | |
| json_schema = pl.Struct({ | |
| 'startTs': pl.Int64, | |
| 'endTs': pl.Int64, | |
| 'drillType': pl.Utf8, | |
| 'session_id': pl.Int64, | |
| 'state': pl.Utf8, | |
| 'drillId': pl.Int64, | |
| }) | |
| session_df = ( | |
| session_df | |
| .with_columns( | |
| pl.col('session_metadata_json').str.json_decode(json_schema).alias('metadata') | |
| ) | |
| .with_columns([ | |
| (pl.col('metadata').struct.field('startTs') / 1000).alias('start_ts'), | |
| (pl.col('metadata').struct.field('endTs') / 1000).alias('end_ts'), | |
| pl.col('metadata').struct.field('drillType').fill_null('').alias('drill_type'), | |
| pl.col('metadata').struct.field('session_id').alias('json_session_id'), | |
| ]) | |
| .select(['session_id', 'start_ts', 'end_ts', 'drill_type', 'json_session_id']) | |
| .sort('start_ts') | |
| ) | |
| print("Session lookup table:") | |
| print(session_df) | |
| # Convert df timestamps to seconds | |
| df = df.with_columns( | |
| (pl.col("ts").str.to_datetime().dt.epoch(time_unit="ns") / 1e9).alias("ts_seconds") | |
| ) | |
| # Assign session_id + drill_type by iterating sessions (memory-efficient) | |
| result_chunks = [] | |
| for row in session_df.iter_rows(named=True): | |
| chunk = df.filter( | |
| (pl.col("ts_seconds") >= row["start_ts"]) & | |
| (pl.col("ts_seconds") <= row["end_ts"]) | |
| ).with_columns( | |
| pl.lit(row["session_id"]).alias("session_id"), | |
| pl.lit(row["drill_type"]).alias("drill_type"), | |
| ) | |
| result_chunks.append(chunk) | |
| df = pl.concat(result_chunks).drop("ts_seconds") | |
| print(f"\nFiltered shape: {len(df)}") | |
| print(f"Unique sessions: {df['session_id'].n_unique()}") | |
| print(f"\nSession distribution:") | |
| print(df.group_by('session_id').len().sort('session_id')) | |
| # Read player data | |
| player_data = pl.read_parquet("shrine_bowl_players.parquet") | |
| player_college_data = pl.read_csv("shrine_bowl_players_college_stats.csv") | |
| # Handle gsis_id conversion | |
| df = df.with_columns([ | |
| pl.col('gsis_id').fill_null(0).cast(pl.Int64) | |
| ]) | |
| player_college_data = player_college_data.with_columns([ | |
| pl.col('college_gsis_id').cast(pl.Int64) | |
| ]) | |
| # Get the last season for each player to get their most recent position | |
| player_last_season = ( | |
| player_college_data | |
| .sort('season') | |
| .group_by('college_gsis_id') | |
| .last() | |
| ) | |
| df = df.join( | |
| player_last_season.select(['college_gsis_id', 'position', 'player_name']), | |
| left_on='gsis_id', | |
| right_on='college_gsis_id', | |
| how='left' | |
| ) | |
| # Get unique sessions with drill types for filtering | |
| session_info = df.select(['session_id', 'drill_type']).unique().sort('session_id') | |
| session_list = session_info.to_dicts() | |
| print("\nAvailable sessions:") | |
| print(session_info) | |
| # Create figure | |
| figure = fpl.Figure(size=(1200, 700)) | |
| # Add football field lines | |
| FIELD_WIDTH = 53.3 | |
| FIELD_LENGTH = 120 | |
| # Add field boundaries | |
| figure[0,0].add_line(np.array([[0, 0, 0], [FIELD_LENGTH, 0, 0]]), thickness=2, colors="white") | |
| figure[0,0].add_line(np.array([[0, FIELD_WIDTH, 0], [FIELD_LENGTH, FIELD_WIDTH, 0]]), thickness=2, colors="white") | |
| figure[0,0].add_line(np.array([[0, 0, 0], [0, FIELD_WIDTH, 0]]), thickness=2, colors="white") | |
| figure[0,0].add_line(np.array([[FIELD_LENGTH, 0, 0], [FIELD_LENGTH, FIELD_WIDTH, 0]]), thickness=2, colors="white") | |
| # Add yard lines every 10 yards | |
| for yard in range(10, int(FIELD_LENGTH), 10): | |
| figure[0,0].add_line( | |
| np.array([[yard, 0, 0], [yard, FIELD_WIDTH, 0]]), | |
| thickness=1, | |
| colors="gray", | |
| alpha=0.5 | |
| ) | |
| # Set axis labels and camera view | |
| figure[0,0].axes.x.name = "X Position (yards)" | |
| figure[0,0].axes.y.name = "Y Position (yards)" | |
| figure[0,0].camera.position = (60, 26.65, 100) | |
| class TimelapseController(EdgeWindow): | |
| def __init__(self, figure, df, session_list, size, location, title): | |
| super().__init__(figure=figure, size=size, location=location, title=title) | |
| # Store full dataset | |
| self._full_df = df | |
| self._session_list = session_list | |
| self._current_session_idx = 0 # Index into session_list | |
| # Initialize with first session | |
| self._initialize_session(self._current_session_idx) | |
| # Animation state | |
| self._current_frame = 0 | |
| self._is_playing = False | |
| self._play_speed = 50 | |
| self._frame_counter = 0 | |
| self._skip_frames = 1 | |
| # Selected player info | |
| self._selected_player_id = None | |
| self._selected_jersey = None | |
| self._clicked_player_info = "" | |
| # Set up event handlers | |
| self._setup_click_handler() | |
| self._setup_keyboard_handler() | |
| def _initialize_session(self, session_idx): | |
| """Initialize or switch to a specific session""" | |
| session_id = self._session_list[session_idx]['session_id'] | |
| drill_type = self._session_list[session_idx]['drill_type'] | |
| print(f"\nLoading Session {session_id}: {drill_type}") | |
| # Filter data for this session | |
| df_session = self._full_df.filter(pl.col('session_id') == session_id) | |
| # Add z coordinate if not present | |
| if 'z' not in df_session.columns: | |
| df_session = df_session.with_columns([ | |
| pl.lit(0.0).alias('z') | |
| ]) | |
| # Prepare the data | |
| df_sorted = df_session.sort('ts') | |
| timestamps = df_sorted['ts'].unique().sort().to_list() | |
| # Get unique players and assign colors | |
| player_id_col = 'gsis_id' | |
| unique_players = ( | |
| df_sorted | |
| .filter(pl.col('entity_type') != 'ball') | |
| .select(player_id_col) | |
| .drop_nulls() | |
| .unique() | |
| .to_series() | |
| .to_list() | |
| ) | |
| n_players_found = len(unique_players) | |
| # Check if ball exists in data | |
| has_ball = (df_sorted['entity_type'] == 'ball').any() | |
| print(f"Found {n_players_found} unique players") | |
| print(f"Found {len(timestamps)} timestamps") | |
| if has_ball: | |
| print("Ball tracking data found") | |
| # Create extended color palette | |
| colors_tab20 = plt.cm.tab20(np.linspace(0, 1, 20)) | |
| colors_tab20b = plt.cm.tab20b(np.linspace(0, 1, 20)) | |
| colors_combined = np.vstack([colors_tab20, colors_tab20b]) | |
| player_colors = {pid: colors_combined[int(i) % len(colors_combined)] for i, pid in enumerate(unique_players)} | |
| # Clear existing graphics | |
| if hasattr(self, '_player_scatters'): | |
| for scatter in self._player_scatters.values(): | |
| try: | |
| self._figure[0,0].remove_graphic(scatter) | |
| except: | |
| pass | |
| if self._ball_scatter: | |
| try: | |
| self._figure[0,0].remove_graphic(self._ball_scatter) | |
| except: | |
| pass | |
| # Initialize scatter plots for each player | |
| player_scatters = {} | |
| player_data_dict = {} | |
| ball_scatter = None | |
| ball_data = None | |
| print("Initializing player visualizations...") | |
| for i, pid in enumerate(unique_players): | |
| if i % 10 == 0: | |
| print(f" Processing player {i+1}/{n_players_found}") | |
| try: | |
| player_df = df_sorted.filter(pl.col(player_id_col) == pid) | |
| if len(player_df) > 0: | |
| player_timestamps = player_df['ts'].unique().sort().to_list() | |
| first_player_ts = player_timestamps[0] | |
| first_frame = player_df.filter(pl.col('ts') == first_player_ts) | |
| if len(first_frame) > 0: | |
| data_float32 = first_frame.select(['x', 'y', 'z']).to_numpy().astype(np.float32) | |
| scatter = self._figure[0,0].add_scatter( | |
| data=data_float32, | |
| sizes=10, | |
| colors=player_colors[pid], | |
| alpha=0.8 | |
| ) | |
| scatter.metadata = {'player_id': pid} | |
| player_scatters[pid] = scatter | |
| player_data_dict[pid] = player_df | |
| except Exception as e: | |
| import traceback | |
| print(f"\nERROR INITIALIZING PLAYER {pid}: {str(e)}") | |
| traceback.print_exc() | |
| continue | |
| # Initialize ball visualization if ball data exists | |
| if has_ball: | |
| print("Initializing ball visualization...") | |
| ball_data = df_sorted.filter(pl.col('entity_type') == 'ball') | |
| if len(ball_data) > 0: | |
| ball_timestamps = ball_data['ts'].unique().sort().to_list() | |
| first_ball_ts = ball_timestamps[0] | |
| first_ball_frame = ball_data.filter(pl.col('ts') == first_ball_ts) | |
| if len(first_ball_frame) > 0: | |
| data_float32 = first_ball_frame.select(['x', 'y', 'z']).to_numpy().astype(np.float32) | |
| ball_scatter = self._figure[0,0].add_scatter( | |
| data=data_float32, | |
| sizes=15, | |
| colors='orange', | |
| alpha=1.0 | |
| ) | |
| ball_scatter.metadata = {'is_ball': True} | |
| # Store session data | |
| self._timestamps = timestamps | |
| self._player_scatters = player_scatters | |
| self._player_data_dict = player_data_dict | |
| self._player_colors = player_colors | |
| self._df_sorted = df_sorted | |
| self._ball_scatter = ball_scatter | |
| self._ball_data = ball_data | |
| self._current_session_id = session_id | |
| self._current_drill_type = drill_type | |
| print(f"Initialized {len(player_scatters)} player visualizations") | |
| if ball_scatter: | |
| print("Ball visualization initialized") | |
| def _setup_click_handler(self): | |
| """Set up click event handler for player selection""" | |
| def on_click(ev): | |
| xy = self._figure[0, 0].map_screen_to_world(ev)[:-1] | |
| min_distance = float('inf') | |
| closest_pid = None | |
| current_ts = self._timestamps[self._current_frame] | |
| for pid, scatter in self._player_scatters.items(): | |
| player_df = self._player_data_dict[pid] | |
| frame_data = player_df.filter(pl.col('ts') == current_ts) | |
| if len(frame_data) > 0: | |
| player_x = frame_data['x'][0] | |
| player_y = frame_data['y'][0] | |
| distance = np.sqrt((player_x - xy[0])**2 + (player_y - xy[1])**2) | |
| if distance < min_distance: | |
| min_distance = distance | |
| closest_pid = pid | |
| if closest_pid is not None and min_distance < 10: | |
| self._selected_player_id = closest_pid | |
| player_df = self._player_data_dict[closest_pid] | |
| frame_data = player_df.filter(pl.col('ts') == current_ts) | |
| if len(frame_data) > 0: | |
| jersey = frame_data['jersey_number'][0] if 'jersey_number' in frame_data.columns else 'N/A' | |
| position = frame_data['position'][0] if 'position' in frame_data.columns else 'N/A' | |
| if jersey is None: | |
| jersey = 'N/A' | |
| if position is None: | |
| position = 'N/A' | |
| self._selected_jersey = jersey | |
| self._clicked_player_info = f"GSIS ID: {closest_pid}\nJersey: {jersey}\nPosition: {position}" | |
| else: | |
| self._clicked_player_info = f"GSIS ID: {closest_pid}\nJersey: N/A\nPosition: N/A" | |
| print(f"\nClicked Player - {self._clicked_player_info}") | |
| self._figure.renderer.add_event_handler(on_click, "click") | |
| def _setup_keyboard_handler(self): | |
| """Set up keyboard event handler for frame navigation""" | |
| # Track which keys are currently pressed | |
| self._keys_pressed = set() | |
| self._key_hold_counter = 0 | |
| self._key_hold_delay = 10 # frames to wait before fast-forwarding | |
| def on_key_down(ev): | |
| key = ev.key.lower() if hasattr(ev, 'key') else None | |
| if key: | |
| self._keys_pressed.add(key) | |
| self._key_hold_counter = 0 | |
| # Immediate single frame movement on first press | |
| if key in ['w', 'arrowup', 'd', 'arrowright']: | |
| if self._current_frame < len(self._timestamps) - 1: | |
| self._current_frame += 1 | |
| self._update_positions() | |
| elif key in ['s', 'arrowdown', 'a', 'arrowleft']: | |
| if self._current_frame > 0: | |
| self._current_frame -= 1 | |
| self._update_positions() | |
| def on_key_up(ev): | |
| key = ev.key.lower() if hasattr(ev, 'key') else None | |
| if key and key in self._keys_pressed: | |
| self._keys_pressed.discard(key) | |
| self._key_hold_counter = 0 | |
| self._figure.renderer.add_event_handler(on_key_down, "key_down") | |
| self._figure.renderer.add_event_handler(on_key_up, "key_up") | |
| def update(self): | |
| try: | |
| # Handle held keys for continuous frame navigation | |
| if self._keys_pressed: | |
| self._key_hold_counter += 1 | |
| # After initial delay, advance frames continuously | |
| if self._key_hold_counter > self._key_hold_delay: | |
| forward_keys = {'w', 'arrowup', 'd', 'arrowright'} | |
| backward_keys = {'s', 'arrowdown', 'a', 'arrowleft'} | |
| if any(k in self._keys_pressed for k in forward_keys): | |
| if self._current_frame < len(self._timestamps) - 1: | |
| self._current_frame += 1 | |
| self._update_positions() | |
| elif any(k in self._keys_pressed for k in backward_keys): | |
| if self._current_frame > 0: | |
| self._current_frame -= 1 | |
| self._update_positions() | |
| # Session selector | |
| imgui.text_colored((0.3, 0.7, 1.0, 1.0), "Session Selection:") | |
| # Create session dropdown | |
| session_labels = [ | |
| f"Session {s['session_id']}: {s['drill_type'] if s['drill_type'] else 'Unknown'}" | |
| for s in self._session_list | |
| ] | |
| changed, new_idx = imgui.combo( | |
| "##session", | |
| self._current_session_idx, | |
| session_labels | |
| ) | |
| if changed and new_idx != self._current_session_idx: | |
| self._current_session_idx = new_idx | |
| self._current_frame = 0 | |
| self._is_playing = False | |
| self._clicked_player_info = "" | |
| self._selected_player_id = None | |
| self._initialize_session(new_idx) | |
| imgui.text(f"Drill Type: {self._current_drill_type if self._current_drill_type else 'Unknown'}") | |
| imgui.separator() | |
| # Display current timestamp | |
| current_ts = self._timestamps[self._current_frame] | |
| imgui.text(f"Timestamp: {current_ts}") | |
| # Copy timestamp button | |
| imgui.same_line() | |
| if imgui.button("Copy TS"): | |
| try: | |
| pyperclip.copy(str(current_ts)) | |
| print(f"Copied timestamp to clipboard: {current_ts}") | |
| except Exception as e: | |
| print(f"Failed to copy to clipboard: {e}") | |
| imgui.text(f"Frame: {self._current_frame + 1:,} / {len(self._timestamps):,}") | |
| imgui.separator() | |
| # Display selected player info | |
| if self._clicked_player_info: | |
| imgui.text_colored((0.2, 0.8, 0.2, 1.0), "Selected Player:") | |
| imgui.text(self._clicked_player_info) | |
| # Copy GSIS ID button | |
| if imgui.button("Copy GSIS ID"): | |
| try: | |
| pyperclip.copy(str(self._selected_player_id)) | |
| print(f"Copied GSIS ID to clipboard: {self._selected_player_id}") | |
| except Exception as e: | |
| print(f"Failed to copy GSIS ID to clipboard: {e}") | |
| imgui.same_line() | |
| if imgui.button("Clear Selection"): | |
| self._clicked_player_info = "" | |
| self._selected_player_id = None | |
| self._selected_jersey = None | |
| imgui.separator() | |
| # Frame slider | |
| changed, new_frame = imgui.slider_int( | |
| "Frame", | |
| v=self._current_frame, | |
| v_min=0, | |
| v_max=len(self._timestamps) - 1 | |
| ) | |
| if changed: | |
| self._current_frame = new_frame | |
| self._update_positions() | |
| # Skip frames slider | |
| changed, skip = imgui.slider_int( | |
| "Skip Frames", | |
| v=self._skip_frames, | |
| v_min=1, | |
| v_max=100 | |
| ) | |
| if changed: | |
| self._skip_frames = skip | |
| imgui.text(f"(Play every {self._skip_frames} frame(s))") | |
| imgui.separator() | |
| # Play/Pause button | |
| if self._is_playing: | |
| if imgui.button("Pause"): | |
| self._is_playing = False | |
| else: | |
| if imgui.button("Play"): | |
| self._is_playing = True | |
| imgui.same_line() | |
| if imgui.button("Reset"): | |
| self._current_frame = 0 | |
| self._is_playing = False | |
| self._update_positions() | |
| imgui.same_line() | |
| if imgui.button("End"): | |
| self._current_frame = len(self._timestamps) - 1 | |
| self._update_positions() | |
| # Play speed slider | |
| changed, speed = imgui.slider_int( | |
| "Speed (ms/frame)", | |
| v=self._play_speed, | |
| v_min=10, | |
| v_max=500 | |
| ) | |
| if changed: | |
| self._play_speed = speed | |
| imgui.separator() | |
| imgui.text(f"Total Players: {len(self._player_scatters)}") | |
| imgui.text(f"Total Frames: {len(self._timestamps):,}") | |
| imgui.separator() | |
| imgui.text_colored((0.7, 0.7, 0.7, 1.0), "Keyboard Controls:") | |
| imgui.text("W/↑ or D/→: Next frame") | |
| imgui.text("S/↓ or A/←: Previous frame") | |
| imgui.text("(Hold key to fast-forward)") | |
| # Auto-play logic | |
| if self._is_playing: | |
| self._frame_counter += 1 | |
| if self._frame_counter >= max(1, self._play_speed // 16): | |
| self._frame_counter = 0 | |
| self._current_frame = min( | |
| self._current_frame + self._skip_frames, | |
| len(self._timestamps) - 1 | |
| ) | |
| self._update_positions() | |
| if self._current_frame >= len(self._timestamps) - 1: | |
| self._is_playing = False | |
| except Exception as e: | |
| import traceback | |
| print(f"\nERROR IN GUI UPDATE: {str(e)}") | |
| traceback.print_exc() | |
| imgui.text(f"Error: {str(e)}") | |
| imgui.text("Check console for details") | |
| self._is_playing = False | |
| def _update_positions(self): | |
| """Update all player positions for the current frame""" | |
| try: | |
| current_ts = self._timestamps[self._current_frame] | |
| # Update player positions | |
| for pid, scatter in list(self._player_scatters.items()): | |
| player_df = self._player_data_dict[pid] | |
| frame_data = player_df.filter(pl.col('ts') == current_ts) | |
| try: | |
| self._figure[0,0].remove_graphic(scatter) | |
| except: | |
| pass | |
| if len(frame_data) > 0: | |
| data_float32 = frame_data.select(['x', 'y', 'z']).to_numpy().astype(np.float32) | |
| new_scatter = self._figure[0,0].add_scatter( | |
| data=data_float32, | |
| sizes=10, | |
| colors=self._player_colors[pid], | |
| alpha=0.8 | |
| ) | |
| new_scatter.metadata = {'player_id': pid} | |
| self._player_scatters[pid] = new_scatter | |
| else: | |
| new_scatter = self._figure[0,0].add_scatter( | |
| data=np.array([[-1000, -1000, 0]], dtype=np.float32), | |
| sizes=0.1, | |
| colors=self._player_colors[pid], | |
| alpha=0.0 | |
| ) | |
| new_scatter.metadata = {'player_id': pid} | |
| self._player_scatters[pid] = new_scatter | |
| # Update ball position | |
| if self._ball_scatter is not None and self._ball_data is not None: | |
| ball_frame_data = self._ball_data.filter(pl.col('ts') == current_ts) | |
| try: | |
| self._figure[0,0].remove_graphic(self._ball_scatter) | |
| except: | |
| pass | |
| if len(ball_frame_data) > 0: | |
| data_float32 = ball_frame_data.select(['x', 'y', 'z']).to_numpy().astype(np.float32) | |
| new_ball_scatter = self._figure[0,0].add_scatter( | |
| data=data_float32, | |
| sizes=15, | |
| colors='orange', | |
| alpha=1.0 | |
| ) | |
| new_ball_scatter.metadata = {'is_ball': True} | |
| self._ball_scatter = new_ball_scatter | |
| else: | |
| new_ball_scatter = self._figure[0,0].add_scatter( | |
| data=np.array([[-1000, -1000, 0]], dtype=np.float32), | |
| sizes=0.1, | |
| colors='orange', | |
| alpha=0.0 | |
| ) | |
| new_ball_scatter.metadata = {'is_ball': True} | |
| self._ball_scatter = new_ball_scatter | |
| except Exception as e: | |
| import traceback | |
| print(f"\nERROR IN _update_positions(): {str(e)}") | |
| traceback.print_exc() | |
| raise | |
| # Create and add GUI | |
| gui = TimelapseController( | |
| figure, | |
| df, | |
| session_list, | |
| size=320, | |
| location="right", | |
| title="Timelapse Controls" | |
| ) | |
| figure.add_gui(gui) | |
| # Show the figure | |
| figure.show() | |
| print(f"\nVisualization ready!") | |
| print(f"Total sessions available: {len(session_list)}") | |
| print("\nNew Features:") | |
| print(" - Select different sessions from the dropdown menu") | |
| print(" - View drill type for each session") | |
| print(" - Click on any player to see their GSIS ID, Jersey Number, and Position") | |
| print(" - Click 'Copy TS' button to copy current timestamp to clipboard") | |
| print(" - Click 'Copy GSIS ID' button to copy selected player's GSIS ID to clipboard") | |
| print(" - Use WASD or Arrow keys to navigate frames") | |
| print(" - HOLD any navigation key to fast-forward frame-by-frame") | |
| print("\nTips:") | |
| print(" - Use 'Skip Frames' slider to speed up playback for large datasets") | |
| print(" - Switch sessions using the dropdown at the top of the control panel") | |
| print(" - Focus the visualization window to use keyboard controls") | |
| print(" - Hold navigation keys for continuous frame advancement") | |
| # Run the event loop if this is a standalone script | |
| if __name__ == "__main__": | |
| fpl.loop.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment