Created
December 4, 2025 00:11
-
-
Save DMelisena/5ae3c2d0aeefdc9a6924a4a8aee91008 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import os | |
| import sys | |
| import time | |
| from collections import defaultdict | |
| from datetime import datetime | |
| from pathlib import Path | |
| from types import SimpleNamespace | |
| import cv2 | |
| import json | |
| import numpy as np | |
| import shutil | |
| from tqdm import tqdm | |
| import process_detections_package as pdp | |
| import utils_package as up | |
| import constants as cfg | |
| # Add parent directory to path for importing ultralytics_patch | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from ultralytics_patch import ensure_custom_ultralytics_layers | |
| # Make sure custom blocks expected by exported YOLO weights are available | |
| ensure_custom_ultralytics_layers() | |
| from ultralytics import YOLO | |
| import torch | |
| # Allow importing the classification stack | |
| CLASSIFICATION_MODULE_DIR = ( | |
| Path(__file__).resolve().parents[2] / | |
| 'development' / | |
| 'classification' / | |
| 'detron_and_cross_entropy' / | |
| 'classification_rectangle_v9-2' | |
| ).resolve() | |
| if CLASSIFICATION_MODULE_DIR.exists(): | |
| module_path = str(CLASSIFICATION_MODULE_DIR) | |
| if module_path not in sys.path: | |
| sys.path.append(module_path) | |
| # Import Siamese network modules | |
| SIAMESE_MODULE_DIR = ( | |
| Path(__file__).resolve().parents[2] / | |
| 'development' / | |
| 'siamese' | |
| ).resolve() | |
| if SIAMESE_MODULE_DIR.exists(): | |
| siamese_path = str(SIAMESE_MODULE_DIR) | |
| if siamese_path not in sys.path: | |
| sys.path.append(siamese_path) | |
| try: | |
| from inference import EmbeddingClassifier | |
| from classify_fish import load_images_from_folder, classify_images | |
| from species_to_family import SPECIES_TO_FAMILY, FAMILY_NAMES | |
| except Exception as exc: | |
| EmbeddingClassifier = None | |
| load_images_from_folder = None | |
| classify_images = None | |
| SPECIES_TO_FAMILY = {} | |
| FAMILY_NAMES = [] | |
| CLASSIFICATION_IMPORT_ERROR = exc | |
| else: | |
| CLASSIFICATION_IMPORT_ERROR = None | |
| try: | |
| from example.models.siamese_network import SiameseNetwork | |
| from example.utils.model_io import load_model | |
| from example.transforms.custom_transforms import get_transforms | |
| from PIL import Image | |
| SIAMESE_AVAILABLE = True | |
| except Exception as exc: | |
| SIAMESE_AVAILABLE = False | |
| SIAMESE_IMPORT_ERROR = exc | |
| print(f"Warning: Siamese network not available: {exc}") | |
| class FishTracker: | |
| """Manages fish identification and tracking across frames.""" | |
| def __init__(self, similarity_threshold=0.85, siamese_model_path=None): | |
| self.similarity_threshold = similarity_threshold | |
| self.fish_database = defaultdict(list) # family -> list of fish instances | |
| self.fish_counter = defaultdict(int) # family -> next fish ID | |
| self.siamese_model = None | |
| self.transform = None | |
| if SIAMESE_AVAILABLE and siamese_model_path: | |
| self._load_siamese_model(siamese_model_path) | |
| def _load_siamese_model(self, model_path): | |
| """Load the Siamese network model.""" | |
| try: | |
| if Path(model_path).exists(): | |
| import torch | |
| device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') | |
| self.siamese_model = load_model(SiameseNetwork(), model_path, device) | |
| self.siamese_model.eval() | |
| self.transform = get_transforms()['val'] | |
| print(f"Siamese model loaded from {model_path} on device: {device}") | |
| else: | |
| print(f"Siamese model not found at {model_path}") | |
| except Exception as e: | |
| print(f"Error loading Siamese model: {e}") | |
| self.siamese_model = None | |
| def _compute_siamese_similarity(self, img1_array, img2_array): | |
| """Compute similarity between two images using Siamese network.""" | |
| if self.siamese_model is None or self.transform is None: | |
| return 0.0 | |
| try: | |
| # Convert numpy arrays to PIL Images | |
| img1_pil = Image.fromarray(img1_array) | |
| img2_pil = Image.fromarray(img2_array) | |
| # Apply transforms | |
| img1_tensor = self.transform(img1_pil).unsqueeze(0) | |
| img2_tensor = self.transform(img2_pil).unsqueeze(0) | |
| # Move tensors to the same device as the model | |
| device = next(self.siamese_model.parameters()).device | |
| img1_tensor = img1_tensor.to(device) | |
| img2_tensor = img2_tensor.to(device) | |
| # Get embeddings | |
| with torch.no_grad(): | |
| emb1 = self.siamese_model.forward_once(img1_tensor) | |
| emb2 = self.siamese_model.forward_once(img2_tensor) | |
| # Compute cosine similarity | |
| similarity = torch.nn.functional.cosine_similarity(emb1, emb2).item() | |
| return similarity | |
| except Exception as e: | |
| print(f"Error computing Siamese similarity: {e}") | |
| return 0.0 | |
| def _compute_basic_similarity(self, img1_array, img2_array): | |
| """Compute basic similarity using histogram comparison.""" | |
| try: | |
| # Resize images to same size | |
| size = (128, 128) | |
| img1_resized = cv2.resize(img1_array, size) | |
| img2_resized = cv2.resize(img2_array, size) | |
| # Compute histograms | |
| hist1 = cv2.calcHist([img1_resized], [0, 1, 2], None, [32, 32, 32], [0, 256, 0, 256, 0, 256]) | |
| hist2 = cv2.calcHist([img2_resized], [0, 1, 2], None, [32, 32, 32], [0, 256, 0, 256, 0, 256]) | |
| # Normalize histograms | |
| hist1 = cv2.normalize(hist1, hist1).flatten() | |
| hist2 = cv2.normalize(hist2, hist2).flatten() | |
| # Compute correlation | |
| similarity = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL) | |
| return similarity | |
| except Exception as e: | |
| print(f"Error computing basic similarity: {e}") | |
| return 0.0 | |
| def _bbox_iou(self, boxA, boxB): | |
| """Compute IoU between two [x1, y1, x2, y2] boxes.""" | |
| xA = max(boxA[0], boxB[0]) | |
| yA = max(boxA[1], boxB[1]) | |
| xB = min(boxA[2], boxB[2]) | |
| yB = min(boxA[3], boxB[3]) | |
| inter_w = max(0, xB - xA) | |
| inter_h = max(0, yB - yA) | |
| inter_area = inter_w * inter_h | |
| if inter_area <= 0: | |
| return 0.0 | |
| boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) | |
| boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) | |
| return inter_area / float(boxA_area + boxB_area - inter_area) | |
| def find_or_create_fish_id(self, family, img_array, crop_path, frame_idx, bbox): | |
| """Find existing fish or create new ID, using Siamese only for ambiguous matches.""" | |
| HIGH_IOU = 0.6 | |
| LOW_IOU = 0.3 | |
| MAX_FRAME_GAP = 10 | |
| AMBIG_MARGIN = 0.05 | |
| def _create_new_fish(): | |
| # New fish of this family | |
| fish_id = f"{family}_fish_{self.fish_counter[family]:03d}" | |
| self.fish_counter[family] += 1 | |
| self.fish_database[family].append({ | |
| 'id': fish_id, | |
| 'images': [crop_path], | |
| 'reference_image': img_array, | |
| 'first_seen': datetime.now(), | |
| 'last_seen': datetime.now(), | |
| 'occurrence_count': 1, | |
| 'last_bbox': bbox, | |
| 'last_frame': frame_idx | |
| }) | |
| return fish_id, True # New fish | |
| fish_list = self.fish_database.get(family, []) | |
| if not fish_list: | |
| return _create_new_fish() | |
| candidates = [] | |
| best_iou = 0.0 | |
| # Cheap gating: only compare to recent fish with overlapping boxes | |
| for fish in fish_list: | |
| last_frame = fish.get('last_frame') | |
| last_bbox = fish.get('last_bbox') | |
| if last_frame is None or last_bbox is None: | |
| continue | |
| if abs(frame_idx - last_frame) > MAX_FRAME_GAP: | |
| continue | |
| iou = self._bbox_iou(bbox, last_bbox) | |
| if iou <= 0: | |
| continue | |
| candidates.append((fish, iou)) | |
| best_iou = max(best_iou, iou) | |
| # No viable candidates by geometry | |
| if not candidates: | |
| return _create_new_fish() | |
| # Strong, unambiguous IoU match: skip Siamese | |
| strong_candidates = [c for c in candidates if c[1] >= HIGH_IOU] | |
| if len(strong_candidates) == 1: | |
| fish, _ = strong_candidates[0] | |
| fish['images'].append(crop_path) | |
| fish['last_seen'] = datetime.now() | |
| fish['occurrence_count'] += 1 | |
| fish['last_bbox'] = bbox | |
| fish['last_frame'] = frame_idx | |
| return fish['id'], False | |
| # Low IoU and no Siamese rescue requested | |
| if best_iou < LOW_IOU and self.siamese_model is None: | |
| return _create_new_fish() | |
| # Ambiguous: run Siamese (or basic similarity if Siamese unavailable) on near-best candidates | |
| near_best = [ | |
| (fish, iou) for fish, iou in candidates | |
| if best_iou - iou <= AMBIG_MARGIN and iou >= LOW_IOU | |
| ] | |
| best_match = None | |
| best_similarity = -1.0 | |
| for fish, _ in near_best: | |
| if self.siamese_model is not None: | |
| similarity = self._compute_siamese_similarity(img_array, fish['reference_image']) | |
| else: | |
| similarity = self._compute_basic_similarity(img_array, fish['reference_image']) | |
| if similarity > best_similarity: | |
| best_similarity = similarity | |
| best_match = fish | |
| if best_match is not None and best_similarity >= self.similarity_threshold: | |
| best_match['images'].append(crop_path) | |
| best_match['last_seen'] = datetime.now() | |
| best_match['occurrence_count'] += 1 | |
| best_match['last_bbox'] = bbox | |
| best_match['last_frame'] = frame_idx | |
| return best_match['id'], False # Existing fish | |
| return _create_new_fish() | |
| def register_tracked_fish(self, family, track_id, img_array, crop_path, frame_idx, bbox): | |
| """Register fish using an external tracker ID (e.g., ByteTrack).""" | |
| if track_id is None: | |
| return self.find_or_create_fish_id(family, img_array, crop_path, frame_idx, bbox) | |
| try: | |
| track_val = float(track_id) | |
| if np.isnan(track_val): | |
| return self.find_or_create_fish_id(family, img_array, crop_path, frame_idx, bbox) | |
| track_int = int(track_val) | |
| except (TypeError, ValueError): | |
| return self.find_or_create_fish_id(family, img_array, crop_path, frame_idx, bbox) | |
| fish_id = f"{family}_track_{track_int:05d}" | |
| fish_list = self.fish_database[family] | |
| for fish in fish_list: | |
| if fish['id'] == fish_id: | |
| fish['images'].append(crop_path) | |
| fish['last_seen'] = datetime.now() | |
| fish['occurrence_count'] += 1 | |
| fish['last_bbox'] = bbox | |
| fish['last_frame'] = frame_idx | |
| fish['track_id'] = track_int | |
| return fish_id, False | |
| fish_list.append({ | |
| 'id': fish_id, | |
| 'track_id': track_int, | |
| 'images': [crop_path], | |
| 'reference_image': img_array, | |
| 'first_seen': datetime.now(), | |
| 'last_seen': datetime.now(), | |
| 'occurrence_count': 1, | |
| 'last_bbox': bbox, | |
| 'last_frame': frame_idx | |
| }) | |
| return fish_id, True | |
| def get_tracking_summary(self): | |
| """Get summary of tracked fish.""" | |
| summary = {} | |
| for family, fish_list in self.fish_database.items(): | |
| summary[family] = { | |
| 'total_individuals': len(fish_list), | |
| 'fish': [] | |
| } | |
| for fish in fish_list: | |
| summary[family]['fish'].append({ | |
| 'id': fish['id'], | |
| 'occurrences': fish['occurrence_count'], | |
| 'first_seen': str(fish['first_seen']), | |
| 'last_seen': str(fish['last_seen']), | |
| 'num_images': len(fish['images']), | |
| 'track_id': fish.get('track_id') | |
| }) | |
| return summary | |
| class DetectionProcessor: | |
| """Process YOLO detections, classify, and track fish.""" | |
| process_detections = pdp.process_detections | |
| def __init__(self, args): | |
| self.args = args | |
| self.output_dir = Path(self.args.output) | |
| self.crops_dir = self.output_dir / 'crops' | |
| self.families_dir = self.output_dir / 'families' | |
| self.tracking_dir = self.output_dir / 'tracking' | |
| self.analysis_dir = self.output_dir / 'analysis' | |
| annotated_output = getattr(self.args, 'annotated_output', None) | |
| self.annotated_output = Path(annotated_output) if annotated_output else None | |
| self.video_writer = None | |
| self.video_meta = {} | |
| for dir_path in [self.output_dir, self.crops_dir, self.families_dir, | |
| self.tracking_dir, self.analysis_dir]: | |
| dir_path.mkdir(parents=True, exist_ok=True) | |
| self.detections = [] | |
| self.predictions = [] | |
| self.crop_paths = [] | |
| self.class_names = {} | |
| self.species_to_family = SPECIES_TO_FAMILY or {} | |
| self.use_bytetrack = True | |
| self.tracker_config = getattr(self.args, 'tracker_config', 'bytetrack.yaml') | |
| self.track_db = defaultdict(dict) # family -> fish_id -> record | |
| print("Tracking method: ByteTrack (Ultralytics tracker IDs), single-pass pipeline.") | |
| # Classification state | |
| self.classifier = None | |
| self.classifier_classes = [] | |
| self.family_to_id = {} | |
| self.classification_batch_size = getattr(self.args, 'classification_batch_size', cfg.CLASSIFICATION_BATCH_SIZE) | |
| self.classification_enabled = bool(self.args.classify) | |
| if self.classification_enabled: | |
| self._initialize_embedding_classifier() | |
| else: | |
| print("Classification disabled via CLI flag.") | |
| def _initialize_embedding_classifier(self): | |
| """Load the embedding classifier.""" | |
| if EmbeddingClassifier is None: | |
| reason = f"{CLASSIFICATION_IMPORT_ERROR}" if CLASSIFICATION_IMPORT_ERROR else "unknown import error" | |
| print(f"Classification modules unavailable ({reason}). Skipping classifier load.") | |
| return | |
| model_path = Path(getattr(self.args, 'classification_model', cfg.CLASSIFICATION_MODEL_PATH)) | |
| dataset_path = Path(getattr(self.args, 'classification_dataset', cfg.CLASSIFICATION_DATASET_PATH)) | |
| missing_paths = [str(path) for path in (model_path, dataset_path) if not Path(path).exists()] | |
| if missing_paths: | |
| print(f"Classification assets missing: {', '.join(missing_paths)}") | |
| return | |
| config = { | |
| 'log_level': getattr(self.args, 'classification_log_level', cfg.CLASSIFICATION_LOG_LEVEL), | |
| 'dataset': {'path': str(dataset_path)}, | |
| 'model': { | |
| 'path': str(model_path), | |
| 'device': getattr(self.args, 'classification_device', cfg.CLASSIFICATION_DEVICE) | |
| } | |
| } | |
| try: | |
| print(f"Loading embedding classifier from {model_path} ...") | |
| self.classifier = EmbeddingClassifier(config) | |
| base_classes = list(FAMILY_NAMES) if FAMILY_NAMES else [] | |
| self.classifier_classes = base_classes.copy() | |
| self.family_to_id = {name: idx for idx, name in enumerate(self.classifier_classes)} | |
| if "Unknown" not in self.family_to_id: | |
| self.family_to_id["Unknown"] = len(self.classifier_classes) | |
| self.classifier_classes.append("Unknown") | |
| if not self.species_to_family: | |
| self.species_to_family = SPECIES_TO_FAMILY | |
| print(f"Classifier loaded with {len(self.classifier_classes)} family labels.") | |
| except Exception as exc: | |
| print(f"Warning: Failed to initialize embedding classifier: {exc}") | |
| self.classifier = None | |
| def _init_video_writer(self, width, height, fps): | |
| """Initialize video writer for annotated output if requested.""" | |
| if not self.annotated_output or self.video_writer: | |
| return | |
| width, height = int(width), int(height) | |
| if width <= 0 or height <= 0: | |
| return | |
| self.annotated_output.parent.mkdir(parents=True, exist_ok=True) | |
| safe_fps = fps if fps and not np.isnan(fps) and fps > 0 else 30.0 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| self.video_writer = cv2.VideoWriter( | |
| str(self.annotated_output), | |
| fourcc, | |
| safe_fps, | |
| (width, height) | |
| ) | |
| self.video_meta = {'fps': safe_fps, 'size': (width, height)} | |
| def _annotate_frame(self, frame, boxes, confidences, class_ids, track_ids=None): | |
| """Draw bounding boxes and labels on a frame.""" | |
| if frame is None: | |
| return None | |
| annotated = frame.copy() | |
| for idx, (box, conf, class_id) in enumerate(zip(boxes, confidences, class_ids)): | |
| x1, y1, x2, y2 = map(int, box) | |
| color = (0, 255, 0) | |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2) | |
| label = self.class_names.get(class_id, f"class_{class_id}") | |
| track_text = "" | |
| if track_ids is not None and len(track_ids) > idx: | |
| try: | |
| tid_raw = float(track_ids[idx]) if track_ids[idx] is not None else None | |
| if tid_raw is not None and not np.isnan(tid_raw): | |
| track_text = f" id:{int(tid_raw)}" | |
| except Exception: | |
| track_text = "" | |
| text = f"{label}{track_text} {conf:.2f}" | |
| cv2.putText( | |
| annotated, | |
| text, | |
| (x1, max(0, y1 - 5)), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| color, | |
| 1, | |
| cv2.LINE_AA | |
| ) | |
| return annotated | |
| def close(self): | |
| """Release resources.""" | |
| if self.video_writer: | |
| self.video_writer.release() | |
| self.video_writer = None | |
| def _finalize_annotation(self): | |
| """Finalize annotated video if enabled.""" | |
| if self.video_writer: | |
| try: | |
| self.video_writer.release() | |
| except Exception: | |
| pass | |
| self.video_writer = None | |
| def process_and_track_detections(self, model, device, progress_callback=None): | |
| """Single-pass detection + ByteTrack + batched classification.""" | |
| print(f"\n๐ Single-pass detect/track/classify with ByteTrack...") | |
| cap = cv2.VideoCapture(self.args.source) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| cap.release() | |
| if width > 0 and height > 0: | |
| duration_seconds = None | |
| if total_frames and fps and fps > 0: | |
| duration_seconds = total_frames / fps | |
| self.video_meta = { | |
| 'fps': fps if fps and not np.isnan(fps) and fps > 0 else None, | |
| 'size': (width, height), | |
| 'total_frames': total_frames, | |
| 'duration_seconds': duration_seconds | |
| } | |
| if self.annotated_output: | |
| self._init_video_writer(width, height, fps) | |
| stream = model.track( | |
| source=self.args.source, | |
| conf=self.args.conf, | |
| iou=self.args.iou, | |
| device=device, | |
| tracker=self.tracker_config, | |
| persist=True, | |
| stream=True, | |
| verbose=False | |
| ) | |
| detection_count = 0 | |
| frame_idx = -1 | |
| last_time = time.time() | |
| last_progress = 0 | |
| pbar = tqdm(stream, total=total_frames if total_frames > 0 else None, desc="Detect+track", unit="frame") | |
| for result in pbar: | |
| frame_idx += 1 | |
| orig_img = result.orig_img | |
| if orig_img is None: | |
| continue | |
| # Emit incremental progress updates while running YOLO/ByteTrack | |
| if progress_callback and total_frames > 0: | |
| # Reserve 10โ90% for this stage; leave final 10% for post-processing | |
| progress = 10 + int(((frame_idx + 1) / total_frames) * 80) | |
| progress = max(10, min(progress, 90)) | |
| if progress > last_progress: | |
| last_progress = progress | |
| progress_callback(stage='yolo', progress=progress) | |
| if hasattr(result, 'names'): | |
| self.class_names.update(result.names) | |
| boxes = result.boxes | |
| if boxes is None or boxes.xyxy is None or len(boxes) == 0: | |
| if self.video_writer: | |
| self.video_writer.write(orig_img) | |
| continue | |
| xyxy = boxes.xyxy.cpu().numpy() | |
| confs = boxes.conf.cpu().numpy() if boxes.conf is not None else np.zeros(len(xyxy)) | |
| class_ids = boxes.cls.cpu().numpy().astype(int) | |
| track_ids = boxes.id.cpu().numpy() if boxes.id is not None else np.array([None] * len(class_ids)) | |
| frame_dets = [] | |
| for i, (box, conf, class_id, track_id) in enumerate(zip(xyxy, confs, class_ids, track_ids)): | |
| if conf < 0.3: | |
| continue | |
| x1, y1, x2, y2 = map(int, box) | |
| h, w = orig_img.shape[:2] | |
| x1 = max(0, x1 - self.args.padding) | |
| y1 = max(0, y1 - self.args.padding) | |
| x2 = min(w, x2 + self.args.padding) | |
| y2 = min(h, y2 + self.args.padding) | |
| crop_w, crop_h = x2 - x1, y2 - y1 | |
| if crop_w < self.args.min_crop_size or crop_h < self.args.min_crop_size: | |
| continue | |
| crop = orig_img[y1:y2, x1:x2] | |
| if crop.size == 0: | |
| continue | |
| crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) | |
| frame_dets.append({ | |
| 'frame': frame_idx, | |
| 'bbox': [x1, y1, x2, y2], | |
| 'crop_bgr': crop, | |
| 'crop_rgb': crop_rgb, | |
| 'confidence': float(conf), | |
| 'class_id': class_id, | |
| 'class_name': self.class_names.get(class_id, f"class_{class_id}"), | |
| 'track_id': int(track_id) if track_id is not None and not np.isnan(track_id) else None, | |
| 'seq_idx': i | |
| }) | |
| if self.video_writer: | |
| annotated_frame = result.plot() if hasattr(result, 'plot') else orig_img | |
| if annotated_frame is None: | |
| annotated_frame = orig_img | |
| self.video_writer.write(annotated_frame) | |
| if not frame_dets: | |
| continue | |
| batch_imgs = [det['crop_rgb'] for det in frame_dets] | |
| if self.classifier and self.classification_enabled: | |
| try: | |
| predictions_batch = self.classifier.inference_numpy_batch(batch_imgs) | |
| except Exception as exc: | |
| print(f"Classification batch failed on frame {frame_idx}: {exc}") | |
| predictions_batch = [None] * len(frame_dets) | |
| else: | |
| predictions_batch = [None] * len(frame_dets) | |
| for det_idx, det in enumerate(frame_dets): | |
| predictions = predictions_batch[det_idx] if det_idx < len(predictions_batch) else None | |
| classification_guesses = [] | |
| species_name = None | |
| family_name = None | |
| if predictions: | |
| sorted_preds = sorted( | |
| predictions, | |
| key=lambda p: getattr(p, 'accuracy', 0.0), | |
| reverse=True | |
| ) | |
| if sorted_preds: | |
| best_pred = sorted_preds[0] | |
| best_conf = float(getattr(best_pred, 'accuracy', 0.0)) | |
| # Treat low-confidence predictions as unidentified | |
| if best_conf < 0.70: | |
| family_name = "unidentified fish" | |
| species_name = "unidentified fish" | |
| else: | |
| species_name = best_pred.name | |
| for pred in sorted_preds[:2]: | |
| fam_name = self.species_to_family.get(pred.name, pred.name) | |
| classification_guesses.append({ | |
| 'family': fam_name, | |
| 'confidence': float(getattr(pred, 'accuracy', 0.0)) | |
| }) | |
| if family_name is None: | |
| family_name = self.species_to_family.get(species_name, None) if species_name else None | |
| if not family_name: | |
| family_name = self.species_to_family.get(det['class_name'], det['class_name']) | |
| track_id = det['track_id'] | |
| fish_id = f"track_{track_id:05d}" if track_id is not None else f"det_{detection_count:06d}" | |
| family_dir = self.families_dir / family_name | |
| family_dir.mkdir(exist_ok=True) | |
| fish_dir = family_dir / fish_id | |
| fish_dir.mkdir(exist_ok=True) | |
| crop_filename = f"frame_{det['frame']:06d}_det_{det['seq_idx']:03d}.jpg" | |
| crop_path = fish_dir / crop_filename | |
| cv2.imwrite(str(crop_path), det['crop_bgr']) | |
| general_crop_path = self.crops_dir / f"{family_name}_{fish_id}_{crop_filename}" | |
| shutil.copy2(crop_path, general_crop_path) | |
| family_tracks = self.track_db[family_name] | |
| track_rec = family_tracks.get(fish_id, { | |
| 'images': [], | |
| 'occurrence_count': 0, | |
| 'first_seen': det['frame'], | |
| 'last_seen': det['frame'] | |
| }) | |
| track_rec['images'].append(str(crop_path)) | |
| track_rec['occurrence_count'] += 1 | |
| track_rec['last_seen'] = det['frame'] | |
| family_tracks[fish_id] = track_rec | |
| rel_crop_path = os.path.relpath(crop_path, self.output_dir) | |
| detection_info = { | |
| 'frame': det['frame'], | |
| 'crop_path': str(crop_path), | |
| 'image_url': f"/{rel_crop_path.replace(os.sep, '/')}", | |
| 'family': family_name, | |
| 'fish_id': fish_id, | |
| 'track_id': track_id, | |
| 'is_new_fish': False, | |
| 'confidence': det['confidence'], | |
| 'bbox': det['bbox'], | |
| 'crop_size': [det['bbox'][2] - det['bbox'][0], det['bbox'][3] - det['bbox'][1]], | |
| 'class_name': det['class_name'], | |
| 'species_name': species_name or det['class_name'], | |
| 'classification_guesses': classification_guesses, | |
| 'fps': self.video_meta.get('fps', None) | |
| } | |
| self.detections.append(detection_info) | |
| self.crop_paths.append(crop_path) | |
| detection_count += 1 | |
| pbar.set_postfix({ | |
| 'Tracked': detection_count, | |
| 'FPS': f'{1/(time.time()-last_time+1e-6):.1f}' | |
| }) | |
| last_time = time.time() | |
| self._finalize_annotation() | |
| print(f"\nโ Processed {detection_count} detections across all frames") | |
| return detection_count | |
| def save_tracking_results(self): | |
| """Save tracking results and analysis.""" | |
| # Save tracking summary | |
| summary = {} | |
| for family, fish_map in self.track_db.items(): | |
| summary[family] = { | |
| 'total_individuals': len(fish_map), | |
| 'fish': [] | |
| } | |
| for fish_id, rec in fish_map.items(): | |
| summary[family]['fish'].append({ | |
| 'id': fish_id, | |
| 'occurrences': rec.get('occurrence_count', 0), | |
| 'first_seen': rec.get('first_seen'), | |
| 'last_seen': rec.get('last_seen'), | |
| 'num_images': len(rec.get('images', [])) | |
| }) | |
| with open(self.analysis_dir / 'tracking_summary.json', 'w') as f: | |
| json.dump(summary, f, indent=2, default=str) | |
| # Save detailed tracking database | |
| detailed_db = {} | |
| for family, fish_map in self.track_db.items(): | |
| detailed_db[family] = [] | |
| for fish_id, rec in fish_map.items(): | |
| detailed_db[family].append({ | |
| 'id': fish_id, | |
| 'occurrences': rec.get('occurrence_count', 0), | |
| 'first_seen': rec.get('first_seen'), | |
| 'last_seen': rec.get('last_seen'), | |
| 'images': rec.get('images', []), | |
| 'track_id': fish_id | |
| }) | |
| with open(self.analysis_dir / 'tracking_database.json', 'w') as f: | |
| json.dump(detailed_db, f, indent=2) | |
| # Generate statistics | |
| stats = { | |
| 'total_families': len(self.track_db), | |
| 'total_individuals': sum(len(fish_map) for fish_map in self.track_db.values()), | |
| 'total_detections': len(self.detections), | |
| 'families': {} | |
| } | |
| for family, fish_map in self.track_db.items(): | |
| stats['families'][family] = { | |
| 'individuals': len(fish_map), | |
| 'total_occurrences': sum(rec.get('occurrence_count', 0) for rec in fish_map.values()), | |
| 'avg_occurrences_per_fish': ( | |
| sum(rec.get('occurrence_count', 0) for rec in fish_map.values()) / len(fish_map) | |
| if fish_map else 0 | |
| ) | |
| } | |
| with open(self.analysis_dir / 'tracking_statistics.json', 'w') as f: | |
| json.dump(stats, f, indent=2) | |
| print(f"\n๐ Tracking Results:") | |
| print(f" - Total families detected: {stats['total_families']}") | |
| print(f" - Total individual fish tracked: {stats['total_individuals']}") | |
| print(f" - Total detections: {stats['total_detections']}") | |
| print(f"\n Family breakdown:") | |
| for family, family_stats in stats['families'].items(): | |
| print(f" {family}: {family_stats['individuals']} individuals, " | |
| f"{family_stats['total_occurrences']} total occurrences") | |
| def save_detection_details(self): | |
| """Persist detection details for downstream consumers (e.g., API responses).""" | |
| details_path = self.analysis_dir / 'detection_details.json' | |
| payload = { | |
| 'detections': self.detections, | |
| 'meta': { | |
| 'fps': self.video_meta.get('fps'), | |
| 'output_dir': str(self.output_dir), | |
| 'total_frames': self.video_meta.get('total_frames'), | |
| 'duration_seconds': self.video_meta.get('duration_seconds'), | |
| 'width': self.video_meta.get('size', (None, None))[0] if self.video_meta.get('size') else None, | |
| 'height': self.video_meta.get('size', (None, None))[1] if self.video_meta.get('size') else None | |
| } | |
| } | |
| with open(details_path, 'w') as f: | |
| json.dump(payload, f, indent=2) | |
| return details_path | |
| def run_detection_pipeline(model_path, | |
| source, | |
| output_dir, | |
| conf=cfg.CONF_THRESHOLD, | |
| iou=cfg.IOU_THRESHOLD, | |
| device=cfg.DEVICE, | |
| min_crop_size=50, | |
| padding=cfg.PADDING, | |
| classify=True, | |
| tracking_method="bytetrack", | |
| tracker_config="bytetrack.yaml", | |
| similarity_threshold=0.85, | |
| siamese_model=None, | |
| annotated_output=None, | |
| progress_callback=None): | |
| """Programmatic entrypoint to run detection, tracking, and optional annotation.""" | |
| args = SimpleNamespace( | |
| model=str(model_path), | |
| source=str(source), | |
| output=str(output_dir), | |
| conf=conf, | |
| iou=iou, | |
| device=device, | |
| min_crop_size=min_crop_size, | |
| padding=padding, | |
| classify=classify, | |
| tracking_method=tracking_method, | |
| tracker_config=tracker_config, | |
| similarity_threshold=similarity_threshold, | |
| siamese_model=siamese_model or cfg.SIAMESE_MODEL_PATH, | |
| annotated_output=str(annotated_output) if annotated_output else None, | |
| classification_model=cfg.CLASSIFICATION_MODEL_PATH, | |
| classification_dataset=cfg.CLASSIFICATION_DATASET_PATH, | |
| classification_batch_size=cfg.CLASSIFICATION_BATCH_SIZE, | |
| classification_device=cfg.CLASSIFICATION_DEVICE, | |
| classification_log_level=cfg.CLASSIFICATION_LOG_LEVEL | |
| ) | |
| model_path = Path(args.model) | |
| if not model_path.exists(): | |
| return {'success': False, 'error': f"Model file not found: {model_path}"} | |
| source_path = Path(args.source) | |
| if not source_path.exists(): | |
| return {'success': False, 'error': f"Source not found: {source_path}"} | |
| processor = None | |
| try: | |
| processor = DetectionProcessor(args) | |
| device_to_use = up.get_device(args.device) | |
| model = YOLO(args.model) | |
| model.to(device_to_use) | |
| detection_count = processor.process_and_track_detections(model, device_to_use, progress_callback=progress_callback) | |
| if detection_count > 0: | |
| processor.save_tracking_results() | |
| detection_details_path = processor.save_detection_details() | |
| tracking_stats_path = processor.analysis_dir / 'tracking_statistics.json' | |
| tracking_summary_path = processor.analysis_dir / 'tracking_summary.json' | |
| def _build_fish_families(details_path): | |
| if not details_path or not Path(details_path).exists(): | |
| return [] | |
| try: | |
| with open(details_path, 'r') as f: | |
| details = json.load(f) | |
| except Exception: | |
| return [] | |
| detections = details.get('detections', []) | |
| fps = details.get('meta', {}).get('fps') or 30 | |
| families = defaultdict(lambda: defaultdict(list)) | |
| for det in detections: | |
| family = det.get('family', 'Unknown') | |
| fish_id = det.get('fish_id', 'fish_000') | |
| families[family][fish_id].append(det) | |
| def _ts(frame_idx): | |
| total_seconds = frame_idx / fps if fps else 0 | |
| hours = int(total_seconds // 3600) | |
| minutes = int((total_seconds % 3600) // 60) | |
| seconds = int(total_seconds % 60) | |
| frames = int(round((total_seconds - int(total_seconds)) * fps)) if fps else 0 | |
| return f"{hours:02d}:{minutes:02d}:{seconds:02d}:{frames:02d}" | |
| fish_families = [] | |
| for family, fish_map in families.items(): | |
| individual_fish = [] | |
| for fish_id, fish_detections in fish_map.items(): | |
| fish_entries = [] | |
| for det in fish_detections: | |
| rel_path = os.path.relpath(det.get('crop_path', ''), processor.output_dir) | |
| image_url = det.get('image_url') or f"/{rel_path.replace(os.sep, '/')}" | |
| frame_idx = det.get('frame', 0) | |
| fish_entries.append({ | |
| "imageURL": image_url, | |
| "timestamp": _ts(frame_idx), | |
| "objectDetection": round(float(det.get('confidence', 0)), 2), | |
| "confidence": [ | |
| { | |
| "familyName": guess.get('family', 'Unknown'), | |
| "classifyConfidence": round(float(guess.get('confidence', 0)), 2) | |
| } | |
| for guess in det.get('classification_guesses', [])[:2] | |
| ] | |
| }) | |
| if fish_entries: | |
| individual_fish.append({ | |
| "fishID": fish_id, | |
| "fish": fish_entries | |
| }) | |
| if individual_fish: | |
| fish_families.append({ | |
| "familyName": family, | |
| "fishCount": len(individual_fish), | |
| "individualFish": individual_fish | |
| }) | |
| return fish_families | |
| fish_families = _build_fish_families(detection_details_path) | |
| response_payload = { | |
| "success": True, | |
| "statusCode": 200, | |
| "message": "Upload processed successfully.", | |
| "processPrecentage": 100, | |
| "stage": "Finished", | |
| "data": { | |
| "annotatedVideoURL": (str(processor.annotated_output) if processor.annotated_output else None), | |
| "fishFamilies": fish_families, | |
| "trackingStatistics": str(tracking_stats_path) if tracking_stats_path.exists() else None, | |
| "trackingSummary": str(tracking_summary_path) if tracking_summary_path.exists() else None, | |
| "detectionDetails": str(detection_details_path) if detection_details_path else None, | |
| "outputDir": str(processor.output_dir), | |
| "totalDetections": detection_count | |
| } | |
| } | |
| return response_payload | |
| except Exception as exc: | |
| return {'success': False, 'error': str(exc)} | |
| finally: | |
| if processor: | |
| try: | |
| processor.close() | |
| except Exception: | |
| pass | |
| def main(): | |
| """Main function to run the detection, classification, and tracking pipeline.""" | |
| parser = argparse.ArgumentParser(description="Run YOLO detection, classification, and fish tracking pipeline") | |
| parser.add_argument('--model', type=str, default=cfg.MODEL_DEFAULT, | |
| help="Path to YOLO model file") | |
| parser.add_argument('--source', type=str, default=cfg.SOURCE_DEFAULT, | |
| help="Source video/image file or directory") | |
| parser.add_argument('--output', type=str, default=cfg.OUTPUT_DIR, | |
| help="Output directory") | |
| parser.add_argument('--conf', type=float, default=cfg.CONF_THRESHOLD, | |
| help="Confidence threshold (default: 0.25)") | |
| parser.add_argument('--iou', type=float, default=cfg.IOU_THRESHOLD, | |
| help="IoU threshold for NMS (default: 0.45)") | |
| parser.add_argument('--device', type=str, default=cfg.DEVICE, | |
| help="Device to run on (auto, cpu, cuda, mps)") | |
| parser.add_argument('--min-crop-size', type=int, default=50, | |
| help="Minimum crop size in pixels (default: 50)") | |
| parser.add_argument('--padding', type=int, default=cfg.PADDING, | |
| help="Padding around bounding box in pixels") | |
| parser.add_argument('--classify', action='store_true', default=True, | |
| help="Enable classification (default: True)") | |
| parser.add_argument('--tracking-method', type=str, default='bytetrack', | |
| choices=['siamese', 'bytetrack'], | |
| help="Tracking backend (ByteTrack recommended)") | |
| parser.add_argument('--tracker-config', type=str, default='bytetrack.yaml', | |
| help="Tracker config name/path when using --tracking-method=bytetrack") | |
| parser.add_argument('--similarity-threshold', type=float, default=0.85, | |
| help="Similarity threshold for fish matching (default: 0.85)") | |
| parser.add_argument('--siamese-model', type=str, default=None, | |
| help="Path to Siamese model checkpoint") | |
| parser.add_argument('--annotated-output', type=str, default=None, | |
| help="Optional path to save annotated video") | |
| args_cli = parser.parse_args() | |
| args = SimpleNamespace( | |
| model=args_cli.model, | |
| source=args_cli.source, | |
| output=args_cli.output, | |
| conf=args_cli.conf, | |
| iou=args_cli.iou, | |
| device=args_cli.device, | |
| min_crop_size=args_cli.min_crop_size, | |
| padding=args_cli.padding, | |
| classify=args_cli.classify, | |
| tracking_method=args_cli.tracking_method, | |
| tracker_config=args_cli.tracker_config, | |
| similarity_threshold=args_cli.similarity_threshold, | |
| siamese_model=args_cli.siamese_model or cfg.SIAMESE_MODEL_PATH, | |
| annotated_output=args_cli.annotated_output, | |
| classification_model=cfg.CLASSIFICATION_MODEL_PATH, | |
| classification_dataset=cfg.CLASSIFICATION_DATASET_PATH, | |
| classification_batch_size=cfg.CLASSIFICATION_BATCH_SIZE, | |
| classification_device=cfg.CLASSIFICATION_DEVICE, | |
| classification_log_level=cfg.CLASSIFICATION_LOG_LEVEL | |
| ) | |
| if not cfg.ULTRALYTICS_AVAILABLE: | |
| print("Error: ultralytics is required. Run 'pip install ultralytics'") | |
| sys.exit(1) | |
| if not Path(args.model).exists(): | |
| print(f"Error: Model file not found: {args.model}") | |
| sys.exit(1) | |
| source_path = Path(args.source) | |
| if not source_path.exists(): | |
| print(f"Error: Source not found: {args.source}") | |
| sys.exit(1) | |
| run_result = run_detection_pipeline( | |
| model_path=args.model, | |
| source=args.source, | |
| output_dir=args.output, | |
| conf=args.conf, | |
| iou=args.iou, | |
| device=args.device, | |
| min_crop_size=args.min_crop_size, | |
| padding=args.padding, | |
| classify=args.classify, | |
| tracking_method=args.tracking_method, | |
| tracker_config=args.tracker_config, | |
| similarity_threshold=args.similarity_threshold, | |
| siamese_model=args.siamese_model, | |
| annotated_output=args.annotated_output | |
| ) | |
| if not run_result.get('success'): | |
| print(f"An error occurred: {run_result.get('error', 'unknown error')}") | |
| sys.exit(1) | |
| detection_count = run_result.get('detection_count', 0) | |
| if detection_count > 0: | |
| print("\n๐ Processing completed successfully!") | |
| print(f"๐ Results saved to: {run_result.get('output_dir')}") | |
| print(f" - Families folder: {Path(run_result.get('output_dir')) / 'families'}") | |
| print(f" - Tracking analysis: {Path(run_result.get('output_dir')) / 'analysis'}") | |
| else: | |
| print("No detections found to process.") | |
| if run_result.get('annotated_video_path'): | |
| print(f"๐น Annotated video: {run_result['annotated_video_path']}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment