"""Medical imaging dataset classes.
Requires optional dependencies: nibabel, scikit-image.
Install with: pip install nibabel scikit-image
"""
import os
import re
import numpy as np
from scipy import sparse
from geomfum.shape import TriangleMesh
# ACDC segmentation label mapping
LABEL_RV = 1 # right ventricle
LABEL_MYO = 2 # myocardium
LABEL_LV = 3 # left ventricle
STRUCTURE_LABELS = {
"lv": LABEL_LV,
"myo": LABEL_MYO,
"rv": LABEL_RV,
}
def _laplacian_smooth(verts, faces, n_iter=3, lamb=0.5):
"""Umbrella Laplacian smoothing.
Parameters
----------
verts : array, shape=[n, 3]
faces : array, shape=[m, 3]
n_iter : int
lamb : float Relaxation factor in (0, 1).
Returns
-------
verts : array, shape=[n, 3]
"""
n = len(verts)
tri = np.asarray(faces, dtype=np.int64)
i = np.concatenate(
[tri[:, 0], tri[:, 1], tri[:, 2], tri[:, 1], tri[:, 2], tri[:, 0]]
)
j = np.concatenate(
[tri[:, 1], tri[:, 2], tri[:, 0], tri[:, 0], tri[:, 1], tri[:, 2]]
)
A = sparse.csr_matrix((np.ones(len(i)), (i, j)), shape=(n, n))
deg = np.asarray(A.sum(axis=1)).flatten()
deg[deg == 0] = 1.0
L = sparse.diags(1.0 / deg) @ A
verts = np.array(verts, dtype=float)
for _ in range(n_iter):
verts = (1.0 - lamb) * verts + lamb * (L @ verts)
return verts
[docs]
def nifti_seg_to_mesh(seg_path, label, voxel_spacing=True, smooth_iter=3):
"""Extract a triangle mesh from a NIfTI binary segmentation mask.
Parameters
----------
seg_path : str
Path to the ``_gt.nii.gz`` segmentation file.
label : int
Integer label of the structure to extract.
voxel_spacing : bool
If True, multiply vertex positions by the NIfTI voxel size so
coordinates are in mm.
smooth_iter : int
Number of Laplacian smoothing iterations. Set to 0 to skip.
Returns
-------
mesh : TriangleMesh
"""
try:
import nibabel as nib
except ImportError as err:
raise ImportError(
"nibabel is required. Install with: pip install nibabel"
) from err
try:
from skimage.measure import marching_cubes
except ImportError as err:
raise ImportError(
"scikit-image is required. Install with: pip install scikit-image"
) from err
img = nib.load(seg_path)
data = img.get_fdata()
spacing = (
tuple(float(s) for s in img.header.get_zooms()[:3])
if voxel_spacing
else (1.0, 1.0, 1.0)
)
mask = (data == label).astype(np.float32)
if mask.sum() == 0:
raise ValueError(
f"Label {label} not found in '{seg_path}'. "
f"Present labels: {np.unique(data).tolist()}"
)
verts, faces, _, _ = marching_cubes(mask, level=0.5, spacing=spacing)
if smooth_iter > 0:
verts = _laplacian_smooth(verts, faces, n_iter=smooth_iter)
return TriangleMesh(vertices=verts, faces=faces)
def _parse_acdc_info(info_path):
"""Parse ACDC ``Info.cfg`` into a dict of strings."""
info = {}
with open(info_path) as fh:
for line in fh:
line = line.strip()
if ":" in line:
key, val = line.split(":", 1)
info[key.strip()] = val.strip()
return info
[docs]
class AcdcPatient:
"""A single ACDC patient with lazy-loaded ED/ES meshes.
Parameters
----------
patient_dir : str
Path to the patient directory (e.g. ``/data/acdc/training/patient001``).
structure : str
Cardiac structure to extract: ``'lv'``, ``'rv'``, or ``'myo'``.
voxel_spacing : bool
Apply voxel spacing so mesh coordinates are in mm.
smooth_iter : int
Laplacian smoothing iterations on the raw marching-cubes mesh.
"""
def __init__(self, patient_dir, structure="lv", voxel_spacing=True, smooth_iter=3):
self.patient_dir = patient_dir
self.patient_id = os.path.basename(patient_dir)
self.structure = structure
self.voxel_spacing = voxel_spacing
self.smooth_iter = smooth_iter
info = _parse_acdc_info(os.path.join(patient_dir, "Info.cfg"))
self.group = info.get("Group", "UNK")
self.ed_frame = int(info.get("ED", 1))
self.es_frame = int(info.get("ES", 1))
self.height = float(info.get("Height", 0) or 0)
self.weight = float(info.get("Weight", 0) or 0)
@property
def ed_seg_path(self):
"""Path to end-diastole segmentation file."""
fname = f"{self.patient_id}_frame{self.ed_frame:02d}_gt.nii.gz"
return os.path.join(self.patient_dir, fname)
@property
def es_seg_path(self):
"""Path to end-systole segmentation file."""
fname = f"{self.patient_id}_frame{self.es_frame:02d}_gt.nii.gz"
return os.path.join(self.patient_dir, fname)
[docs]
def load_ed(self):
"""Load and return the end-diastole TriangleMesh."""
return nifti_seg_to_mesh(
self.ed_seg_path,
STRUCTURE_LABELS[self.structure],
self.voxel_spacing,
self.smooth_iter,
)
[docs]
def load_es(self):
"""Load and return the end-systole TriangleMesh."""
return nifti_seg_to_mesh(
self.es_seg_path,
STRUCTURE_LABELS[self.structure],
self.voxel_spacing,
self.smooth_iter,
)
[docs]
def load_pair(self):
"""Return ``(mesh_ed, mesh_es)``."""
return self.load_ed(), self.load_es()
@property
def metadata(self):
"""Metadata dict (no mesh loading)."""
return {
"patient_id": self.patient_id,
"group": self.group,
"ed_frame": self.ed_frame,
"es_frame": self.es_frame,
"height_cm": self.height,
"weight_kg": self.weight,
}
def __repr__(self):
"""Return a string representation of the AcdcPatient."""
return (
f"AcdcPatient(id={self.patient_id!r}, group={self.group!r}, "
f"structure={self.structure!r})"
)
[docs]
class AcdcDataset:
"""ACDC Automated Cardiac Diagnosis Challenge dataset.
Parameters
----------
root : str
Path to the ACDC data directory. Should contain ``patient001/``,
``patient002/``, … subdirectories (either directly or inside a
``training/`` or ``testing/`` subfolder — both layouts accepted).
structure : str
Cardiac structure: ``'lv'`` (left ventricle, default), ``'rv'``
(right ventricle), ``'myo'`` (myocardium).
groups : list[str] or None
Filter by diagnostic group. ACDC groups: ``'NOR'``, ``'DCM'``,
``'HCM'``, ``'MINF'``, ``'RVA'``. ``None`` keeps all groups.
voxel_spacing : bool
Apply voxel spacing so mesh coordinates are in mm. Default ``True``.
smooth_iter : int
Laplacian smoothing iterations on raw marching-cubes meshes.
Default 3. Set to 0 to skip smoothing.
Notes
-----
**Download.** Register and download from
https://www.creatis.insa-lyon.fr/Challenge/acdc/
**Expected layout**::
root/
├── patient001/
│ ├── Info.cfg
│ ├── patient001_frame01.nii.gz
│ ├── patient001_frame01_gt.nii.gz ← ED segmentation
│ ├── patient001_frame12.nii.gz
│ └── patient001_frame12_gt.nii.gz ← ES segmentation
├── patient002/
└── ...
**Segmentation labels:** 0=background, 1=RV, 2=myocardium, 3=LV.
Examples
--------
>>> dataset = AcdcDataset("/path/to/acdc/training", structure="lv")
>>> print(dataset)
AcdcDataset(root=..., n_patients=100, groups={'NOR': 20, ...})
>>> mesh_ed, mesh_es, meta = dataset[0]
>>> mesh_ed, mesh_es = dataset.patients[0].load_pair()
"""
def __init__(
self,
root,
structure="lv",
groups=None,
voxel_spacing=True,
smooth_iter=3,
):
self.root = root
self.structure = structure
self.voxel_spacing = voxel_spacing
self.smooth_iter = smooth_iter
# Accept root pointing directly to patient dirs or to a parent
# containing training/ or testing/ subdirectories.
if not any(re.match(r"patient\d+", d) for d in os.listdir(root)):
for sub in ("training", "testing", "train", "test"):
candidate = os.path.join(root, sub)
if os.path.isdir(candidate) and any(
re.match(r"patient\d+", d) for d in os.listdir(candidate)
):
root = candidate
break
patient_dirs = sorted(
os.path.join(root, d)
for d in os.listdir(root)
if re.match(r"patient\d+", d) and os.path.isdir(os.path.join(root, d))
)
patients = []
for d in patient_dirs:
info_path = os.path.join(d, "Info.cfg")
if os.path.exists(info_path):
patients.append(
AcdcPatient(
d,
structure=structure,
voxel_spacing=voxel_spacing,
smooth_iter=smooth_iter,
)
)
if groups is not None:
groups_set = set(groups)
patients = [p for p in patients if p.group in groups_set]
self.patients = patients
# ------------------------------------------------------------------
# Container protocol
# ------------------------------------------------------------------
def __len__(self):
"""Return the number of patients in the dataset."""
return len(self.patients)
def __getitem__(self, idx):
"""Return ``(mesh_ed, mesh_es, metadata)`` for patient at ``idx``."""
patient = self.patients[idx]
mesh_ed, mesh_es = patient.load_pair()
return mesh_ed, mesh_es, patient.metadata
def __iter__(self):
"""Iterate over patients, yielding ``(mesh_ed, mesh_es, metadata)``."""
for patient in self.patients:
mesh_ed, mesh_es = patient.load_pair()
yield mesh_ed, mesh_es, patient.metadata
# ------------------------------------------------------------------
# Convenience
# ------------------------------------------------------------------
@property
def group_counts(self):
"""Dict mapping group name → patient count."""
from collections import Counter
return dict(Counter(p.group for p in self.patients))
@property
def metadata_list(self):
"""List of metadata dicts for all patients (no mesh loading)."""
return [p.metadata for p in self.patients]
[docs]
def get_patient(self, patient_id):
"""Return ``AcdcPatient`` by ID string (e.g. ``'patient001'``)."""
for p in self.patients:
if p.patient_id == patient_id:
return p
raise KeyError(f"Patient '{patient_id}' not found in dataset.")
def __repr__(self):
"""Return a string representation of the AcdcDataset."""
return (
f"AcdcDataset(root={self.root!r}, n_patients={len(self)}, "
f"groups={self.group_counts}, structure={self.structure!r})"
)