"""Plotly plotting module."""
import plotly.graph_objects as go
from matplotlib import colors as mcolors
from geomfum.plot import ShapePlotter
from geomfum.shape.convert import to_go_mesh3d, to_go_pointcloud
[docs]
class PlotlyShapePlotter(ShapePlotter):
"""Base plotting object for 3D shapes using Plotly."""
def __init__(self, colormap="viridis"):
self.colormap = colormap
self._plotter = self.fig = go.Figure(
data=[],
layout=go.Layout(scene=dict(aspectmode="data")),
)
[docs]
def highlight_vertices(self, coords, color="red", size=4):
"""Highlight vertices on shape.
Parameters
----------
coords : array-like, shape=[n_vertices, 3]
Coordinates of vertices to highlight.
color : str
Color of the highlighted vertices.
size : int
Size of the highlighted vertices.
"""
marker = go.Scatter3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
mode="markers",
marker=dict(size=size, color=color),
name="Highlighted_points",
)
self._plotter.add_trace(marker)
return self
[docs]
def add_vectors(self, origins, vectors, color="blue", scale=1.0, name="vectors"):
"""Add vector field visualization.
Parameters
----------
origins : array-like, shape=[n_points, 3]
Starting points for vectors.
vectors : array-like, shape=[n_points, 3]
Vector directions and magnitudes.
color : str
Color of the vectors.
scale : float
Scale factor for vector length.
name : str
Name for the vector trace.
"""
# Create lines for each vector
x_lines, y_lines, z_lines = [], [], []
for i in range(len(origins)):
start = origins[i]
end = start + scale * vectors[i]
# Add line from start to end
x_lines.extend([start[0], end[0], None])
y_lines.extend([start[1], end[1], None])
z_lines.extend([start[2], end[2], None])
vector_trace = go.Scatter3d(
x=x_lines,
y=y_lines,
z=z_lines,
mode="lines",
line=dict(color=color, width=3),
name=name,
)
self._plotter.add_trace(vector_trace)
return self
[docs]
def set_colormap(self, colormap):
"""Update the colormap.
Parameters
----------
colormap : str
Name of the colormap to use.
"""
self.colormap = colormap
if len(self._plotter.data) > 0:
self._plotter.data[0]["colorscale"] = colormap
return self
[docs]
def show(self):
"""Display plot."""
self._plotter.show()
[docs]
def save(self, filename, **kwargs):
"""Save plot to file.
Parameters
----------
filename : str
Filename to save to.
**kwargs
Additional arguments passed to plotly's write_html or write_image.
"""
if filename.endswith(".html"):
self._plotter.write_html(filename, **kwargs)
else:
self._plotter.write_image(filename, **kwargs)
return self
[docs]
class PlotlyMeshPlotter(PlotlyShapePlotter):
"""Plotting object to display meshes."""
def __init__(self, colormap="viridis"):
super().__init__(colormap=colormap)
[docs]
def add_mesh(self, mesh, **kwargs):
"""Add mesh to plot.
Parameters
----------
mesh : TriangleMesh
Mesh to be plotted.
"""
plotly_obj = to_go_mesh3d(mesh)
plotly_obj.update(colorscale=self.colormap, **kwargs)
self._plotter.update(data=[plotly_obj])
# Add hover text with vertex indices
hover_text = [f"Index: {index}" for index in range(len(mesh.vertices))]
self._plotter.data[0]["text"] = hover_text
return self
[docs]
def set_vertex_scalars(self, scalars, name="scalars"):
"""Set vertex scalars on mesh."""
data = self._plotter.data[0]
data["intensity"] = scalars
data["colorscale"] = self.colormap
self._plotter.data[0].update(data)
return self
[docs]
def set_vertex_colors(self, colors):
"""Set vertex colors on mesh."""
data = self._plotter.data[0]
data["vertexcolor"] = colors
self._plotter.data[0].update(data)
return self
[docs]
class PlotlyPointCloudPlotter(PlotlyShapePlotter):
"""Plotting object to display point clouds."""
def __init__(self, colormap="viridis"):
super().__init__(colormap=colormap)
[docs]
def add_point_cloud(self, pointcloud, **kwargs):
"""Add point cloud to plot.
Parameters
----------
pointcloud : PointCloud
Point cloud to be plotted.
"""
plotly_obj = to_go_pointcloud(pointcloud)
plotly_obj.update(marker=dict(colorscale=self.colormap), **kwargs)
self._plotter.update(data=[plotly_obj])
# Add hover text with vertex indices
hover_text = [f"Index: {index}" for index in range(len(pointcloud.vertices))]
self._plotter.data[0]["text"] = hover_text
return self
[docs]
def set_vertex_scalars(self, scalars, name="scalars"):
"""Set vertex scalars on point cloud."""
data = self._plotter.data[0]
data["marker"]["color"] = scalars
data["marker"]["colorscale"] = self.colormap
self._plotter.data[0].update(data)
return self
[docs]
def set_vertex_colors(self, colors):
"""Set vertex colors on point cloud."""
data = self._plotter.data[0]
colors_hex = [mcolors.rgb2hex(color) for color in colors]
data["marker"]["color"] = colors_hex
self._plotter.data[0].update(data)
return self