Source code for geomfum.dataset.torch

"""Datasets for Loading Meshes and Point Clouds using PyTorch."""

import itertools
import os
import random
import warnings

import gsops.backend as gs
import meshio
import numpy as np
import scipy
import torch
from torch.utils.data import Dataset

from geomfum.metric import VertexEuclideanMetric
from geomfum.metric.mesh import ScipyGraphShortestPathMetric
from geomfum.shape import PointCloud, TriangleMesh


[docs] class ShapeDataset(Dataset): """General dataset for loading and preprocessing meshes or point clouds. Parameters ---------- dataset_dir : str Path to the directory containing the dataset. We assume the dataset directory to have a subfolder shapes, for shapes, corr, for correspondences and dist, for cached distance matrices. shape_type : str Type of shape to load. Either 'mesh' or 'pointcloud'. spectral : bool Whether to compute the spectral features. distances : bool Whether to compute geodesic distance matrices. For computational reasons, these are not computed on the fly, but rather loaded from a precomputed .mat file. correspondences : bool Whether to load correspondences. k : int Number of eigenvectors to use for the spectral features. device : torch.device, optional Device to move the data to. """ def __init__( self, dataset_dir, shape_type="mesh", spectral=False, distances=False, correspondences=True, k=200, device=None, ): if shape_type not in ["mesh", "pointcloud"]: raise ValueError("shape_type must be either 'mesh' or 'pointcloud'") self.dataset_dir = dataset_dir self.shape_type = shape_type self.shape_dir = os.path.join(dataset_dir, "shapes") all_shape_files = sorted( [ f for f in os.listdir(self.shape_dir) if f.lower().endswith((".off", ".ply", ".obj")) ] ) self.shape_files = all_shape_files self.device = ( device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) self.spectral = spectral self.k = k self.distances = distances self.correspondences = correspondences # Preload shapes (or their important features) into memory self.shapes = {} self.corrs = {} for filename in self.shape_files: ext = os.path.splitext(filename)[1][1:] if ext not in meshio._helpers._writer_map: warnings.warn(f"Skipped unsupported mesh file: {filename}") continue filepath = os.path.join(self.shape_dir, filename) # Load shape based on type if self.shape_type == "mesh": shape = TriangleMesh.from_file(filepath) else: # pointcloud shape = PointCloud.from_file(filepath) base_name, _ = os.path.splitext(filename) # preprocess if spectral: shape.laplacian.find_spectrum(spectrum_size=k, set_as_basis=True) self.shapes[filename] = shape corr_filename = base_name + ".vts" if self.correspondences: if os.path.exists( os.path.join(self.dataset_dir, "corr", corr_filename) ): # Load correspondences from file, subtract 1 to convert to zero-based indexing. self.corrs[filename] = ( np.loadtxt( os.path.join(self.dataset_dir, "corr", corr_filename) ).astype(np.int32) - 1 ) else: self.corrs[filename] = np.arange(shape.vertices.shape[0]) def __getitem__(self, idx): """Retrieve a data sample by index. Parameters ---------- idx : int Index of the item to retrieve. Returns ------- shape_data: dict Dictionary containing the shape, the correspondence and the distances if available and required. """ filename = self.shape_files[idx] shape = self.shapes[filename] shape_data = {} if self.correspondences: shape_data.update({"corr": gs.array(self.corrs[filename])}) if self.distances: mat_subfolder = os.path.join(self.dataset_dir, "dist") base_name, _ = os.path.splitext(filename) mat_filename = base_name + ".mat" dist_path = os.path.join(mat_subfolder, mat_filename) geod_distance_matrix = None if os.path.exists(dist_path): mat_contents = scipy.io.loadmat(dist_path) if "D" in mat_contents: geod_distance_matrix = mat_contents["D"] if geod_distance_matrix is None: if self.shape_type == "mesh": metric = ScipyGraphShortestPathMetric(shape) else: # pointcloud metric = VertexEuclideanMetric(shape) geod_distance_matrix = metric.dist_matrix() os.makedirs(os.path.dirname(dist_path), exist_ok=True) scipy.io.savemat( dist_path, {"D": gs.to_numpy(geod_distance_matrix)}, ) shape_data.update({"dist_matrix": gs.array(geod_distance_matrix)}) # Move shape data to device shape.vertices = gs.to_device(shape.vertices, self.device) shape.basis.full_vals = gs.to_device(shape.basis.full_vals, self.device) shape.basis.full_vecs = gs.to_device(shape.basis.full_vecs, self.device) shape.laplacian._mass_matrix = gs.to_device( shape.laplacian._mass_matrix, self.device ) # Only move faces to device for meshes if self.shape_type == "mesh": shape.faces = gs.to_device(shape.faces, self.device) shape_data.update({"shape": shape}) return shape_data def __len__(self): """Get the length of the dataset.""" return len(self.shape_files)
# Convenience classes for backward compatibility
[docs] class MeshDataset(ShapeDataset): """ShapeDataset for loading and preprocessing mesh data.""" def __init__( self, dataset_dir, spectral=False, distances=False, correspondences=True, k=200, device=None, ): super().__init__( dataset_dir=dataset_dir, shape_type="mesh", spectral=spectral, distances=distances, correspondences=correspondences, k=k, device=device, )
[docs] class PointCloudDataset(ShapeDataset): """ShapeDataset for loading and preprocessing point cloud data.""" def __init__( self, dataset_dir, spectral=False, distances=False, correspondences=True, k=200, device=None, ): super().__init__( dataset_dir=dataset_dir, shape_type="pointcloud", spectral=spectral, distances=distances, correspondences=correspondences, k=k, device=device, )
[docs] class PairsDataset(Dataset): """ Dataset of pairs of shapes. Each item is a pair (source, target) of shapes from the provided dataset. Parameters ---------- dataset : torch.utils.data.Dataset or list Preloaded dataset or list of shape data objects. pair_mode : str, optional Strategy to generate pairs. Options: 'all', 'random'. Default is 'all'. n_pairs : int, optional Number of random pairs to generate if pair_mode is 'random'. Default is 100. device : torch.device, optional Device to move the data to. If None, uses CUDA if available, else CPU. """ def __init__(self, dataset=None, pair_mode="all", pairs_ratio=100, device=None): # Preload meshes self.shape_data = dataset self.pair_mode = pair_mode self.device = ( device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) # Depending on pair_mode, choose the appropriate strategy if pair_mode == "all": self.pairs = self.generate_all_pairs() elif pair_mode == "random": self.pairs = self.generate_random_pairs( pairs_ratio ) # You can specify the number of pairs else: raise ValueError(f"Unsupported pair_mode: {pair_mode}")
[docs] def generate_all_pairs(self): """Generate all possible pairs of shapes from the dataset.""" return list(itertools.permutations(range(self.shape_data.__len__()), 2))
[docs] def generate_random_pairs(self, pairs_ratio=0.5): """Generate pairs of shapes considering random sampling from the dataset. Parameters ---------- pairs_ratio : float Ratio of pairs to generate compared to the total number of possible pairs. Default is 0.5, meaning half of the possible pairs will be generated. """ return random.sample( list(itertools.combinations(range(self.shape_data.__len__()), 2)), int(self.shape_data.__len__() * pairs_ratio), )
def __getitem__(self, idx): """Get item by index. Parameters ---------- idx : int Index of the item to retrieve. Returns ------- data: dict Dictionary containing the source and target shapes. """ src_idx, tgt_idx = self.pairs[idx] return {"source": self.shape_data[src_idx], "target": self.shape_data[tgt_idx]} def __len__(self): """Get the length of the dataset.""" return len(self.pairs)