Created
October 25, 2025 20:04
-
-
Save filiptronicek/2de31dc5106fea162681af4d9b655eaa to your computer and use it in GitHub Desktop.
A GUI for SAM 2.1 (Windows)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg | |
| from PIL import Image | |
| from sam2.build_sam import build_sam2_video_predictor | |
| import cv2 | |
| import tkinter as tk | |
| from tkinter import filedialog, messagebox, ttk | |
| import tempfile | |
| import shutil | |
| class SAM2VideoSegmentationApp: | |
| def __init__(self, root): | |
| self.root = root | |
| self.root.title("SAM2 Video Segmentation Tool") | |
| # Enable high DPI scaling | |
| try: | |
| from ctypes import windll | |
| windll.shcore.SetProcessDpiAwareness(1) | |
| except: | |
| pass | |
| # Scale for high DPI displays | |
| self.scale_factor = 1.5 # Adjust this if needed (1.5-2.0 for 5K) | |
| self.root.tk.call('tk', 'scaling', self.scale_factor * 72 / 96) | |
| self.root.geometry("1800x1200") | |
| # Initialize variables | |
| self.video_path = None | |
| self.video_folder = None | |
| self.frame_names = [] | |
| self.current_frame_idx = 0 | |
| self.points = [] | |
| self.labels = [] | |
| self.predictor = None | |
| self.inference_state = None | |
| self.video_segments = None | |
| self.preview_mask = None | |
| self.current_step = 1 # 1: Select video, 2: Annotate, 3: Preview mask, 4: Propagate, 5: Export | |
| # Setup device | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| torch.autocast("cuda", dtype=torch.bfloat16).__enter__() | |
| else: | |
| self.device = torch.device("cpu") | |
| # SAM2 model paths | |
| self.sam_2_checkpoint = "C:\\Users\\Filip\\Documents\\Git\\sam2\\checkpoints\\sam2.1_hiera_large.pt" | |
| self.model_cfg = "C:\\Users\\Filip\\Documents\\Git\\sam2\\sam2\\configs\\sam2.1\\sam2.1_hiera_l.yaml" | |
| self.setup_ui() | |
| self.update_step_display() | |
| def setup_ui(self): | |
| # Configure font sizes for high DPI | |
| default_font = ('Segoe UI', 11) | |
| button_font = ('Segoe UI', 11, 'bold') | |
| title_font = ('Segoe UI', 12, 'bold') | |
| self.root.option_add('*Font', default_font) | |
| self.root.option_add('*Button.Font', button_font) | |
| self.root.option_add('*Label.Font', default_font) | |
| # Main container | |
| main_frame = ttk.Frame(self.root, padding="15") | |
| main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) | |
| # Configure grid weights | |
| self.root.columnconfigure(0, weight=1) | |
| self.root.rowconfigure(0, weight=1) | |
| main_frame.columnconfigure(1, weight=1) | |
| main_frame.rowconfigure(1, weight=1) | |
| # Control panel (left side) | |
| control_frame = ttk.LabelFrame(main_frame, text="Workflow Steps", padding="15") | |
| control_frame.grid(row=0, column=0, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), padx=(0, 15)) | |
| # Step indicator | |
| self.step_label = ttk.Label(control_frame, text="", font=('Segoe UI', 13, 'bold')) | |
| self.step_label.pack(fill=tk.X, pady=(0, 15)) | |
| ttk.Separator(control_frame, orient='horizontal').pack(fill=tk.X, pady=10) | |
| # Step 1: Video Selection | |
| self.step1_frame = ttk.LabelFrame(control_frame, text="Step 1: Select Video", padding="10") | |
| self.step1_frame.pack(fill=tk.X, pady=8) | |
| ttk.Button(self.step1_frame, text="Browse Video File", command=self.select_video).pack(fill=tk.X, pady=5) | |
| self.video_label = ttk.Label(self.step1_frame, text="No video selected", wraplength=300) | |
| self.video_label.pack(fill=tk.X, pady=5) | |
| # Step 2: Annotation | |
| self.step2_frame = ttk.LabelFrame(control_frame, text="Step 2: Annotate Frame", padding="10") | |
| self.step2_frame.pack(fill=tk.X, pady=8) | |
| ttk.Label(self.step2_frame, text="Frame Navigation:").pack(anchor=tk.W) | |
| frame_nav_frame = ttk.Frame(self.step2_frame) | |
| frame_nav_frame.pack(fill=tk.X, pady=5) | |
| ttk.Button(frame_nav_frame, text="<<", command=self.prev_frame, width=5).pack(side=tk.LEFT) | |
| self.frame_label = ttk.Label(frame_nav_frame, text="Frame: 0/0") | |
| self.frame_label.pack(side=tk.LEFT, padx=10) | |
| ttk.Button(frame_nav_frame, text=">>", command=self.next_frame, width=5).pack(side=tk.LEFT) | |
| ttk.Label(self.step2_frame, text="Click on frame to add points:").pack(anchor=tk.W, pady=(10, 0)) | |
| self.annotation_mode = tk.StringVar(value="positive") | |
| ttk.Radiobutton(self.step2_frame, text="Positive Point (Green)", | |
| variable=self.annotation_mode, value="positive").pack(anchor=tk.W) | |
| ttk.Radiobutton(self.step2_frame, text="Negative Point (Red)", | |
| variable=self.annotation_mode, value="negative").pack(anchor=tk.W) | |
| ttk.Label(self.step2_frame, text="Manual Coordinates:").pack(anchor=tk.W, pady=(10, 0)) | |
| coord_frame = ttk.Frame(self.step2_frame) | |
| coord_frame.pack(fill=tk.X, pady=5) | |
| ttk.Label(coord_frame, text="X:").grid(row=0, column=0) | |
| self.coord_x = ttk.Entry(coord_frame, width=10) | |
| self.coord_x.grid(row=0, column=1, padx=2) | |
| ttk.Label(coord_frame, text="Y:").grid(row=0, column=2, padx=(10, 0)) | |
| self.coord_y = ttk.Entry(coord_frame, width=10) | |
| self.coord_y.grid(row=0, column=3, padx=2) | |
| ttk.Button(self.step2_frame, text="Add Point", command=self.add_manual_point).pack(fill=tk.X, pady=5) | |
| self.points_label = ttk.Label(self.step2_frame, text="Points: 0", foreground="blue") | |
| self.points_label.pack(fill=tk.X, pady=5) | |
| ttk.Button(self.step2_frame, text="Clear All Points", command=self.clear_points).pack(fill=tk.X, pady=5) | |
| self.confirm_annotation_btn = ttk.Button(self.step2_frame, text="Confirm Annotation →", | |
| command=self.confirm_annotation, style='Accent.TButton') | |
| self.confirm_annotation_btn.pack(fill=tk.X, pady=(10, 5)) | |
| # Step 3: Preview Mask | |
| self.step3_frame = ttk.LabelFrame(control_frame, text="Step 3: Preview Mask", padding="10") | |
| self.step3_frame.pack(fill=tk.X, pady=8) | |
| self.preview_info = ttk.Label(self.step3_frame, text="Preview the segmentation mask\nfor the current frame", | |
| justify=tk.LEFT) | |
| self.preview_info.pack(fill=tk.X, pady=5) | |
| ttk.Button(self.step3_frame, text="← Back to Annotation", | |
| command=self.back_to_annotation).pack(fill=tk.X, pady=5) | |
| self.propagate_btn = ttk.Button(self.step3_frame, text="Propagate to All Frames →", | |
| command=self.run_segmentation, style='Accent.TButton') | |
| self.propagate_btn.pack(fill=tk.X, pady=5) | |
| # Step 4: Propagation | |
| self.step4_frame = ttk.LabelFrame(control_frame, text="Step 4: Propagate Segmentation", padding="10") | |
| self.step4_frame.pack(fill=tk.X, pady=8) | |
| self.propagate_info = ttk.Label(self.step4_frame, text="Segmenting all frames...", justify=tk.LEFT) | |
| self.propagate_info.pack(fill=tk.X, pady=5) | |
| self.propagate_progress = ttk.Progressbar(self.step4_frame, mode='indeterminate') | |
| self.propagate_progress.pack(fill=tk.X, pady=5) | |
| # Step 5: Export | |
| self.step5_frame = ttk.LabelFrame(control_frame, text="Step 5: Export Results", padding="10") | |
| self.step5_frame.pack(fill=tk.X, pady=8) | |
| self.export_info = ttk.Label(self.step5_frame, text="Segmentation complete!\nReady to export masks.", | |
| justify=tk.LEFT) | |
| self.export_info.pack(fill=tk.X, pady=5) | |
| ttk.Button(self.step5_frame, text="Export Masks", command=self.export_masks).pack(fill=tk.X, pady=5) | |
| ttk.Button(self.step5_frame, text="Start Over", command=self.reset_app).pack(fill=tk.X, pady=5) | |
| ttk.Separator(control_frame, orient='horizontal').pack(fill=tk.X, pady=15) | |
| # Status | |
| self.status_label = ttk.Label(control_frame, text="Status: Ready", foreground="green", | |
| font=('Segoe UI', 11, 'bold')) | |
| self.status_label.pack(fill=tk.X, pady=10) | |
| # Canvas frame (right side) | |
| canvas_frame = ttk.LabelFrame(main_frame, text="Video Frame", padding="15") | |
| canvas_frame.grid(row=0, column=1, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S)) | |
| # Create matplotlib figure with larger size | |
| self.fig, self.ax = plt.subplots(figsize=(12, 9)) | |
| self.fig.patch.set_facecolor('#f0f0f0') | |
| self.canvas = FigureCanvasTkAgg(self.fig, master=canvas_frame) | |
| self.canvas.draw() | |
| self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True) | |
| # Bind click event | |
| self.canvas.mpl_connect('button_press_event', self.on_canvas_click) | |
| def update_step_display(self): | |
| """Update UI to show only relevant controls for current step""" | |
| # Hide all step frames | |
| for frame in [self.step1_frame, self.step2_frame, self.step3_frame, | |
| self.step4_frame, self.step5_frame]: | |
| frame.pack_forget() | |
| # Show relevant frames based on current step | |
| if self.current_step == 1: | |
| self.step_label.config(text="📹 Step 1: Select Video File") | |
| self.step1_frame.pack(fill=tk.X, pady=8) | |
| elif self.current_step == 2: | |
| self.step_label.config(text="🎯 Step 2: Annotate Frame") | |
| self.step1_frame.pack(fill=tk.X, pady=8) | |
| self.step2_frame.pack(fill=tk.X, pady=8) | |
| elif self.current_step == 3: | |
| self.step_label.config(text="👁️ Step 3: Preview Mask") | |
| self.step3_frame.pack(fill=tk.X, pady=8) | |
| elif self.current_step == 4: | |
| self.step_label.config(text="⚡ Step 4: Propagating...") | |
| self.step4_frame.pack(fill=tk.X, pady=8) | |
| elif self.current_step == 5: | |
| self.step_label.config(text="✅ Step 5: Export Results") | |
| self.step5_frame.pack(fill=tk.X, pady=8) | |
| def select_video(self): | |
| filepath = filedialog.askopenfilename( | |
| title="Select Video File", | |
| filetypes=[("Video files", "*.mp4 *.avi *.mov *.mkv"), ("All files", "*.*")] | |
| ) | |
| if filepath: | |
| self.video_path = filepath | |
| self.video_label.config(text=f"✓ {os.path.basename(filepath)}") | |
| self.extract_frames() | |
| def extract_frames(self): | |
| try: | |
| self.status_label.config(text="Status: Extracting frames...", foreground="orange") | |
| self.root.update() | |
| # Create temporary folder for frames | |
| self.video_folder = tempfile.mkdtemp() | |
| # Extract frames using OpenCV | |
| cap = cv2.VideoCapture(self.video_path) | |
| frame_idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_path = os.path.join(self.video_folder, f"{frame_idx}.jpg") | |
| cv2.imwrite(frame_path, frame) | |
| frame_idx += 1 | |
| cap.release() | |
| # Get frame names | |
| self.frame_names = [ | |
| p for p in os.listdir(self.video_folder) | |
| if p.endswith(".jpg") or p.endswith(".png") | |
| ] | |
| self.frame_names.sort(key=lambda x: int(x.split(".")[0])) | |
| self.current_frame_idx = 0 | |
| self.current_step = 2 | |
| self.update_step_display() | |
| self.display_frame() | |
| # Initialize SAM2 automatically | |
| self.init_sam2() | |
| self.status_label.config(text=f"Status: Ready - {len(self.frame_names)} frames extracted", | |
| foreground="green") | |
| except Exception as e: | |
| self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red") | |
| messagebox.showerror("Error", f"Failed to extract frames: {str(e)}") | |
| def display_frame(self, show_mask=False): | |
| if not self.frame_names: | |
| return | |
| self.ax.clear() | |
| frame_path = os.path.join(self.video_folder, self.frame_names[self.current_frame_idx]) | |
| img = Image.open(frame_path) | |
| self.ax.imshow(img) | |
| # Show mask if in preview mode | |
| if show_mask and self.preview_mask is not None: | |
| self.show_mask(self.preview_mask, self.ax, obj_id=1) | |
| # Draw existing points | |
| for i, (point, label) in enumerate(zip(self.points, self.labels)): | |
| color = "green" if label == 1 else "red" | |
| self.ax.plot(point[0], point[1], marker="*", color=color, markersize=25, | |
| markeredgecolor="white", markeredgewidth=2.5) | |
| title = f"Frame {self.current_frame_idx}/{len(self.frame_names)-1}" | |
| if show_mask: | |
| title += " - Mask Preview" | |
| self.ax.set_title(title, fontsize=14, fontweight='bold') | |
| self.ax.axis('off') | |
| self.canvas.draw() | |
| self.frame_label.config(text=f"Frame: {self.current_frame_idx}/{len(self.frame_names)-1}") | |
| def show_mask(self, mask, ax, obj_id=None, random_color=False): | |
| """Display mask overlay on the image""" | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| cmap = plt.get_cmap("tab10") | |
| cmap_idx = 0 if obj_id is None else obj_id | |
| color = np.array([*cmap(cmap_idx)[:3], 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_image) | |
| def on_canvas_click(self, event): | |
| if event.inaxes != self.ax or not self.frame_names or self.current_step != 2: | |
| return | |
| x, y = int(event.xdata), int(event.ydata) | |
| label = 1 if self.annotation_mode.get() == "positive" else 0 | |
| self.points.append([x, y]) | |
| self.labels.append(label) | |
| self.display_frame() | |
| self.update_points_label() | |
| def add_manual_point(self): | |
| try: | |
| x = int(self.coord_x.get()) | |
| y = int(self.coord_y.get()) | |
| label = 1 if self.annotation_mode.get() == "positive" else 0 | |
| self.points.append([x, y]) | |
| self.labels.append(label) | |
| self.coord_x.delete(0, tk.END) | |
| self.coord_y.delete(0, tk.END) | |
| self.display_frame() | |
| self.update_points_label() | |
| except ValueError: | |
| messagebox.showerror("Error", "Please enter valid integer coordinates") | |
| def clear_points(self): | |
| self.points = [] | |
| self.labels = [] | |
| self.display_frame() | |
| self.update_points_label() | |
| def update_points_label(self): | |
| pos_count = sum(1 for l in self.labels if l == 1) | |
| neg_count = sum(1 for l in self.labels if l == 0) | |
| self.points_label.config(text=f"Points: {len(self.points)} (✓ {pos_count}, ✗ {neg_count})") | |
| def prev_frame(self): | |
| if self.current_frame_idx > 0: | |
| self.current_frame_idx -= 1 | |
| self.display_frame() | |
| def next_frame(self): | |
| if self.current_frame_idx < len(self.frame_names) - 1: | |
| self.current_frame_idx += 1 | |
| self.display_frame() | |
| def init_sam2(self): | |
| """Initialize SAM2 model""" | |
| try: | |
| self.status_label.config(text="Status: Initializing SAM2...", foreground="orange") | |
| self.root.update() | |
| self.predictor = build_sam2_video_predictor( | |
| self.model_cfg, | |
| self.sam_2_checkpoint, | |
| device=self.device | |
| ) | |
| self.inference_state = self.predictor.init_state(video_path=self.video_folder) | |
| self.predictor.reset_state(self.inference_state) | |
| self.status_label.config(text="Status: SAM2 ready", foreground="green") | |
| except Exception as e: | |
| self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red") | |
| messagebox.showerror("Error", f"Failed to initialize SAM2: {str(e)}") | |
| def confirm_annotation(self): | |
| """Generate preview mask for current frame""" | |
| if not self.predictor: | |
| messagebox.showwarning("Warning", "SAM2 not initialized") | |
| return | |
| if not self.points: | |
| messagebox.showwarning("Warning", "Please add at least one point annotation") | |
| return | |
| try: | |
| self.status_label.config(text="Status: Generating mask preview...", foreground="orange") | |
| self.root.update() | |
| # Add points to predictor | |
| points_array = np.array(self.points, dtype=np.float32) | |
| labels_array = np.array(self.labels, dtype=np.int32) | |
| _, out_obj_ids, out_mask_logits = self.predictor.add_new_points( | |
| inference_state=self.inference_state, | |
| frame_idx=self.current_frame_idx, | |
| points=points_array, | |
| labels=labels_array, | |
| obj_id=1, | |
| ) | |
| # Store the preview mask | |
| self.preview_mask = (out_mask_logits[0] > 0.0).cpu().numpy() | |
| # Move to preview step | |
| self.current_step = 3 | |
| self.update_step_display() | |
| self.display_frame(show_mask=True) | |
| self.status_label.config(text="Status: Preview ready", foreground="green") | |
| except Exception as e: | |
| self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red") | |
| messagebox.showerror("Error", f"Failed to generate preview: {str(e)}") | |
| def back_to_annotation(self): | |
| """Go back to annotation step""" | |
| self.current_step = 2 | |
| self.preview_mask = None | |
| self.update_step_display() | |
| self.display_frame() | |
| # Reset the inference state | |
| if self.predictor: | |
| self.predictor.reset_state(self.inference_state) | |
| def run_segmentation(self): | |
| """Propagate segmentation to all frames""" | |
| if not self.predictor or self.preview_mask is None: | |
| messagebox.showwarning("Warning", "Please confirm annotation first") | |
| return | |
| try: | |
| self.current_step = 4 | |
| self.update_step_display() | |
| self.status_label.config(text="Status: Propagating to all frames...", foreground="orange") | |
| self.propagate_progress.start() | |
| self.root.update() | |
| # Propagate through video | |
| self.video_segments = {} | |
| for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( | |
| self.inference_state | |
| ): | |
| self.video_segments[out_frame_idx] = { | |
| out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | |
| for i, out_obj_id in enumerate(out_obj_ids) | |
| } | |
| # Update progress | |
| progress_pct = (out_frame_idx + 1) / len(self.frame_names) * 100 | |
| self.propagate_info.config(text=f"Processing: {out_frame_idx + 1}/{len(self.frame_names)} frames\n({progress_pct:.1f}%)") | |
| self.root.update() | |
| self.propagate_progress.stop() | |
| # Move to export step | |
| self.current_step = 5 | |
| self.update_step_display() | |
| self.status_label.config( | |
| text=f"Status: Segmentation complete! {len(self.video_segments)} frames processed", | |
| foreground="green" | |
| ) | |
| messagebox.showinfo("Success", f"Segmentation completed for {len(self.video_segments)} frames!") | |
| except Exception as e: | |
| self.propagate_progress.stop() | |
| self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red") | |
| messagebox.showerror("Error", f"Failed to run segmentation: {str(e)}") | |
| def export_masks(self): | |
| """Export masks to selected directory""" | |
| if not self.video_segments: | |
| messagebox.showwarning("Warning", "No segmentation results to export") | |
| return | |
| output_dir = filedialog.askdirectory(title="Select Output Directory") | |
| if not output_dir: | |
| return | |
| try: | |
| self.status_label.config(text="Status: Exporting masks...", foreground="orange") | |
| self.root.update() | |
| output_mask_dir = os.path.join(output_dir, "output_masks") | |
| os.makedirs(output_mask_dir, exist_ok=True) | |
| for out_frame_idx in range(len(self.frame_names)): | |
| # Get the first mask to determine shape | |
| first_mask = next(iter(self.video_segments[out_frame_idx].values())) | |
| if len(first_mask.shape) == 3: | |
| h, w = first_mask.shape[-2:] | |
| combined_mask = np.zeros((h, w), dtype=np.uint8) | |
| else: | |
| combined_mask = np.zeros_like(first_mask, dtype=np.uint8) | |
| # Combine all masks | |
| for out_mask in self.video_segments[out_frame_idx].values(): | |
| mask_2d = np.squeeze(out_mask) | |
| combined_mask = np.logical_or(combined_mask, mask_2d).astype(np.uint8) | |
| # Scale to 0-255 | |
| mask_img = combined_mask * 255 | |
| # Save as PNG | |
| cv2.imwrite( | |
| os.path.join(output_mask_dir, f"mask_{out_frame_idx:04d}.png"), | |
| mask_img | |
| ) | |
| self.status_label.config(text="Status: Masks exported successfully!", foreground="green") | |
| messagebox.showinfo("Success", f"Masks exported to:\n{output_mask_dir}") | |
| except Exception as e: | |
| self.status_label.config(text=f"Status: Error - {str(e)}", foreground="red") | |
| messagebox.showerror("Error", f"Failed to export masks: {str(e)}") | |
| def reset_app(self): | |
| """Reset the application to start over""" | |
| if messagebox.askyesno("Confirm", "Start over? This will clear all current progress."): | |
| # Clean up | |
| if self.video_folder and os.path.exists(self.video_folder): | |
| shutil.rmtree(self.video_folder) | |
| # Reset variables | |
| self.video_path = None | |
| self.video_folder = None | |
| self.frame_names = [] | |
| self.current_frame_idx = 0 | |
| self.points = [] | |
| self.labels = [] | |
| self.inference_state = None | |
| self.video_segments = None | |
| self.preview_mask = None | |
| self.current_step = 1 | |
| # Reset UI | |
| self.video_label.config(text="No video selected") | |
| self.points_label.config(text="Points: 0") | |
| self.frame_label.config(text="Frame: 0/0") | |
| self.ax.clear() | |
| self.canvas.draw() | |
| self.update_step_display() | |
| self.status_label.config(text="Status: Ready", foreground="green") | |
| def __del__(self): | |
| # Cleanup temporary folder | |
| if self.video_folder and os.path.exists(self.video_folder): | |
| shutil.rmtree(self.video_folder) | |
| def main(): | |
| root = tk.Tk() | |
| app = SAM2VideoSegmentationApp(root) | |
| root.mainloop() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment