Created
March 11, 2026 03:16
-
-
Save bbarad/b6b662939325e025130307a27e07bf02 to your computer and use it in GitHub Desktop.
Filter Tardis instance filaments based on a cleaned label mrc from Amira.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Filter filament instance CSVs from tardis to points within 1 voxel of | |
| label 4 in the corresponding scores.labels.mrc, then renumber IDs | |
| sequentially. | |
| Reads work_dir (CSVs) and data_dir (MRCs) from a morphometrics config.yml. | |
| """ | |
| import warnings | |
| from pathlib import Path | |
| import click | |
| import mrcfile | |
| import numpy as np | |
| import pandas as pd | |
| import yaml | |
| from scipy.ndimage import binary_dilation, map_coordinates | |
| def load_mrc(path): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| with mrcfile.open(path, permissive=True) as mrc: | |
| data = mrc.data.copy().astype(np.float32) | |
| voxel_size = np.array([mrc.voxel_size.x, mrc.voxel_size.y, mrc.voxel_size.z]) | |
| origin = np.array([mrc.header.origin.x, mrc.header.origin.y, mrc.header.origin.z]) | |
| if np.allclose(origin, 0): | |
| origin = np.array([ | |
| mrc.header.nxstart * voxel_size[0], | |
| mrc.header.nystart * voxel_size[1], | |
| mrc.header.nzstart * voxel_size[2], | |
| ]) | |
| return data, voxel_size, origin | |
| def process(csv_path, mrc_path, out_dir, label=4, radius=1): | |
| print(f"\n--- {csv_path.name} ---") | |
| data, voxel_size, origin = load_mrc(str(mrc_path)) | |
| print(f" MRC shape (Z,Y,X): {data.shape}, voxel size: {voxel_size[0]:.2f} Å") | |
| df = pd.read_csv(csv_path) | |
| print(f" CSV: {len(df)} rows") | |
| col_map = {} | |
| for col in df.columns: | |
| cl = col.strip().upper() | |
| if cl.startswith("X"): | |
| col_map["X"] = col | |
| elif cl.startswith("Y"): | |
| col_map["Y"] = col | |
| elif cl.startswith("Z"): | |
| col_map["Z"] = col | |
| xyz = df[[col_map["X"], col_map["Y"], col_map["Z"]]].values | |
| voxel_coords = (xyz - origin) / voxel_size | |
| # Dilate label mask by given radius (sphere structuring element) | |
| span = 2 * radius + 1 | |
| struct = np.zeros((span, span, span), dtype=bool) | |
| for dz in range(-radius, radius + 1): | |
| for dy in range(-radius, radius + 1): | |
| for dx in range(-radius, radius + 1): | |
| if dz**2 + dy**2 + dx**2 <= radius**2: | |
| struct[dz + radius, dy + radius, dx + radius] = True | |
| dilated = binary_dilation(data == label, structure=struct) | |
| z, y, x = voxel_coords[:, 2], voxel_coords[:, 1], voxel_coords[:, 0] | |
| in_mask = map_coordinates(dilated.astype(np.float32), np.vstack([z, y, x]), | |
| order=0, mode="nearest") > 0.5 | |
| print(f" Kept {in_mask.sum()} / {len(df)} points within {radius} voxel(s) of label {label}") | |
| out_df = df[in_mask].copy() | |
| if "IDs" in out_df.columns: | |
| id_map = {old: new for new, old in enumerate(out_df["IDs"].unique())} | |
| out_df["IDs"] = out_df["IDs"].map(id_map) | |
| out_path = Path(out_dir) / csv_path.name.replace("_instances.csv", "_instances_label4.csv") | |
| out_df.to_csv(out_path, index=False) | |
| print(f" Saved -> {out_path.name}") | |
| @click.command() | |
| @click.argument("config", type=click.Path(exists=True)) | |
| @click.option("--label", default=4, show_default=True, help="Label value to filter to") | |
| @click.option("--radius", default=1, show_default=True, help="Dilation radius in voxels") | |
| def main(config, label, radius): | |
| """Filter filament instance CSVs to points near a given label in the scores.labels.mrc.""" | |
| with open(config) as f: | |
| cfg = yaml.safe_load(f) | |
| work_dir = Path(cfg["work_dir"]) | |
| data_dir = Path(cfg["data_dir"]) | |
| csvs = sorted(work_dir.glob("*_instances.csv")) | |
| if not csvs: | |
| click.echo(f"No *_instances.csv files found in {work_dir}") | |
| return | |
| pairs = [] | |
| for csv in csvs: | |
| stem = csv.name.replace("_instances.csv", "") | |
| mrc = data_dir / f"{stem}_scores.labels.mrc" | |
| if mrc.exists(): | |
| pairs.append((csv, mrc)) | |
| else: | |
| click.echo(f"WARNING: no matching MRC for {csv.name} in {data_dir}, skipping") | |
| click.echo(f"Found {len(pairs)} matched pair(s)") | |
| for csv_path, mrc_path in pairs: | |
| process(csv_path, mrc_path, out_dir=work_dir, label=label, radius=radius) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment