Skip to content

Instantly share code, notes, and snippets.

@bbarad
Created March 11, 2026 03:16
Show Gist options
  • Select an option

  • Save bbarad/b6b662939325e025130307a27e07bf02 to your computer and use it in GitHub Desktop.

Select an option

Save bbarad/b6b662939325e025130307a27e07bf02 to your computer and use it in GitHub Desktop.
Filter Tardis instance filaments based on a cleaned label mrc from Amira.
"""
Filter filament instance CSVs from tardis to points within 1 voxel of
label 4 in the corresponding scores.labels.mrc, then renumber IDs
sequentially.
Reads work_dir (CSVs) and data_dir (MRCs) from a morphometrics config.yml.
"""
import warnings
from pathlib import Path
import click
import mrcfile
import numpy as np
import pandas as pd
import yaml
from scipy.ndimage import binary_dilation, map_coordinates
def load_mrc(path):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with mrcfile.open(path, permissive=True) as mrc:
data = mrc.data.copy().astype(np.float32)
voxel_size = np.array([mrc.voxel_size.x, mrc.voxel_size.y, mrc.voxel_size.z])
origin = np.array([mrc.header.origin.x, mrc.header.origin.y, mrc.header.origin.z])
if np.allclose(origin, 0):
origin = np.array([
mrc.header.nxstart * voxel_size[0],
mrc.header.nystart * voxel_size[1],
mrc.header.nzstart * voxel_size[2],
])
return data, voxel_size, origin
def process(csv_path, mrc_path, out_dir, label=4, radius=1):
print(f"\n--- {csv_path.name} ---")
data, voxel_size, origin = load_mrc(str(mrc_path))
print(f" MRC shape (Z,Y,X): {data.shape}, voxel size: {voxel_size[0]:.2f} Å")
df = pd.read_csv(csv_path)
print(f" CSV: {len(df)} rows")
col_map = {}
for col in df.columns:
cl = col.strip().upper()
if cl.startswith("X"):
col_map["X"] = col
elif cl.startswith("Y"):
col_map["Y"] = col
elif cl.startswith("Z"):
col_map["Z"] = col
xyz = df[[col_map["X"], col_map["Y"], col_map["Z"]]].values
voxel_coords = (xyz - origin) / voxel_size
# Dilate label mask by given radius (sphere structuring element)
span = 2 * radius + 1
struct = np.zeros((span, span, span), dtype=bool)
for dz in range(-radius, radius + 1):
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
if dz**2 + dy**2 + dx**2 <= radius**2:
struct[dz + radius, dy + radius, dx + radius] = True
dilated = binary_dilation(data == label, structure=struct)
z, y, x = voxel_coords[:, 2], voxel_coords[:, 1], voxel_coords[:, 0]
in_mask = map_coordinates(dilated.astype(np.float32), np.vstack([z, y, x]),
order=0, mode="nearest") > 0.5
print(f" Kept {in_mask.sum()} / {len(df)} points within {radius} voxel(s) of label {label}")
out_df = df[in_mask].copy()
if "IDs" in out_df.columns:
id_map = {old: new for new, old in enumerate(out_df["IDs"].unique())}
out_df["IDs"] = out_df["IDs"].map(id_map)
out_path = Path(out_dir) / csv_path.name.replace("_instances.csv", "_instances_label4.csv")
out_df.to_csv(out_path, index=False)
print(f" Saved -> {out_path.name}")
@click.command()
@click.argument("config", type=click.Path(exists=True))
@click.option("--label", default=4, show_default=True, help="Label value to filter to")
@click.option("--radius", default=1, show_default=True, help="Dilation radius in voxels")
def main(config, label, radius):
"""Filter filament instance CSVs to points near a given label in the scores.labels.mrc."""
with open(config) as f:
cfg = yaml.safe_load(f)
work_dir = Path(cfg["work_dir"])
data_dir = Path(cfg["data_dir"])
csvs = sorted(work_dir.glob("*_instances.csv"))
if not csvs:
click.echo(f"No *_instances.csv files found in {work_dir}")
return
pairs = []
for csv in csvs:
stem = csv.name.replace("_instances.csv", "")
mrc = data_dir / f"{stem}_scores.labels.mrc"
if mrc.exists():
pairs.append((csv, mrc))
else:
click.echo(f"WARNING: no matching MRC for {csv.name} in {data_dir}, skipping")
click.echo(f"Found {len(pairs)} matched pair(s)")
for csv_path, mrc_path in pairs:
process(csv_path, mrc_path, out_dir=work_dir, label=label, radius=radius)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment