Source code for geomfum.wrap.transformer
"""Transformer for Feature Extraction.
Copyright (c) 2025 Alessandro Riva
References
----------
[RRM2024] Riva, A., Raganato, A, Melzi, S., "Localized Gaussians as Self-Attention Weights for Point Clouds Correspondence",
Smart Tools and Applications in Graphics, 2024, arXiv:2409.13291
https://github.com/ariva00/aided_transformer/
"""
import gsops.backend as gs
import torch
import torch.nn as nn
from geomfum.descriptor.learned import BaseFeatureExtractor
[docs]
class TransformerFeatureExtractor(BaseFeatureExtractor, nn.Module):
"""Transformer Feature Extractor for point clouds.
This feature extractor uses the Transformer architecture to extract
features from shapes.
Parameters
----------
in_channels: int
Input feature dimension (typically 3 for xyz coordinates)
embed_dim: int
Embedding dimension for the transformer
num_heads: int
Number of attention heads
num_layers: int
Number of transformer layers
output_dim: int
Output feature dimension
dropout: float
Dropout probability
use_global_pool: bool
Whether to use global max pooling for global features
"""
def __init__(
self,
in_channels: int = 3,
embed_dim: int = 256,
num_heads: int = 8,
num_layers: int = 4,
out_channels: int = 512,
dropout: float = 0.1,
use_global_pool: bool = True,
k_neighbors: int = 16,
device=None,
descriptor=None,
):
super(TransformerFeatureExtractor, self).__init__()
self.device = torch.device(device) if device else torch.device("cpu")
self.descriptor = descriptor
self.in_channels = in_channels
self.embed_dim = embed_dim
self.out_channels = out_channels
self.use_global_pool = use_global_pool
self.k_neighbors = k_neighbors
# Input projection
self.input_projection = (
nn.Sequential(
nn.Linear(in_channels, embed_dim // 2),
nn.ReLU(),
nn.Linear(embed_dim // 2, embed_dim),
)
.to(self.device)
.float()
)
# Transformer
self.transformer = (
Transformer(
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
dropout=dropout,
)
.to(self.device)
.float()
)
self.output_projection = (
nn.Sequential(nn.Linear(embed_dim, out_channels)).to(self.device).float()
)
[docs]
def forward(self, shape, return_intermediate=False):
"""Forward pass of the feature extractor.
Parameters
----------
shape: Shape
Input shape
Returns
-------
features: torch.Tensor
Extracted features tensor of shape (batch_size, num_points, output_dim)
"""
if self.descriptor is None:
input_feat = shape.vertices
else:
input_feat = self.descriptor(shape).T
xyz = gs.to_torch(input_feat).to(self.device).float().unsqueeze(0)
input_feat = gs.to_torch(input_feat).to(self.device).float()
input_feat = input_feat.unsqueeze(0)
# Project input features
x = self.input_projection(input_feat) # (B, N, embed_dim)
# Apply transformer (self-attention: x attends to itself)
transformer_output = self.transformer(x, x) # (B, N, embed_dim)
# Apply output projection
point_features = self.output_projection(
transformer_output
) # (B, N, output_dim)
if return_intermediate:
return {
"point_features": point_features,
"transformer_features": transformer_output,
}
return point_features
[docs]
class MultiHeadAttention(torch.nn.Module):
"""Multi-Head Attention layer.
Parameters
----------
embed_dim: int
Dimension of the used embedding
num_heads: int
Number of attention heads
dropout: float
Dropout rate
bias: bool
Set the use of leaned bias in the output linear layer of the attention heads
"""
def __init__(
self, embed_dim, num_heads: int, dropout: float = 0.0, bias: bool = True
):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dropout = torch.nn.Dropout(dropout)
self.linear_out = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
self.scale = self.head_dim**-0.5
[docs]
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
attn_prev: torch.Tensor = None,
):
"""Forward pass of the Multi-Head Attention layer.
Parameters
----------
query: torch.Tensor
Query vectors of x
key: torch.Tensor
Key vectors of y
value: torch.Tensor
Value vectors of y
attn_mask: torch.Tensor
Mask tensor for the attention weights of shape (batch_size, num_heads, num_points_x, num_points_y) or (batch_size, num_points_x, num_points_y).
If (batch_size, num_points_x, num_points_y) the mask will be broadcasted over the num_heads dimension
attn_prev: torch.Tensor
Attention weights of the previous layer to be used in the residual attention of shape (batch_size, num_heads, num_points_x, num_points_y)
Returns
-------
output: torch.Tensor
Attention output tensor of shape (batch_size, num_points_x, embed_dim)
hiddens: dict
Intermediate activations of the attention mechanism
"""
query = query.reshape(
query.size(0), self.num_heads, query.size(1), self.head_dim
)
key = key.reshape(key.size(0), self.num_heads, key.size(1), self.head_dim)
value = value.reshape(
value.size(0), self.num_heads, value.size(1), self.head_dim
)
attn = torch.matmul(query, key.transpose(2, 3))
attn = attn * self.scale
attn = self.dropout(attn)
if attn_mask is not None:
if attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(1)
attn = attn.masked_fill(attn_mask == 0, -1e9)
pre_softmax_attn = attn
attn = attn.softmax(dim=-1)
attn = attn + attn_prev if attn_prev is not None else attn
output = torch.matmul(attn, value)
output = output.transpose(1, 2).reshape(
output.size(0), output.size(2), self.embed_dim
)
output = self.linear_out(output)
hiddens = {
"q": query,
"k": key,
"v": value,
"attn": attn,
"pre_softmax_attn": pre_softmax_attn,
}
return output, hiddens
[docs]
class AttentionLayer(torch.nn.Module):
"""Attention Layer for the Transformer.
Parameters
----------
embed_dim: int
Dimension of the used embedding
num_heads: int
Number of attention heads
dropout: float
Dropout rate
attn_bias: bool
Set the use of leaned bias in the output linear layer of the attention heads
ff_mult: int
Dimension factor of the feed forward section of the attention layer. The dimension expansion is computed as ff_mult * embed_dim
"""
def __init__(
self,
embed_dim,
num_heads,
dropout: float = 0.0,
attn_bias: bool = True,
ff_mult: int = 4,
):
super(AttentionLayer, self).__init__()
self.attn = MultiHeadAttention(
embed_dim,
num_heads,
dropout,
attn_bias,
)
self.norm1 = torch.nn.LayerNorm(embed_dim)
self.norm2 = torch.nn.LayerNorm(embed_dim)
self.to_q = torch.nn.Linear(embed_dim, embed_dim)
self.to_k = torch.nn.Linear(embed_dim, embed_dim)
self.to_v = torch.nn.Linear(embed_dim, embed_dim)
self.feed_forward = torch.nn.Sequential(
torch.nn.Linear(embed_dim, embed_dim * ff_mult),
torch.nn.GELU(),
torch.nn.Linear(embed_dim * ff_mult, embed_dim),
)
[docs]
def forward(self, x, y, attn_mask=None, x_mask=None, y_mask=None, attn_prev=None):
"""Forward pass of the Attention Layer.
Parameters
----------
x: torch.Tensor
Tokens of x tensor of shape (batch_size, num_points_x, embed_dim)
y: torch.Tensor
Tokens of y tensor of shape (batch_size, num_points_y, embed_dim)
attn_mask: torch.Tensor
Mask tensor for the attention weights of shape (batch_size, num_heads, num_points_x, num_points_y) or (batch_size, num_points_x, num_points_y).
If (batch_size, num_points_x, num_points_y) the mask will be broadcasted over the num_heads dimension
x_mask: torch.Tensor
Mask tensor of x of shape (batch_size, num_points_x)
y_mask: torch.Tensor
Mask tensor of y of shape (batch_size, num_points_y)
attn_prev: torch.Tensor
Attention weights of the previous layer to be used in the residual attention of shape (batch_size, num_heads, num_points_x, num_points_y)
Returns
-------
output: torch.Tensor
Attention output tensor of shape (batch_size, num_points_x, embed_dim)
hiddens: dict
Intermediate activations of the attention mechanism
"""
if x_mask is not None:
if y_mask is None:
y_mask = x_mask
input_mask = (
(x_mask.float().unsqueeze(-1))
.bmm(y_mask.float().unsqueeze(-1).transpose(-1, -2))
.long()
.unsqueeze(1)
)
if attn_mask is None:
attn_mask = input_mask
else:
attn_mask = attn_mask & input_mask
attn_output, hiddens = self.attn(
self.to_q(x),
self.to_k(y),
self.to_v(y),
attn_mask=attn_mask,
attn_prev=attn_prev,
)
attn_output = self.norm1(x + attn_output)
output = self.feed_forward(attn_output)
output = self.norm2(attn_output + output)
return output, hiddens
[docs]
class Transformer(torch.nn.Module):
"""Transformer for feature extraction.
Parameters
----------
embed_dim: int
Dimension of the used embedding
num_heads: int
Number of attention heads
num_layers: int
Number of attention layers
dropout: float
Dropout rate
residual: bool
Boolean control the use of the residual attention
"""
def __init__(
self,
embed_dim,
num_heads,
num_layers,
dropout=0.0,
residual=False,
):
super(Transformer, self).__init__()
self.layers = torch.nn.ModuleList(
[
AttentionLayer(
embed_dim,
num_heads,
dropout,
)
for _ in range(num_layers)
]
)
self.residual = residual
[docs]
def forward(
self, x, y, attn_mask=None, x_mask=None, y_mask=None, return_hiddens=False
):
"""Forward pass of the Transformer.
Parameters
----------
x: torch.Tensor
Tokens of x tensor of shape (batch_size, num_points_x, embed_dim)
y: torch.Tensor
Tokens of y tensor of shape (batch_size, num_points_y, embed_dim)
attn_mask: torch.Tensor
Mask tensor for the attention weights of shape (batch_size, num_heads, num_points_x, num_points_y) or (batch_size, num_points_x, num_points_y).
If (batch_size, num_points_x, num_points_y) the mask will be broadcasted over the num_heads dimension
x_mask: torch.Tensor
Mask tensor of x of shape (batch_size, num_points_x)
y_mask: torch.Tensor
Mask tensor of y of shape (batch_size, num_points_y)
return_hiddens: bool
Boolean to return intermediate activations of the attention mechanism. If True the output is a tuple (output, hiddens)
Returns
-------
output: torch.Tensor
Attention output tensor of shape (batch_size, num_points_x, embed_dim)
hiddens: dict
Intermediate activations of the attention mechanism
"""
attn_hiddens = []
for layer in self.layers:
if self.residual and len(attn_hiddens) > 0:
attn_prev = attn_hiddens[-1]["attn"]
else:
attn_prev = None
output, hiddens = layer(
x,
y,
attn_mask=attn_mask,
x_mask=x_mask,
y_mask=y_mask,
attn_prev=attn_prev,
)
attn_hiddens.append(hiddens)
if return_hiddens:
return output, attn_hiddens
else:
return output