Skip to content

Instantly share code, notes, and snippets.

@myociss
Last active August 16, 2025 19:03
Show Gist options
  • Select an option

  • Save myociss/a552a262c8842d9bae54cb34ba008d0a to your computer and use it in GitHub Desktop.

Select an option

Save myociss/a552a262c8842d9bae54cb34ba008d0a to your computer and use it in GitHub Desktop.
Lightning Flash Clustering: cluster flashes
from sklearn.cluster import DBSCAN
# sources[time, x, y, z, power, grid_lat, grid_lon, init_lat, init_lon, init_alt]
def cluster(sources: array_type, xyz_scale: float, t_scale: float, grid_max: int, min_samples: int=10, epsilon: float=1.0, max_duration: float=3.0) -> Tuple[List[array_type], int]:
min_t = np.min(sources[:,0])
max_t = np.max(sources[:,0])
time_start = min_t
all_flashes: List[array_type]=[]
total_removed = 0
algorithm = DBSCAN(eps=epsilon, min_samples=min_samples)
while time_start <= max_t:
indexes = (sources[:,0] >= time_start) & (sources[:,0] < time_start + max_duration*2)
first_half_indexes = (sources[:,0] >= time_start) & (sources[:,0] < time_start + max_duration)
dbscan_data = np.zeros((np.sum(indexes), 4))
dbscan_data[:,:3] = sources[indexes, 1:4] / xyz_scale
dbscan_data[:,3] = sources[indexes, 0] / t_scale
if dbscan_data.shape[0] > 0:
clustering = algorithm.fit(dbscan_data)
first_half_unique_labels = np.unique(clustering.labels_[:np.sum(first_half_indexes)])
first_half_cluster_indexes = np.squeeze(np.argwhere(np.isin(clustering.labels_, first_half_unique_labels)))
cluster_labels = clustering.labels_[first_half_cluster_indexes]
cluster_info, n_removed = get_cluster_list(sources[first_half_cluster_indexes,:], cluster_labels, grid_max)
all_flashes += cluster_info
total_removed += n_removed
mask = np.ones(len(sources), bool)
mask[first_half_cluster_indexes] = 0
sources = sources[mask]
time_start += max_duration
return all_flashes, total_removed
def get_cluster_list(sources: array_type, cluster_ids: np.ndarray[int, np.dtype[np.int64]], grid_max: int) -> Tuple[List[array_type], int]:
unique_cluster_ids = np.unique(cluster_ids)
all_cluster_sources: List[array_type] = []
n_removed = 0
min_sources = 5
for cluster_id in unique_cluster_ids:
if cluster_id == -1:
continue
cluster_sources = sources[cluster_ids == cluster_id]
# remove flashes with out-of-bounds sources
if not np.all((cluster_sources[:,5:7] > -1) & (cluster_sources[:,5:7] < grid_max)):
n_removed += 1
continue
if cluster_sources.shape[0] >= min_sources:
all_cluster_sources.append(cluster_sources)
return all_cluster_sources, n_removed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment