Skip to content

Instantly share code, notes, and snippets.

@DMelisena
Created December 4, 2025 00:23
Show Gist options
  • Select an option

  • Save DMelisena/200d334a0f5018325a1988871bf4b02f to your computer and use it in GitHub Desktop.

Select an option

Save DMelisena/200d334a0f5018325a1988871bf4b02f to your computer and use it in GitHub Desktop.
import argparse
import os
import sys
import time
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
import cv2
import json
import numpy as np
import shutil
from tqdm import tqdm
import process_detections_package as pdp
import utils_package as up
import constants as cfg
# Add parent directory to path for importing ultralytics_patch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from ultralytics_patch import ensure_custom_ultralytics_layers
# Make sure custom blocks expected by exported YOLO weights are available
ensure_custom_ultralytics_layers()
from ultralytics import YOLO
import torch
# Allow importing the classification stack
CLASSIFICATION_MODULE_DIR = (
Path(__file__).resolve().parents[2] /
'development' /
'classification' /
'detron_and_cross_entropy' /
'classification_rectangle_v9-2'
).resolve()
if CLASSIFICATION_MODULE_DIR.exists():
module_path = str(CLASSIFICATION_MODULE_DIR)
if module_path not in sys.path:
sys.path.append(module_path)
# Import Siamese network modules
SIAMESE_MODULE_DIR = (
Path(__file__).resolve().parents[2] /
'development' /
'siamese'
).resolve()
if SIAMESE_MODULE_DIR.exists():
siamese_path = str(SIAMESE_MODULE_DIR)
if siamese_path not in sys.path:
sys.path.append(siamese_path)
try:
from inference import EmbeddingClassifier
from classify_fish import load_images_from_folder, classify_images
from species_to_family import SPECIES_TO_FAMILY, FAMILY_NAMES
except Exception as exc:
EmbeddingClassifier = None
load_images_from_folder = None
classify_images = None
SPECIES_TO_FAMILY = {}
FAMILY_NAMES = []
CLASSIFICATION_IMPORT_ERROR = exc
else:
CLASSIFICATION_IMPORT_ERROR = None
try:
from example.models.siamese_network import SiameseNetwork
from example.utils.model_io import load_model
from example.transforms.custom_transforms import get_transforms
from PIL import Image
SIAMESE_AVAILABLE = True
except Exception as exc:
SIAMESE_AVAILABLE = False
SIAMESE_IMPORT_ERROR = exc
print(f"Warning: Siamese network not available: {exc}")
class FishTracker:
"""Manages fish identification and tracking across frames."""
def __init__(self, similarity_threshold=0.85, siamese_model_path=None):
self.similarity_threshold = similarity_threshold
self.fish_database = defaultdict(list) # family -> list of fish instances
self.fish_counter = defaultdict(int) # family -> next fish ID
self.siamese_model = None
self.transform = None
if SIAMESE_AVAILABLE and siamese_model_path:
self._load_siamese_model(siamese_model_path)
def _load_siamese_model(self, model_path):
"""Load the Siamese network model."""
try:
if Path(model_path).exists():
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
self.siamese_model = load_model(SiameseNetwork(), model_path, device)
self.siamese_model.eval()
self.transform = get_transforms()['val']
print(f"Siamese model loaded from {model_path} on device: {device}")
else:
print(f"Siamese model not found at {model_path}")
except Exception as e:
print(f"Error loading Siamese model: {e}")
self.siamese_model = None
def _compute_siamese_similarity(self, img1_array, img2_array):
"""Compute similarity between two images using Siamese network."""
if self.siamese_model is None or self.transform is None:
return 0.0
try:
# Convert numpy arrays to PIL Images
img1_pil = Image.fromarray(img1_array)
img2_pil = Image.fromarray(img2_array)
# Apply transforms
img1_tensor = self.transform(img1_pil).unsqueeze(0)
img2_tensor = self.transform(img2_pil).unsqueeze(0)
# Move tensors to the same device as the model
device = next(self.siamese_model.parameters()).device
img1_tensor = img1_tensor.to(device)
img2_tensor = img2_tensor.to(device)
# Get embeddings
with torch.no_grad():
emb1 = self.siamese_model.forward_once(img1_tensor)
emb2 = self.siamese_model.forward_once(img2_tensor)
# Compute cosine similarity
similarity = torch.nn.functional.cosine_similarity(emb1, emb2).item()
return similarity
except Exception as e:
print(f"Error computing Siamese similarity: {e}")
return 0.0
def _compute_basic_similarity(self, img1_array, img2_array):
"""Compute basic similarity using histogram comparison."""
try:
# Resize images to same size
size = (128, 128)
img1_resized = cv2.resize(img1_array, size)
img2_resized = cv2.resize(img2_array, size)
# Compute histograms
hist1 = cv2.calcHist([img1_resized], [0, 1, 2], None, [32, 32, 32], [0, 256, 0, 256, 0, 256])
hist2 = cv2.calcHist([img2_resized], [0, 1, 2], None, [32, 32, 32], [0, 256, 0, 256, 0, 256])
# Normalize histograms
hist1 = cv2.normalize(hist1, hist1).flatten()
hist2 = cv2.normalize(hist2, hist2).flatten()
# Compute correlation
similarity = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
return similarity
except Exception as e:
print(f"Error computing basic similarity: {e}")
return 0.0
def _bbox_iou(self, boxA, boxB):
"""Compute IoU between two [x1, y1, x2, y2] boxes."""
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
inter_w = max(0, xB - xA)
inter_h = max(0, yB - yA)
inter_area = inter_w * inter_h
if inter_area <= 0:
return 0.0
boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
return inter_area / float(boxA_area + boxB_area - inter_area)
def find_or_create_fish_id(self, family, img_array, crop_path, frame_idx, bbox):
"""Find existing fish or create new ID, using Siamese only for ambiguous matches."""
HIGH_IOU = 0.6
LOW_IOU = 0.3
MAX_FRAME_GAP = 10
AMBIG_MARGIN = 0.05
def _create_new_fish():
# New fish of this family
fish_id = f"{family}_fish_{self.fish_counter[family]:03d}"
self.fish_counter[family] += 1
self.fish_database[family].append({
'id': fish_id,
'images': [crop_path],
'reference_image': img_array,
'first_seen': datetime.now(),
'last_seen': datetime.now(),
'occurrence_count': 1,
'last_bbox': bbox,
'last_frame': frame_idx
})
return fish_id, True # New fish
fish_list = self.fish_database.get(family, [])
if not fish_list:
return _create_new_fish()
candidates = []
best_iou = 0.0
# Cheap gating: only compare to recent fish with overlapping boxes
for fish in fish_list:
last_frame = fish.get('last_frame')
last_bbox = fish.get('last_bbox')
if last_frame is None or last_bbox is None:
continue
if abs(frame_idx - last_frame) > MAX_FRAME_GAP:
continue
iou = self._bbox_iou(bbox, last_bbox)
if iou <= 0:
continue
candidates.append((fish, iou))
best_iou = max(best_iou, iou)
# No viable candidates by geometry
if not candidates:
return _create_new_fish()
# Strong, unambiguous IoU match: skip Siamese
strong_candidates = [c for c in candidates if c[1] >= HIGH_IOU]
if len(strong_candidates) == 1:
fish, _ = strong_candidates[0]
fish['images'].append(crop_path)
fish['last_seen'] = datetime.now()
fish['occurrence_count'] += 1
fish['last_bbox'] = bbox
fish['last_frame'] = frame_idx
return fish['id'], False
# Low IoU and no Siamese rescue requested
if best_iou < LOW_IOU and self.siamese_model is None:
return _create_new_fish()
# Ambiguous: run Siamese (or basic similarity if Siamese unavailable) on near-best candidates
near_best = [
(fish, iou) for fish, iou in candidates
if best_iou - iou <= AMBIG_MARGIN and iou >= LOW_IOU
]
best_match = None
best_similarity = -1.0
for fish, _ in near_best:
if self.siamese_model is not None:
similarity = self._compute_siamese_similarity(img_array, fish['reference_image'])
else:
similarity = self._compute_basic_similarity(img_array, fish['reference_image'])
if similarity > best_similarity:
best_similarity = similarity
best_match = fish
if best_match is not None and best_similarity >= self.similarity_threshold:
best_match['images'].append(crop_path)
best_match['last_seen'] = datetime.now()
best_match['occurrence_count'] += 1
best_match['last_bbox'] = bbox
best_match['last_frame'] = frame_idx
return best_match['id'], False # Existing fish
return _create_new_fish()
def register_tracked_fish(self, family, track_id, img_array, crop_path, frame_idx, bbox):
"""Register fish using an external tracker ID (e.g., ByteTrack)."""
if track_id is None:
return self.find_or_create_fish_id(family, img_array, crop_path, frame_idx, bbox)
try:
track_val = float(track_id)
if np.isnan(track_val):
return self.find_or_create_fish_id(family, img_array, crop_path, frame_idx, bbox)
track_int = int(track_val)
except (TypeError, ValueError):
return self.find_or_create_fish_id(family, img_array, crop_path, frame_idx, bbox)
fish_id = f"{family}_track_{track_int:05d}"
fish_list = self.fish_database[family]
for fish in fish_list:
if fish['id'] == fish_id:
fish['images'].append(crop_path)
fish['last_seen'] = datetime.now()
fish['occurrence_count'] += 1
fish['last_bbox'] = bbox
fish['last_frame'] = frame_idx
fish['track_id'] = track_int
return fish_id, False
fish_list.append({
'id': fish_id,
'track_id': track_int,
'images': [crop_path],
'reference_image': img_array,
'first_seen': datetime.now(),
'last_seen': datetime.now(),
'occurrence_count': 1,
'last_bbox': bbox,
'last_frame': frame_idx
})
return fish_id, True
def get_tracking_summary(self):
"""Get summary of tracked fish."""
summary = {}
for family, fish_list in self.fish_database.items():
summary[family] = {
'total_individuals': len(fish_list),
'fish': []
}
for fish in fish_list:
summary[family]['fish'].append({
'id': fish['id'],
'occurrences': fish['occurrence_count'],
'first_seen': str(fish['first_seen']),
'last_seen': str(fish['last_seen']),
'num_images': len(fish['images']),
'track_id': fish.get('track_id')
})
return summary
class DetectionProcessor:
"""Process YOLO detections, classify, and track fish."""
process_detections = pdp.process_detections
def __init__(self, args):
self.args = args
self.output_dir = Path(self.args.output)
self.crops_dir = self.output_dir / 'crops'
self.families_dir = self.output_dir / 'families'
self.tracking_dir = self.output_dir / 'tracking'
self.analysis_dir = self.output_dir / 'analysis'
annotated_output = getattr(self.args, 'annotated_output', None)
self.annotated_output = Path(annotated_output) if annotated_output else None
self.video_writer = None
self.video_meta = {}
for dir_path in [self.output_dir, self.crops_dir, self.families_dir,
self.tracking_dir, self.analysis_dir]:
dir_path.mkdir(parents=True, exist_ok=True)
self.detections = []
self.predictions = []
self.crop_paths = []
self.class_names = {}
self.species_to_family = SPECIES_TO_FAMILY or {}
self.use_bytetrack = True
self.tracker_config = getattr(self.args, 'tracker_config', 'bytetrack.yaml')
self.track_db = defaultdict(dict) # family -> fish_id -> record
print("Tracking method: ByteTrack (Ultralytics tracker IDs), single-pass pipeline.")
# Classification state
self.classifier = None
self.classifier_classes = []
self.family_to_id = {}
self.classification_batch_size = getattr(self.args, 'classification_batch_size', cfg.CLASSIFICATION_BATCH_SIZE)
self.classification_enabled = bool(self.args.classify)
if self.classification_enabled:
self._initialize_embedding_classifier()
else:
print("Classification disabled via CLI flag.")
def _initialize_embedding_classifier(self):
"""Load the embedding classifier."""
if EmbeddingClassifier is None:
reason = f"{CLASSIFICATION_IMPORT_ERROR}" if CLASSIFICATION_IMPORT_ERROR else "unknown import error"
print(f"Classification modules unavailable ({reason}). Skipping classifier load.")
return
model_path = Path(getattr(self.args, 'classification_model', cfg.CLASSIFICATION_MODEL_PATH))
dataset_path = Path(getattr(self.args, 'classification_dataset', cfg.CLASSIFICATION_DATASET_PATH))
missing_paths = [str(path) for path in (model_path, dataset_path) if not Path(path).exists()]
if missing_paths:
print(f"Classification assets missing: {', '.join(missing_paths)}")
return
config = {
'log_level': getattr(self.args, 'classification_log_level', cfg.CLASSIFICATION_LOG_LEVEL),
'dataset': {'path': str(dataset_path)},
'model': {
'path': str(model_path),
'device': getattr(self.args, 'classification_device', cfg.CLASSIFICATION_DEVICE)
}
}
try:
print(f"Loading embedding classifier from {model_path} ...")
self.classifier = EmbeddingClassifier(config)
base_classes = list(FAMILY_NAMES) if FAMILY_NAMES else []
self.classifier_classes = base_classes.copy()
self.family_to_id = {name: idx for idx, name in enumerate(self.classifier_classes)}
if "Unknown" not in self.family_to_id:
self.family_to_id["Unknown"] = len(self.classifier_classes)
self.classifier_classes.append("Unknown")
if not self.species_to_family:
self.species_to_family = SPECIES_TO_FAMILY
print(f"Classifier loaded with {len(self.classifier_classes)} family labels.")
except Exception as exc:
print(f"Warning: Failed to initialize embedding classifier: {exc}")
self.classifier = None
def _init_video_writer(self, width, height, fps):
"""Initialize video writer for annotated output if requested."""
if not self.annotated_output or self.video_writer:
return
width, height = int(width), int(height)
if width <= 0 or height <= 0:
return
self.annotated_output.parent.mkdir(parents=True, exist_ok=True)
safe_fps = fps if fps and not np.isnan(fps) and fps > 0 else 30.0
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
self.video_writer = cv2.VideoWriter(
str(self.annotated_output),
fourcc,
safe_fps,
(width, height)
)
self.video_meta = {'fps': safe_fps, 'size': (width, height)}
def _annotate_frame(self, frame, boxes, confidences, class_ids, track_ids=None):
"""Draw bounding boxes and labels on a frame."""
if frame is None:
return None
annotated = frame.copy()
for idx, (box, conf, class_id) in enumerate(zip(boxes, confidences, class_ids)):
x1, y1, x2, y2 = map(int, box)
color = (0, 255, 0)
cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
label = self.class_names.get(class_id, f"class_{class_id}")
track_text = ""
if track_ids is not None and len(track_ids) > idx:
try:
tid_raw = float(track_ids[idx]) if track_ids[idx] is not None else None
if tid_raw is not None and not np.isnan(tid_raw):
track_text = f" id:{int(tid_raw)}"
except Exception:
track_text = ""
text = f"{label}{track_text} {conf:.2f}"
cv2.putText(
annotated,
text,
(x1, max(0, y1 - 5)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
1,
cv2.LINE_AA
)
return annotated
def close(self):
"""Release resources."""
if self.video_writer:
self.video_writer.release()
self.video_writer = None
def _finalize_annotation(self):
"""Finalize annotated video if enabled."""
if self.video_writer:
try:
self.video_writer.release()
except Exception:
pass
self.video_writer = None
def process_and_track_detections(self, model, device, progress_callback=None):
"""Single-pass detection + ByteTrack + batched classification."""
print(f"\n๐Ÿ” Single-pass detect/track/classify with ByteTrack...")
cap = cv2.VideoCapture(self.args.source)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
if width > 0 and height > 0:
duration_seconds = None
if total_frames and fps and fps > 0:
duration_seconds = total_frames / fps
self.video_meta = {
'fps': fps if fps and not np.isnan(fps) and fps > 0 else None,
'size': (width, height),
'total_frames': total_frames,
'duration_seconds': duration_seconds
}
if self.annotated_output:
self._init_video_writer(width, height, fps)
stream = model.track(
source=self.args.source,
conf=self.args.conf,
iou=self.args.iou,
device=device,
tracker=self.tracker_config,
persist=True,
stream=True,
verbose=False
)
detection_count = 0
frame_idx = -1
last_time = time.time()
last_progress = 0
pbar = tqdm(stream, total=total_frames if total_frames > 0 else None, desc="Detect+track", unit="frame")
for result in pbar:
frame_idx += 1
orig_img = result.orig_img
if orig_img is None:
continue
# Emit incremental progress updates while running YOLO/ByteTrack
if progress_callback and total_frames > 0:
# Reserve 10โ€“90% for this stage; leave final 10% for post-processing
progress = 10 + int(((frame_idx + 1) / total_frames) * 80)
progress = max(10, min(progress, 90))
if progress > last_progress:
last_progress = progress
progress_callback(stage='yolo', progress=progress)
if hasattr(result, 'names'):
self.class_names.update(result.names)
boxes = result.boxes
if boxes is None or boxes.xyxy is None or len(boxes) == 0:
if self.video_writer:
self.video_writer.write(orig_img)
continue
xyxy = boxes.xyxy.cpu().numpy()
confs = boxes.conf.cpu().numpy() if boxes.conf is not None else np.zeros(len(xyxy))
class_ids = boxes.cls.cpu().numpy().astype(int)
track_ids = boxes.id.cpu().numpy() if boxes.id is not None else np.array([None] * len(class_ids))
frame_dets = []
for i, (box, conf, class_id, track_id) in enumerate(zip(xyxy, confs, class_ids, track_ids)):
if conf < 0.3:
continue
x1, y1, x2, y2 = map(int, box)
h, w = orig_img.shape[:2]
x1 = max(0, x1 - self.args.padding)
y1 = max(0, y1 - self.args.padding)
x2 = min(w, x2 + self.args.padding)
y2 = min(h, y2 + self.args.padding)
crop_w, crop_h = x2 - x1, y2 - y1
if crop_w < self.args.min_crop_size or crop_h < self.args.min_crop_size:
continue
crop = orig_img[y1:y2, x1:x2]
if crop.size == 0:
continue
crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
frame_dets.append({
'frame': frame_idx,
'bbox': [x1, y1, x2, y2],
'crop_bgr': crop,
'crop_rgb': crop_rgb,
'confidence': float(conf),
'class_id': class_id,
'class_name': self.class_names.get(class_id, f"class_{class_id}"),
'track_id': int(track_id) if track_id is not None and not np.isnan(track_id) else None,
'seq_idx': i
})
if self.video_writer:
annotated_frame = result.plot() if hasattr(result, 'plot') else orig_img
if annotated_frame is None:
annotated_frame = orig_img
self.video_writer.write(annotated_frame)
if not frame_dets:
continue
batch_imgs = [det['crop_rgb'] for det in frame_dets]
if self.classifier and self.classification_enabled:
try:
predictions_batch = self.classifier.inference_numpy_batch(batch_imgs)
except Exception as exc:
print(f"Classification batch failed on frame {frame_idx}: {exc}")
predictions_batch = [None] * len(frame_dets)
else:
predictions_batch = [None] * len(frame_dets)
for det_idx, det in enumerate(frame_dets):
predictions = predictions_batch[det_idx] if det_idx < len(predictions_batch) else None
classification_guesses = []
species_name = None
family_name = None
if predictions:
sorted_preds = sorted(
predictions,
key=lambda p: getattr(p, 'accuracy', 0.0),
reverse=True
)
if sorted_preds:
best_pred = sorted_preds[0]
best_conf = float(getattr(best_pred, 'accuracy', 0.0))
# Treat low-confidence predictions as unidentified
if best_conf < 0.70:
family_name = "unidentified fish"
species_name = "unidentified fish"
else:
species_name = best_pred.name
for pred in sorted_preds[:2]:
fam_name = self.species_to_family.get(pred.name, pred.name)
classification_guesses.append({
'family': fam_name,
'confidence': float(getattr(pred, 'accuracy', 0.0))
})
if family_name is None:
family_name = self.species_to_family.get(species_name, None) if species_name else None
if not family_name:
family_name = self.species_to_family.get(det['class_name'], det['class_name'])
track_id = det['track_id']
fish_id = f"track_{track_id:05d}" if track_id is not None else f"det_{detection_count:06d}"
family_dir = self.families_dir / family_name
family_dir.mkdir(exist_ok=True)
fish_dir = family_dir / fish_id
fish_dir.mkdir(exist_ok=True)
crop_filename = f"frame_{det['frame']:06d}_det_{det['seq_idx']:03d}.jpg"
crop_path = fish_dir / crop_filename
cv2.imwrite(str(crop_path), det['crop_bgr'])
general_crop_path = self.crops_dir / f"{family_name}_{fish_id}_{crop_filename}"
shutil.copy2(crop_path, general_crop_path)
family_tracks = self.track_db[family_name]
track_rec = family_tracks.get(fish_id, {
'images': [],
'occurrence_count': 0,
'first_seen': det['frame'],
'last_seen': det['frame']
})
track_rec['images'].append(str(crop_path))
track_rec['occurrence_count'] += 1
track_rec['last_seen'] = det['frame']
family_tracks[fish_id] = track_rec
rel_crop_path = os.path.relpath(crop_path, self.output_dir)
detection_info = {
'frame': det['frame'],
'crop_path': str(crop_path),
'image_url': f"/{rel_crop_path.replace(os.sep, '/')}",
'family': family_name,
'fish_id': fish_id,
'track_id': track_id,
'is_new_fish': False,
'confidence': det['confidence'],
'bbox': det['bbox'],
'crop_size': [det['bbox'][2] - det['bbox'][0], det['bbox'][3] - det['bbox'][1]],
'class_name': det['class_name'],
'species_name': species_name or det['class_name'],
'classification_guesses': classification_guesses,
'fps': self.video_meta.get('fps', None)
}
self.detections.append(detection_info)
self.crop_paths.append(crop_path)
detection_count += 1
pbar.set_postfix({
'Tracked': detection_count,
'FPS': f'{1/(time.time()-last_time+1e-6):.1f}'
})
last_time = time.time()
self._finalize_annotation()
print(f"\nโœ… Processed {detection_count} detections across all frames")
return detection_count
def save_tracking_results(self):
"""Save tracking results and analysis."""
# Save tracking summary
summary = {}
for family, fish_map in self.track_db.items():
summary[family] = {
'total_individuals': len(fish_map),
'fish': []
}
for fish_id, rec in fish_map.items():
summary[family]['fish'].append({
'id': fish_id,
'occurrences': rec.get('occurrence_count', 0),
'first_seen': rec.get('first_seen'),
'last_seen': rec.get('last_seen'),
'num_images': len(rec.get('images', []))
})
with open(self.analysis_dir / 'tracking_summary.json', 'w') as f:
json.dump(summary, f, indent=2, default=str)
# Save detailed tracking database
detailed_db = {}
for family, fish_map in self.track_db.items():
detailed_db[family] = []
for fish_id, rec in fish_map.items():
detailed_db[family].append({
'id': fish_id,
'occurrences': rec.get('occurrence_count', 0),
'first_seen': rec.get('first_seen'),
'last_seen': rec.get('last_seen'),
'images': rec.get('images', []),
'track_id': fish_id
})
with open(self.analysis_dir / 'tracking_database.json', 'w') as f:
json.dump(detailed_db, f, indent=2)
# Generate statistics
stats = {
'total_families': len(self.track_db),
'total_individuals': sum(len(fish_map) for fish_map in self.track_db.values()),
'total_detections': len(self.detections),
'families': {}
}
for family, fish_map in self.track_db.items():
stats['families'][family] = {
'individuals': len(fish_map),
'total_occurrences': sum(rec.get('occurrence_count', 0) for rec in fish_map.values()),
'avg_occurrences_per_fish': (
sum(rec.get('occurrence_count', 0) for rec in fish_map.values()) / len(fish_map)
if fish_map else 0
)
}
with open(self.analysis_dir / 'tracking_statistics.json', 'w') as f:
json.dump(stats, f, indent=2)
print(f"\n๐Ÿ“Š Tracking Results:")
print(f" - Total families detected: {stats['total_families']}")
print(f" - Total individual fish tracked: {stats['total_individuals']}")
print(f" - Total detections: {stats['total_detections']}")
print(f"\n Family breakdown:")
for family, family_stats in stats['families'].items():
print(f" {family}: {family_stats['individuals']} individuals, "
f"{family_stats['total_occurrences']} total occurrences")
def save_detection_details(self):
"""Persist detection details for downstream consumers (e.g., API responses)."""
details_path = self.analysis_dir / 'detection_details.json'
payload = {
'detections': self.detections,
'meta': {
'fps': self.video_meta.get('fps'),
'output_dir': str(self.output_dir),
'total_frames': self.video_meta.get('total_frames'),
'duration_seconds': self.video_meta.get('duration_seconds'),
'width': self.video_meta.get('size', (None, None))[0] if self.video_meta.get('size') else None,
'height': self.video_meta.get('size', (None, None))[1] if self.video_meta.get('size') else None
}
}
with open(details_path, 'w') as f:
json.dump(payload, f, indent=2)
return details_path
def run_detection_pipeline(model_path,
source,
output_dir,
conf=cfg.CONF_THRESHOLD,
iou=cfg.IOU_THRESHOLD,
device=cfg.DEVICE,
min_crop_size=50,
padding=cfg.PADDING,
classify=True,
tracking_method="bytetrack",
tracker_config="bytetrack.yaml",
similarity_threshold=0.85,
siamese_model=None,
annotated_output=None,
progress_callback=None):
"""Programmatic entrypoint to run detection, tracking, and optional annotation."""
args = SimpleNamespace(
model=str(model_path),
source=str(source),
output=str(output_dir),
conf=conf,
iou=iou,
device=device,
min_crop_size=min_crop_size,
padding=padding,
classify=classify,
tracking_method=tracking_method,
tracker_config=tracker_config,
similarity_threshold=similarity_threshold,
siamese_model=siamese_model or cfg.SIAMESE_MODEL_PATH,
annotated_output=str(annotated_output) if annotated_output else None,
classification_model=cfg.CLASSIFICATION_MODEL_PATH,
classification_dataset=cfg.CLASSIFICATION_DATASET_PATH,
classification_batch_size=cfg.CLASSIFICATION_BATCH_SIZE,
classification_device=cfg.CLASSIFICATION_DEVICE,
classification_log_level=cfg.CLASSIFICATION_LOG_LEVEL
)
model_path = Path(args.model)
if not model_path.exists():
return {'success': False, 'error': f"Model file not found: {model_path}"}
source_path = Path(args.source)
if not source_path.exists():
return {'success': False, 'error': f"Source not found: {source_path}"}
processor = None
try:
processor = DetectionProcessor(args)
device_to_use = up.get_device(args.device)
model = YOLO(args.model)
model.to(device_to_use)
detection_count = processor.process_and_track_detections(model, device_to_use, progress_callback=progress_callback)
if detection_count > 0:
processor.save_tracking_results()
detection_details_path = processor.save_detection_details()
tracking_stats_path = processor.analysis_dir / 'tracking_statistics.json'
tracking_summary_path = processor.analysis_dir / 'tracking_summary.json'
def _build_fish_families(details_path):
if not details_path or not Path(details_path).exists():
return []
try:
with open(details_path, 'r') as f:
details = json.load(f)
except Exception:
return []
detections = details.get('detections', [])
fps = details.get('meta', {}).get('fps') or 30
families = defaultdict(lambda: defaultdict(list))
for det in detections:
family = det.get('family', 'Unknown')
fish_id = det.get('fish_id', 'fish_000')
families[family][fish_id].append(det)
def _ts(frame_idx):
total_seconds = frame_idx / fps if fps else 0
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
seconds = int(total_seconds % 60)
frames = int(round((total_seconds - int(total_seconds)) * fps)) if fps else 0
return f"{hours:02d}:{minutes:02d}:{seconds:02d}:{frames:02d}"
fish_families = []
for family, fish_map in families.items():
individual_fish = []
for fish_id, fish_detections in fish_map.items():
fish_entries = []
for det in fish_detections:
rel_path = os.path.relpath(det.get('crop_path', ''), processor.output_dir)
image_url = det.get('image_url') or f"/{rel_path.replace(os.sep, '/')}"
frame_idx = det.get('frame', 0)
fish_entries.append({
"imageURL": image_url,
"timestamp": _ts(frame_idx),
"objectDetection": round(float(det.get('confidence', 0)), 2),
"confidence": [
{
"familyName": guess.get('family', 'Unknown'),
"classifyConfidence": round(float(guess.get('confidence', 0)), 2)
}
for guess in det.get('classification_guesses', [])[:2]
]
})
if fish_entries:
individual_fish.append({
"fishID": fish_id,
"fish": fish_entries
})
if individual_fish:
fish_families.append({
"familyName": family,
"fishCount": len(individual_fish),
"individualFish": individual_fish
})
return fish_families
fish_families = _build_fish_families(detection_details_path)
response_payload = {
"success": True,
"statusCode": 200,
"message": "Upload processed successfully.",
"processPrecentage": 100,
"stage": "Finished",
"data": {
"annotatedVideoURL": (str(processor.annotated_output) if processor.annotated_output else None),
"fishFamilies": fish_families,
"trackingStatistics": str(tracking_stats_path) if tracking_stats_path.exists() else None,
"trackingSummary": str(tracking_summary_path) if tracking_summary_path.exists() else None,
"detectionDetails": str(detection_details_path) if detection_details_path else None,
"outputDir": str(processor.output_dir),
"totalDetections": detection_count
}
}
return response_payload
except Exception as exc:
return {'success': False, 'error': str(exc)}
finally:
if processor:
try:
processor.close()
except Exception:
pass
def main():
"""Main function to run the detection, classification, and tracking pipeline."""
parser = argparse.ArgumentParser(description="Run YOLO detection, classification, and fish tracking pipeline")
parser.add_argument('--model', type=str, default=cfg.MODEL_DEFAULT,
help="Path to YOLO model file")
parser.add_argument('--source', type=str, default=cfg.SOURCE_DEFAULT,
help="Source video/image file or directory")
parser.add_argument('--output', type=str, default=cfg.OUTPUT_DIR,
help="Output directory")
parser.add_argument('--conf', type=float, default=cfg.CONF_THRESHOLD,
help="Confidence threshold (default: 0.25)")
parser.add_argument('--iou', type=float, default=cfg.IOU_THRESHOLD,
help="IoU threshold for NMS (default: 0.45)")
parser.add_argument('--device', type=str, default=cfg.DEVICE,
help="Device to run on (auto, cpu, cuda, mps)")
parser.add_argument('--min-crop-size', type=int, default=50,
help="Minimum crop size in pixels (default: 50)")
parser.add_argument('--padding', type=int, default=cfg.PADDING,
help="Padding around bounding box in pixels")
parser.add_argument('--classify', action='store_true', default=True,
help="Enable classification (default: True)")
parser.add_argument('--tracking-method', type=str, default='bytetrack',
choices=['siamese', 'bytetrack'],
help="Tracking backend (ByteTrack recommended)")
parser.add_argument('--tracker-config', type=str, default='bytetrack.yaml',
help="Tracker config name/path when using --tracking-method=bytetrack")
parser.add_argument('--similarity-threshold', type=float, default=0.85,
help="Similarity threshold for fish matching (default: 0.85)")
parser.add_argument('--siamese-model', type=str, default=None,
help="Path to Siamese model checkpoint")
parser.add_argument('--annotated-output', type=str, default=None,
help="Optional path to save annotated video")
args_cli = parser.parse_args()
args = SimpleNamespace(
model=args_cli.model,
source=args_cli.source,
output=args_cli.output,
conf=args_cli.conf,
iou=args_cli.iou,
device=args_cli.device,
min_crop_size=args_cli.min_crop_size,
padding=args_cli.padding,
classify=args_cli.classify,
tracking_method=args_cli.tracking_method,
tracker_config=args_cli.tracker_config,
similarity_threshold=args_cli.similarity_threshold,
siamese_model=args_cli.siamese_model or cfg.SIAMESE_MODEL_PATH,
annotated_output=args_cli.annotated_output,
classification_model=cfg.CLASSIFICATION_MODEL_PATH,
classification_dataset=cfg.CLASSIFICATION_DATASET_PATH,
classification_batch_size=cfg.CLASSIFICATION_BATCH_SIZE,
classification_device=cfg.CLASSIFICATION_DEVICE,
classification_log_level=cfg.CLASSIFICATION_LOG_LEVEL
)
if not cfg.ULTRALYTICS_AVAILABLE:
print("Error: ultralytics is required. Run 'pip install ultralytics'")
sys.exit(1)
if not Path(args.model).exists():
print(f"Error: Model file not found: {args.model}")
sys.exit(1)
source_path = Path(args.source)
if not source_path.exists():
print(f"Error: Source not found: {args.source}")
sys.exit(1)
run_result = run_detection_pipeline(
model_path=args.model,
source=args.source,
output_dir=args.output,
conf=args.conf,
iou=args.iou,
device=args.device,
min_crop_size=args.min_crop_size,
padding=args.padding,
classify=args.classify,
tracking_method=args.tracking_method,
tracker_config=args.tracker_config,
similarity_threshold=args.similarity_threshold,
siamese_model=args.siamese_model,
annotated_output=args.annotated_output
)
if not run_result.get('success'):
print(f"An error occurred: {run_result.get('error', 'unknown error')}")
sys.exit(1)
detection_count = run_result.get('detection_count', 0)
if detection_count > 0:
print("\n๐ŸŽ‰ Processing completed successfully!")
print(f"๐Ÿ“ Results saved to: {run_result.get('output_dir')}")
print(f" - Families folder: {Path(run_result.get('output_dir')) / 'families'}")
print(f" - Tracking analysis: {Path(run_result.get('output_dir')) / 'analysis'}")
else:
print("No detections found to process.")
if run_result.get('annotated_video_path'):
print(f"๐Ÿ“น Annotated video: {run_result['annotated_video_path']}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment