Skip to content

Instantly share code, notes, and snippets.

@filiptronicek
Created October 25, 2025 20:04
Show Gist options
  • Select an option

  • Save filiptronicek/2de31dc5106fea162681af4d9b655eaa to your computer and use it in GitHub Desktop.

Select an option

Save filiptronicek/2de31dc5106fea162681af4d9b655eaa to your computer and use it in GitHub Desktop.
A GUI for SAM 2.1 (Windows)
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor
import cv2
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import tempfile
import shutil
class SAM2VideoSegmentationApp:
def __init__(self, root):
self.root = root
self.root.title("SAM2 Video Segmentation Tool")
# Enable high DPI scaling
try:
from ctypes import windll
windll.shcore.SetProcessDpiAwareness(1)
except:
pass
# Scale for high DPI displays
self.scale_factor = 1.5 # Adjust this if needed (1.5-2.0 for 5K)
self.root.tk.call('tk', 'scaling', self.scale_factor * 72 / 96)
self.root.geometry("1800x1200")
# Initialize variables
self.video_path = None
self.video_folder = None
self.frame_names = []
self.current_frame_idx = 0
self.points = []
self.labels = []
self.predictor = None
self.inference_state = None
self.video_segments = None
self.preview_mask = None
self.current_step = 1 # 1: Select video, 2: Annotate, 3: Preview mask, 4: Propagate, 5: Export
# Setup device
if torch.cuda.is_available():
self.device = torch.device("cuda")
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
else:
self.device = torch.device("cpu")
# SAM2 model paths
self.sam_2_checkpoint = "C:\\Users\\Filip\\Documents\\Git\\sam2\\checkpoints\\sam2.1_hiera_large.pt"
self.model_cfg = "C:\\Users\\Filip\\Documents\\Git\\sam2\\sam2\\configs\\sam2.1\\sam2.1_hiera_l.yaml"
self.setup_ui()
self.update_step_display()
def setup_ui(self):
# Configure font sizes for high DPI
default_font = ('Segoe UI', 11)
button_font = ('Segoe UI', 11, 'bold')
title_font = ('Segoe UI', 12, 'bold')
self.root.option_add('*Font', default_font)
self.root.option_add('*Button.Font', button_font)
self.root.option_add('*Label.Font', default_font)
# Main container
main_frame = ttk.Frame(self.root, padding="15")
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
# Configure grid weights
self.root.columnconfigure(0, weight=1)
self.root.rowconfigure(0, weight=1)
main_frame.columnconfigure(1, weight=1)
main_frame.rowconfigure(1, weight=1)
# Control panel (left side)
control_frame = ttk.LabelFrame(main_frame, text="Workflow Steps", padding="15")
control_frame.grid(row=0, column=0, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), padx=(0, 15))
# Step indicator
self.step_label = ttk.Label(control_frame, text="", font=('Segoe UI', 13, 'bold'))
self.step_label.pack(fill=tk.X, pady=(0, 15))
ttk.Separator(control_frame, orient='horizontal').pack(fill=tk.X, pady=10)
# Step 1: Video Selection
self.step1_frame = ttk.LabelFrame(control_frame, text="Step 1: Select Video", padding="10")
self.step1_frame.pack(fill=tk.X, pady=8)
ttk.Button(self.step1_frame, text="Browse Video File", command=self.select_video).pack(fill=tk.X, pady=5)
self.video_label = ttk.Label(self.step1_frame, text="No video selected", wraplength=300)
self.video_label.pack(fill=tk.X, pady=5)
# Step 2: Annotation
self.step2_frame = ttk.LabelFrame(control_frame, text="Step 2: Annotate Frame", padding="10")
self.step2_frame.pack(fill=tk.X, pady=8)
ttk.Label(self.step2_frame, text="Frame Navigation:").pack(anchor=tk.W)
frame_nav_frame = ttk.Frame(self.step2_frame)
frame_nav_frame.pack(fill=tk.X, pady=5)
ttk.Button(frame_nav_frame, text="<<", command=self.prev_frame, width=5).pack(side=tk.LEFT)
self.frame_label = ttk.Label(frame_nav_frame, text="Frame: 0/0")
self.frame_label.pack(side=tk.LEFT, padx=10)
ttk.Button(frame_nav_frame, text=">>", command=self.next_frame, width=5).pack(side=tk.LEFT)
ttk.Label(self.step2_frame, text="Click on frame to add points:").pack(anchor=tk.W, pady=(10, 0))
self.annotation_mode = tk.StringVar(value="positive")
ttk.Radiobutton(self.step2_frame, text="Positive Point (Green)",
variable=self.annotation_mode, value="positive").pack(anchor=tk.W)
ttk.Radiobutton(self.step2_frame, text="Negative Point (Red)",
variable=self.annotation_mode, value="negative").pack(anchor=tk.W)
ttk.Label(self.step2_frame, text="Manual Coordinates:").pack(anchor=tk.W, pady=(10, 0))
coord_frame = ttk.Frame(self.step2_frame)
coord_frame.pack(fill=tk.X, pady=5)
ttk.Label(coord_frame, text="X:").grid(row=0, column=0)
self.coord_x = ttk.Entry(coord_frame, width=10)
self.coord_x.grid(row=0, column=1, padx=2)
ttk.Label(coord_frame, text="Y:").grid(row=0, column=2, padx=(10, 0))
self.coord_y = ttk.Entry(coord_frame, width=10)
self.coord_y.grid(row=0, column=3, padx=2)
ttk.Button(self.step2_frame, text="Add Point", command=self.add_manual_point).pack(fill=tk.X, pady=5)
self.points_label = ttk.Label(self.step2_frame, text="Points: 0", foreground="blue")
self.points_label.pack(fill=tk.X, pady=5)
ttk.Button(self.step2_frame, text="Clear All Points", command=self.clear_points).pack(fill=tk.X, pady=5)
self.confirm_annotation_btn = ttk.Button(self.step2_frame, text="Confirm Annotation →",
command=self.confirm_annotation, style='Accent.TButton')
self.confirm_annotation_btn.pack(fill=tk.X, pady=(10, 5))
# Step 3: Preview Mask
self.step3_frame = ttk.LabelFrame(control_frame, text="Step 3: Preview Mask", padding="10")
self.step3_frame.pack(fill=tk.X, pady=8)
self.preview_info = ttk.Label(self.step3_frame, text="Preview the segmentation mask\nfor the current frame",
justify=tk.LEFT)
self.preview_info.pack(fill=tk.X, pady=5)
ttk.Button(self.step3_frame, text="← Back to Annotation",
command=self.back_to_annotation).pack(fill=tk.X, pady=5)
self.propagate_btn = ttk.Button(self.step3_frame, text="Propagate to All Frames →",
command=self.run_segmentation, style='Accent.TButton')
self.propagate_btn.pack(fill=tk.X, pady=5)
# Step 4: Propagation
self.step4_frame = ttk.LabelFrame(control_frame, text="Step 4: Propagate Segmentation", padding="10")
self.step4_frame.pack(fill=tk.X, pady=8)
self.propagate_info = ttk.Label(self.step4_frame, text="Segmenting all frames...", justify=tk.LEFT)
self.propagate_info.pack(fill=tk.X, pady=5)
self.propagate_progress = ttk.Progressbar(self.step4_frame, mode='indeterminate')
self.propagate_progress.pack(fill=tk.X, pady=5)
# Step 5: Export
self.step5_frame = ttk.LabelFrame(control_frame, text="Step 5: Export Results", padding="10")
self.step5_frame.pack(fill=tk.X, pady=8)
self.export_info = ttk.Label(self.step5_frame, text="Segmentation complete!\nReady to export masks.",
justify=tk.LEFT)
self.export_info.pack(fill=tk.X, pady=5)
ttk.Button(self.step5_frame, text="Export Masks", command=self.export_masks).pack(fill=tk.X, pady=5)
ttk.Button(self.step5_frame, text="Start Over", command=self.reset_app).pack(fill=tk.X, pady=5)
ttk.Separator(control_frame, orient='horizontal').pack(fill=tk.X, pady=15)
# Status
self.status_label = ttk.Label(control_frame, text="Status: Ready", foreground="green",
font=('Segoe UI', 11, 'bold'))
self.status_label.pack(fill=tk.X, pady=10)
# Canvas frame (right side)
canvas_frame = ttk.LabelFrame(main_frame, text="Video Frame", padding="15")
canvas_frame.grid(row=0, column=1, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S))
# Create matplotlib figure with larger size
self.fig, self.ax = plt.subplots(figsize=(12, 9))
self.fig.patch.set_facecolor('#f0f0f0')
self.canvas = FigureCanvasTkAgg(self.fig, master=canvas_frame)
self.canvas.draw()
self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
# Bind click event
self.canvas.mpl_connect('button_press_event', self.on_canvas_click)
def update_step_display(self):
"""Update UI to show only relevant controls for current step"""
# Hide all step frames
for frame in [self.step1_frame, self.step2_frame, self.step3_frame,
self.step4_frame, self.step5_frame]:
frame.pack_forget()
# Show relevant frames based on current step
if self.current_step == 1:
self.step_label.config(text="📹 Step 1: Select Video File")
self.step1_frame.pack(fill=tk.X, pady=8)
elif self.current_step == 2:
self.step_label.config(text="🎯 Step 2: Annotate Frame")
self.step1_frame.pack(fill=tk.X, pady=8)
self.step2_frame.pack(fill=tk.X, pady=8)
elif self.current_step == 3:
self.step_label.config(text="👁️ Step 3: Preview Mask")
self.step3_frame.pack(fill=tk.X, pady=8)
elif self.current_step == 4:
self.step_label.config(text="⚡ Step 4: Propagating...")
self.step4_frame.pack(fill=tk.X, pady=8)
elif self.current_step == 5:
self.step_label.config(text="✅ Step 5: Export Results")
self.step5_frame.pack(fill=tk.X, pady=8)
def select_video(self):
filepath = filedialog.askopenfilename(
title="Select Video File",
filetypes=[("Video files", "*.mp4 *.avi *.mov *.mkv"), ("All files", "*.*")]
)
if filepath:
self.video_path = filepath
self.video_label.config(text=f"✓ {os.path.basename(filepath)}")
self.extract_frames()
def extract_frames(self):
try:
self.status_label.config(text="Status: Extracting frames...", foreground="orange")
self.root.update()
# Create temporary folder for frames
self.video_folder = tempfile.mkdtemp()
# Extract frames using OpenCV
cap = cv2.VideoCapture(self.video_path)
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_path = os.path.join(self.video_folder, f"{frame_idx}.jpg")
cv2.imwrite(frame_path, frame)
frame_idx += 1
cap.release()
# Get frame names
self.frame_names = [
p for p in os.listdir(self.video_folder)
if p.endswith(".jpg") or p.endswith(".png")
]
self.frame_names.sort(key=lambda x: int(x.split(".")[0]))
self.current_frame_idx = 0
self.current_step = 2
self.update_step_display()
self.display_frame()
# Initialize SAM2 automatically
self.init_sam2()
self.status_label.config(text=f"Status: Ready - {len(self.frame_names)} frames extracted",
foreground="green")
except Exception as e:
self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red")
messagebox.showerror("Error", f"Failed to extract frames: {str(e)}")
def display_frame(self, show_mask=False):
if not self.frame_names:
return
self.ax.clear()
frame_path = os.path.join(self.video_folder, self.frame_names[self.current_frame_idx])
img = Image.open(frame_path)
self.ax.imshow(img)
# Show mask if in preview mode
if show_mask and self.preview_mask is not None:
self.show_mask(self.preview_mask, self.ax, obj_id=1)
# Draw existing points
for i, (point, label) in enumerate(zip(self.points, self.labels)):
color = "green" if label == 1 else "red"
self.ax.plot(point[0], point[1], marker="*", color=color, markersize=25,
markeredgecolor="white", markeredgewidth=2.5)
title = f"Frame {self.current_frame_idx}/{len(self.frame_names)-1}"
if show_mask:
title += " - Mask Preview"
self.ax.set_title(title, fontsize=14, fontweight='bold')
self.ax.axis('off')
self.canvas.draw()
self.frame_label.config(text=f"Frame: {self.current_frame_idx}/{len(self.frame_names)-1}")
def show_mask(self, mask, ax, obj_id=None, random_color=False):
"""Display mask overlay on the image"""
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def on_canvas_click(self, event):
if event.inaxes != self.ax or not self.frame_names or self.current_step != 2:
return
x, y = int(event.xdata), int(event.ydata)
label = 1 if self.annotation_mode.get() == "positive" else 0
self.points.append([x, y])
self.labels.append(label)
self.display_frame()
self.update_points_label()
def add_manual_point(self):
try:
x = int(self.coord_x.get())
y = int(self.coord_y.get())
label = 1 if self.annotation_mode.get() == "positive" else 0
self.points.append([x, y])
self.labels.append(label)
self.coord_x.delete(0, tk.END)
self.coord_y.delete(0, tk.END)
self.display_frame()
self.update_points_label()
except ValueError:
messagebox.showerror("Error", "Please enter valid integer coordinates")
def clear_points(self):
self.points = []
self.labels = []
self.display_frame()
self.update_points_label()
def update_points_label(self):
pos_count = sum(1 for l in self.labels if l == 1)
neg_count = sum(1 for l in self.labels if l == 0)
self.points_label.config(text=f"Points: {len(self.points)} (✓ {pos_count}, ✗ {neg_count})")
def prev_frame(self):
if self.current_frame_idx > 0:
self.current_frame_idx -= 1
self.display_frame()
def next_frame(self):
if self.current_frame_idx < len(self.frame_names) - 1:
self.current_frame_idx += 1
self.display_frame()
def init_sam2(self):
"""Initialize SAM2 model"""
try:
self.status_label.config(text="Status: Initializing SAM2...", foreground="orange")
self.root.update()
self.predictor = build_sam2_video_predictor(
self.model_cfg,
self.sam_2_checkpoint,
device=self.device
)
self.inference_state = self.predictor.init_state(video_path=self.video_folder)
self.predictor.reset_state(self.inference_state)
self.status_label.config(text="Status: SAM2 ready", foreground="green")
except Exception as e:
self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red")
messagebox.showerror("Error", f"Failed to initialize SAM2: {str(e)}")
def confirm_annotation(self):
"""Generate preview mask for current frame"""
if not self.predictor:
messagebox.showwarning("Warning", "SAM2 not initialized")
return
if not self.points:
messagebox.showwarning("Warning", "Please add at least one point annotation")
return
try:
self.status_label.config(text="Status: Generating mask preview...", foreground="orange")
self.root.update()
# Add points to predictor
points_array = np.array(self.points, dtype=np.float32)
labels_array = np.array(self.labels, dtype=np.int32)
_, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
inference_state=self.inference_state,
frame_idx=self.current_frame_idx,
points=points_array,
labels=labels_array,
obj_id=1,
)
# Store the preview mask
self.preview_mask = (out_mask_logits[0] > 0.0).cpu().numpy()
# Move to preview step
self.current_step = 3
self.update_step_display()
self.display_frame(show_mask=True)
self.status_label.config(text="Status: Preview ready", foreground="green")
except Exception as e:
self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red")
messagebox.showerror("Error", f"Failed to generate preview: {str(e)}")
def back_to_annotation(self):
"""Go back to annotation step"""
self.current_step = 2
self.preview_mask = None
self.update_step_display()
self.display_frame()
# Reset the inference state
if self.predictor:
self.predictor.reset_state(self.inference_state)
def run_segmentation(self):
"""Propagate segmentation to all frames"""
if not self.predictor or self.preview_mask is None:
messagebox.showwarning("Warning", "Please confirm annotation first")
return
try:
self.current_step = 4
self.update_step_display()
self.status_label.config(text="Status: Propagating to all frames...", foreground="orange")
self.propagate_progress.start()
self.root.update()
# Propagate through video
self.video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(
self.inference_state
):
self.video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# Update progress
progress_pct = (out_frame_idx + 1) / len(self.frame_names) * 100
self.propagate_info.config(text=f"Processing: {out_frame_idx + 1}/{len(self.frame_names)} frames\n({progress_pct:.1f}%)")
self.root.update()
self.propagate_progress.stop()
# Move to export step
self.current_step = 5
self.update_step_display()
self.status_label.config(
text=f"Status: Segmentation complete! {len(self.video_segments)} frames processed",
foreground="green"
)
messagebox.showinfo("Success", f"Segmentation completed for {len(self.video_segments)} frames!")
except Exception as e:
self.propagate_progress.stop()
self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red")
messagebox.showerror("Error", f"Failed to run segmentation: {str(e)}")
def export_masks(self):
"""Export masks to selected directory"""
if not self.video_segments:
messagebox.showwarning("Warning", "No segmentation results to export")
return
output_dir = filedialog.askdirectory(title="Select Output Directory")
if not output_dir:
return
try:
self.status_label.config(text="Status: Exporting masks...", foreground="orange")
self.root.update()
output_mask_dir = os.path.join(output_dir, "output_masks")
os.makedirs(output_mask_dir, exist_ok=True)
for out_frame_idx in range(len(self.frame_names)):
# Get the first mask to determine shape
first_mask = next(iter(self.video_segments[out_frame_idx].values()))
if len(first_mask.shape) == 3:
h, w = first_mask.shape[-2:]
combined_mask = np.zeros((h, w), dtype=np.uint8)
else:
combined_mask = np.zeros_like(first_mask, dtype=np.uint8)
# Combine all masks
for out_mask in self.video_segments[out_frame_idx].values():
mask_2d = np.squeeze(out_mask)
combined_mask = np.logical_or(combined_mask, mask_2d).astype(np.uint8)
# Scale to 0-255
mask_img = combined_mask * 255
# Save as PNG
cv2.imwrite(
os.path.join(output_mask_dir, f"mask_{out_frame_idx:04d}.png"),
mask_img
)
self.status_label.config(text="Status: Masks exported successfully!", foreground="green")
messagebox.showinfo("Success", f"Masks exported to:\n{output_mask_dir}")
except Exception as e:
self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red")
messagebox.showerror("Error", f"Failed to export masks: {str(e)}")
def reset_app(self):
"""Reset the application to start over"""
if messagebox.askyesno("Confirm", "Start over? This will clear all current progress."):
# Clean up
if self.video_folder and os.path.exists(self.video_folder):
shutil.rmtree(self.video_folder)
# Reset variables
self.video_path = None
self.video_folder = None
self.frame_names = []
self.current_frame_idx = 0
self.points = []
self.labels = []
self.inference_state = None
self.video_segments = None
self.preview_mask = None
self.current_step = 1
# Reset UI
self.video_label.config(text="No video selected")
self.points_label.config(text="Points: 0")
self.frame_label.config(text="Frame: 0/0")
self.ax.clear()
self.canvas.draw()
self.update_step_display()
self.status_label.config(text="Status: Ready", foreground="green")
def __del__(self):
# Cleanup temporary folder
if self.video_folder and os.path.exists(self.video_folder):
shutil.rmtree(self.video_folder)
def main():
root = tk.Tk()
app = SAM2VideoSegmentationApp(root)
root.mainloop()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment