Last active
August 16, 2025 19:03
-
-
Save myociss/a552a262c8842d9bae54cb34ba008d0a to your computer and use it in GitHub Desktop.
Lightning Flash Clustering: cluster flashes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sklearn.cluster import DBSCAN | |
| # sources[time, x, y, z, power, grid_lat, grid_lon, init_lat, init_lon, init_alt] | |
| def cluster(sources: array_type, xyz_scale: float, t_scale: float, grid_max: int, min_samples: int=10, epsilon: float=1.0, max_duration: float=3.0) -> Tuple[List[array_type], int]: | |
| min_t = np.min(sources[:,0]) | |
| max_t = np.max(sources[:,0]) | |
| time_start = min_t | |
| all_flashes: List[array_type]=[] | |
| total_removed = 0 | |
| algorithm = DBSCAN(eps=epsilon, min_samples=min_samples) | |
| while time_start <= max_t: | |
| indexes = (sources[:,0] >= time_start) & (sources[:,0] < time_start + max_duration*2) | |
| first_half_indexes = (sources[:,0] >= time_start) & (sources[:,0] < time_start + max_duration) | |
| dbscan_data = np.zeros((np.sum(indexes), 4)) | |
| dbscan_data[:,:3] = sources[indexes, 1:4] / xyz_scale | |
| dbscan_data[:,3] = sources[indexes, 0] / t_scale | |
| if dbscan_data.shape[0] > 0: | |
| clustering = algorithm.fit(dbscan_data) | |
| first_half_unique_labels = np.unique(clustering.labels_[:np.sum(first_half_indexes)]) | |
| first_half_cluster_indexes = np.squeeze(np.argwhere(np.isin(clustering.labels_, first_half_unique_labels))) | |
| cluster_labels = clustering.labels_[first_half_cluster_indexes] | |
| cluster_info, n_removed = get_cluster_list(sources[first_half_cluster_indexes,:], cluster_labels, grid_max) | |
| all_flashes += cluster_info | |
| total_removed += n_removed | |
| mask = np.ones(len(sources), bool) | |
| mask[first_half_cluster_indexes] = 0 | |
| sources = sources[mask] | |
| time_start += max_duration | |
| return all_flashes, total_removed | |
| def get_cluster_list(sources: array_type, cluster_ids: np.ndarray[int, np.dtype[np.int64]], grid_max: int) -> Tuple[List[array_type], int]: | |
| unique_cluster_ids = np.unique(cluster_ids) | |
| all_cluster_sources: List[array_type] = [] | |
| n_removed = 0 | |
| min_sources = 5 | |
| for cluster_id in unique_cluster_ids: | |
| if cluster_id == -1: | |
| continue | |
| cluster_sources = sources[cluster_ids == cluster_id] | |
| # remove flashes with out-of-bounds sources | |
| if not np.all((cluster_sources[:,5:7] > -1) & (cluster_sources[:,5:7] < grid_max)): | |
| n_removed += 1 | |
| continue | |
| if cluster_sources.shape[0] >= min_sources: | |
| all_cluster_sources.append(cluster_sources) | |
| return all_cluster_sources, n_removed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment