diff --git a/examples/0_Introduction_to_K3d_Widgets.ipynb b/examples/0_Introduction_to_K3d_Widgets.ipynb
new file mode 100644
index 0000000..cff64e9
--- /dev/null
+++ b/examples/0_Introduction_to_K3d_Widgets.ipynb
@@ -0,0 +1,664 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import menpo3d.io as m3io\n",
+ "import menpo.io as mio\n",
+ "from menpo.shape import PointCloud, ColouredTriMesh\n",
+ "from menpo.landmark import face_ibug_68_to_face_ibug_68\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
Load the data (Mesh, landmarks and model)
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh = m3io.import_mesh('../menpo3d/data/james.obj')\n",
+ "lms = m3io.import_landmark_file('../menpo3d/data/james.ljson')['LJSON']\n",
+ "\n",
+ "# Load model and its landmarks indices \n",
+ "model = mio.import_pickle('../menpo3d/data/3DMD_all_all_all_10.pkl')['model']\n",
+ "lms_indices = [21868, 22404, 22298, 22327, 43430, 45175, 46312, 47132, 47911, 48692,\n",
+ " 49737, 51376, 53136, 32516, 32616, 32205, 32701, 38910, 39396, 39693,\n",
+ " 39934, 40131, 40843, 41006, 41179, 41430, 13399, 8161, 8172, 8179, 8185,\n",
+ " 5622, 6881, 8202, 9403, 10764, 1831, 3887, 5049, 6214, 4805, 3643, 9955,\n",
+ " 11095, 12255, 14197, 12397, 11366, 5779, 6024, 7014, 8215, 9294, 10267,\n",
+ " 10922, 9556, 8836, 8236, 7636, 6794, 5905, 7264, 8223, 9063, 10404, 8828,\n",
+ " 8228, 7509]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Create new random instances
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cov = np.diag(model.eigenvalues)\n",
+ "model_mean = model.mean()\n",
+ "synthetic_weights = np.random.multivariate_normal(np.zeros(model.n_active_components),\n",
+ " cov, 1000)\n",
+ "random_mesh = model.instance(synthetic_weights[5])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Show the mesh
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "# Default values for TriMesh, TextureMesh viewer are\n",
+ "# figure_id None\n",
+ "# new_figure True\n",
+ "# in that case an automatic figure_id will be given\n",
+ "# with 'Figure_{n}' format\n",
+ "# n will be an increased integer starting from zero\n",
+ "mesh.view() # wait a bit before magic happens"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Show the mesh and landmarks
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "mesh.view(figure_id='James')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Add landmarks to figure with id James\n",
+ "lms_poincloud = PointCloud(lms.points)\n",
+ "lms_poincloud.view(figure_id='James',new_figure=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Add landmarks to figure with id Figure_0\n",
+ "lms_poincloud.view(figure_id='Figure_0', new_figure=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Show a mesh that has landmarks
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh.landmarks = lms"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lms.view(new_figure=True, render_numbering=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# The mesh has now landmarks, so they would be plotted as well\n",
+ "# the figure id is now Figure_2\n",
+ "mesh.view()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Show the TexturedMesh without texture\n",
+ "mesh.view(render_texture=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh.view(render_texture=False, mesh_type='wireframe')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "random_mesh.view(mesh_type='surface')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " HeatMaps
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Heatmap between a random mesh and mean mesh\n",
+ "random_mesh.heatmap(model_mean)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Heatmap with statistics \n",
+ "# Be careful, since we have already drawn a heatmap between\n",
+ "# random and mean, we should use another name for figure\n",
+ "random_mesh.heatmap(model_mean, show_statistics=True, figure_id='Heatmap2')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Heatmap with landmarks\n",
+ "model_mean.landmarks = face_ibug_68_to_face_ibug_68(PointCloud(model_mean.points[lms_indices]))\n",
+ "model_mean.heatmap(random_mesh, inline=True, show_statistics=True, figure_id='Heatmap3')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "random_mesh_landmarks = face_ibug_68_to_face_ibug_68(random_mesh.points[lms_indices])\n",
+ "random_mesh_landmarks.view(inline=True, new_figure=False, figure_id='Heatmap2')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Show Normals
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pts = random_mesh.points[lms_indices]\n",
+ "vrt = np.zeros((random_mesh.n_points,3))\n",
+ "vrt[lms_indices] = random_mesh.vertex_normals()[lms_indices] / 5"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "random_mesh.view(normals=vrt, \n",
+ " normals_marker_size= 0.5,\n",
+ " normals_line_width = 0.01,\n",
+ " figure_id='Normals')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "random_mesh_landmarks.view(figure_id='Normals', new_figure=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "random_mesh.landmarks = random_mesh_landmarks\n",
+ "random_mesh.view( normals=vrt, \n",
+ " normals_marker_size= 0.5,\n",
+ " normals_line_width = 0.01)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Show PointCloud with colours
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from IPython.utils.tempdir import TemporaryWorkingDirectory\n",
+ "with TemporaryWorkingDirectory() as tmpdir:\n",
+ " !mkdir -p ./data/PittsburghBridge\n",
+ " !wget -P ./data/PittsburghBridge https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz\n",
+ " pointcloud = np.load('./data/PittsburghBridge/pointcloud.npz')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "points = pointcloud['verts']\n",
+ "colours = pointcloud['rgb']\n",
+ "\n",
+ "coloured_pointcloud = PointCloud(points)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "coloured_pointcloud.view(colours=colours)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Show Surface
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Show ColouredTriMesh
\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "colors = np.random.rand(random_mesh.n_points)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "colors[0:1000]= 0.1\n",
+ "colors[1000:10000]= 0.2\n",
+ "colors[1000:10000]= 0.4\n",
+ "colors[10000:30000]= 0.6\n",
+ "colors[30000:40000]= 0.8\n",
+ "colors[40000:50000]= 0.8\n",
+ "colors[50000:]= 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_mesh = ColouredTriMesh(random_mesh.points, random_mesh.trilist, colours=colors)\n",
+ "new_mesh.view()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_mesh.landmarks = random_mesh_landmarks"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Show Graphs
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from menpo.shape import PointUndirectedGraph \n",
+ "import numpy as np\n",
+ "points = np.array([[10, 30, 10], [0, 20, 11], [20, 20, 11], [0, 10, 12], [20, 10, 12], [0, 0, 12]]) \n",
+ "edges = np.array([[0, 1], [1, 0], [0, 2], [2, 0], [1, 2], [2, 1], \n",
+ " [1, 3], [3, 1], [2, 4], [4, 2], [3, 4], [4, 3],[3, 5], [5, 3]]) \n",
+ "colors = [\n",
+ " 0xff,\n",
+ " 0xffff,\n",
+ " 0xff00ff,\n",
+ " 0x00ffff,\n",
+ " 0xffff00,]\n",
+ "\n",
+ "graph = PointUndirectedGraph.init_from_edges(points, edges) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "graph.view()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "graph.view(line_colour=colors, render_numbering=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " View a Morphable Model
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "model.view(figure_id='Model', n_parameters=10, landmarks_indices=lms_indices)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "random_mesh.view()\n",
+ "lms.view(new_figure=False)\n",
+ "mesh.view(new_figure=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Load big meshes
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh = m3io.import_mesh('/data/meshes/mesh/33_plain.obj')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh.view( )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh_hq = m3io.import_mesh('/data/meshes/mesh/HQ_mesh_alex.obj')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh_hq"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh_hq.view()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Fail cases (supposed you have already executed all the above cells)
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# It should fail if the previous cells have been executed, as default values for landmarker viewer are\n",
+ "# figure_id = None and new_figure=False, so it could not\n",
+ "# find a figure with id None\n",
+ "lms.view()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# It should fail if the previous cells have been executed, as we have already had a figure with id \n",
+ "# James and we cannot create a new one with the same figure_id\n",
+ "mesh.view(figure_id='James', new_figure=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# It should fail if the previous cells have been executed, as we have already had a figure with id \n",
+ "# Model and we cannot create a new one with the same figure_id\n",
+ "model.view(inline=True, figure_id='Model')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# You have already created a heatmap between random_mesh and model_mean\n",
+ "random_mesh.heatmap(model_mean, inline=True)`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Additional functions
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from menpo3d.visualize import list_figures, dict_figures"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_figures()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dict_figures()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " Testing
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ipywidgets import Widget\n",
+ "from menpo3d.visualize.viewk3dwidgets import K3dwidgetsRenderer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for x in Widget.widgets.values():\n",
+ " print(type(x),x.model_id)\n",
+ "# if isinstance(x, K3dwidgetsRenderer):\n",
+ " if hasattr(x,'figure_id'):\n",
+ " print(type(x),x.model_id, x.figure_id)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_figures()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/menpo3d/visualize/__init__.py b/menpo3d/visualize/__init__.py
index 87e7761..d221a22 100644
--- a/menpo3d/visualize/__init__.py
+++ b/menpo3d/visualize/__init__.py
@@ -1,8 +1,10 @@
-from .base import (
- PointGraphViewer3d,
- TriMeshViewer3d,
- VectorViewer3d,
- ColouredTriMeshViewer3d,
- TexturedTriMeshViewer3d,
- LandmarkViewer3d,
-)
+from .base import (PointGraphViewer3d, TriMeshViewer3d, VectorViewer3d,
+ ColouredTriMeshViewer3d, TexturedTriMeshViewer3d,
+ LandmarkViewer3d, HeatmapViewer3d,
+ TriMeshInlineViewer3d, ColouredTriMeshInlineViewer3d,
+ PointGraphInlineViewer3d, TexturedTriMeshInlineViewer3d,
+ LandmarkInlineViewer3d, PointGraphInlineViewer3d,
+ VectorInlineViewer3d, HeatmapInlineViewer3d,
+ PCAModelInlineViewer3d)
+
+from .viewk3dwidgets import (list_figures, clear_figure, dict_figures)
diff --git a/menpo3d/visualize/base.py b/menpo3d/visualize/base.py
index 696c898..b7f524c 100644
--- a/menpo3d/visualize/base.py
+++ b/menpo3d/visualize/base.py
@@ -1,11 +1,17 @@
from .viewmayavi import (
- MayaviTriMeshViewer3d,
- MayaviPointGraphViewer3d,
- MayaviTexturedTriMeshViewer3d,
- MayaviLandmarkViewer3d,
- MayaviVectorViewer3d,
- MayaviColouredTriMeshViewer3d,
-)
+ MayaviTriMeshViewer3d, MayaviPointGraphViewer3d,
+ MayaviTexturedTriMeshViewer3d, MayaviLandmarkViewer3d,
+ MayaviVectorViewer3d, MayaviColouredTriMeshViewer3d, MayaviHeatmapViewer3d)
+
+
+from .viewk3dwidgets import (K3dwidgetsTriMeshViewer3d,
+ K3dwidgetsPointGraphViewer3d,
+ K3dwidgetsVectorViewer3d,
+ K3dwidgetsLandmarkViewer3d,
+ K3dwidgetsTexturedTriMeshViewer3d,
+ K3dwidgetsColouredTriMeshViewer3d,
+ K3dwidgetsHeatmapViewer3d,
+ K3dwidgetsPCAModelViewer3d)
PointGraphViewer3d = MayaviPointGraphViewer3d
TriMeshViewer3d = MayaviTriMeshViewer3d
@@ -13,3 +19,13 @@
ColouredTriMeshViewer3d = MayaviColouredTriMeshViewer3d
LandmarkViewer3d = MayaviLandmarkViewer3d
VectorViewer3d = MayaviVectorViewer3d
+HeatmapViewer3d = MayaviHeatmapViewer3d
+
+TriMeshInlineViewer3d = K3dwidgetsTriMeshViewer3d
+TexturedTriMeshInlineViewer3d = K3dwidgetsTexturedTriMeshViewer3d
+LandmarkInlineViewer3d = K3dwidgetsLandmarkViewer3d
+PointGraphInlineViewer3d = K3dwidgetsPointGraphViewer3d
+VectorInlineViewer3d = K3dwidgetsVectorViewer3d
+HeatmapInlineViewer3d = K3dwidgetsHeatmapViewer3d
+PCAModelInlineViewer3d = K3dwidgetsPCAModelViewer3d
+ColouredTriMeshInlineViewer3d = K3dwidgetsColouredTriMeshViewer3d
diff --git a/menpo3d/visualize/menpowidgets.py b/menpo3d/visualize/menpowidgets.py
new file mode 100644
index 0000000..749dba2
--- /dev/null
+++ b/menpo3d/visualize/menpowidgets.py
@@ -0,0 +1,828 @@
+from collections import OrderedDict
+from time import sleep
+from IPython import get_ipython
+from ipywidgets import Box
+import ipywidgets
+from traitlets.traitlets import List
+
+# The below classes have been copied from
+# the deprecated menpowidgets package
+# MenpoWidget can be found in abstract.py
+# LinearModelParametersWidget in options.py
+class MenpoWidget(Box):
+ r"""
+ Base class for defining a Menpo widget.
+
+ The widget has a `selected_values` trait that can be used in order to
+ inspect any changes that occur to its children. It also has functionality
+ for adding, removing, replacing or calling the handler callback function of
+ the `selected_values` trait.
+
+ Parameters
+ ----------
+ children : `list` of `ipywidgets`
+ The `list` of `ipywidgets` objects to be set as children in the
+ `ipywidgets.Box`.
+ trait : `traitlets.TraitType` subclass
+ The type of the `selected_values` object that gets added as a trait
+ in the widget. Possible options from `traitlets` are {``Int``, ``Float``,
+ ``Dict``, ``List``, ``Tuple``}.
+ trait_initial_value : `int` or `float` or `dict` or `list` or `tuple`
+ The initial value of the `selected_values` trait.
+ render_function : `callable` or ``None``, optional
+ The render function that behaves as a callback handler of the
+ `selected_values` trait for the `change` event. Its signature can be
+ ``render_function()`` or ``render_function(change)``, where ``change``
+ is a `dict` with the following keys:
+
+ - ``owner`` : the `HasTraits` instance
+ - ``old`` : the old value of the modified trait attribute
+ - ``new`` : the new value of the modified trait attribute
+ - ``name`` : the name of the modified trait attribute.
+ - ``type`` : ``'change'``
+
+ If ``None``, then nothing is added.
+ """
+ def __init__(self, children, trait, trait_initial_value,
+ render_function=None):
+ # Create box object
+ super(MenpoWidget, self).__init__(children=children)
+
+ # Add trait for selected values
+ selected_values = trait(default_value=trait_initial_value)
+ selected_values_trait = {'selected_values': selected_values}
+ self.add_traits(**selected_values_trait)
+ self.selected_values = trait_initial_value
+
+ # Set render function
+ self._render_function = None
+ self.add_render_function(render_function)
+
+ def add_render_function(self, render_function):
+ r"""
+ Method that adds the provided `render_function()` as a callback handler
+ to the `selected_values` trait of the widget. The given function is
+ also stored in `self._render_function`.
+
+ Parameters
+ ----------
+ render_function : `callable` or ``None``, optional
+ The render function that behaves as a callback handler of the
+ `selected_values` trait for the `change` event. Its signature can be
+ ``render_function()`` or ``render_function(change)``, where
+ ``change`` is a `dict` with the following keys:
+
+ - ``owner`` : the `HasTraits` instance
+ - ``old`` : the old value of the modified trait attribute
+ - ``new`` : the new value of the modified trait attribute
+ - ``name`` : the name of the modified trait attribute.
+ - ``type`` : ``'change'``
+
+ If ``None``, then nothing is added.
+ """
+ self._render_function = render_function
+ if self._render_function is not None:
+ self.observe(self._render_function, names='selected_values',
+ type='change')
+
+ def remove_render_function(self):
+ r"""
+ Method that removes the current `self._render_function()` as a callback
+ handler to the `selected_values` trait of the widget and sets
+ ``self._render_function = None``.
+ """
+ if self._render_function is not None:
+ self.unobserve(self._render_function, names='selected_values',
+ type='change')
+ self._render_function = None
+
+ def replace_render_function(self, render_function):
+ r"""
+ Method that replaces the current `self._render_function()` with the
+ given `render_function()` as a callback handler to the `selected_values`
+ trait of the widget.
+
+ Parameters
+ ----------
+ render_function : `callable` or ``None``, optional
+ The render function that behaves as a callback handler of the
+ `selected_values` trait for the `change` event. Its signature can be
+ ``render_function()`` or ``render_function(change)``, where
+ ``change`` is a `dict` with the following keys:
+
+ - ``owner`` : the `HasTraits` instance
+ - ``old`` : the old value of the modified trait attribute
+ - ``new`` : the new value of the modified trait attribute
+ - ``name`` : the name of the modified trait attribute.
+ - ``type`` : ``'change'``
+
+ If ``None``, then nothing is added.
+ """
+ # remove old function
+ self.remove_render_function()
+
+ # add new function
+ self.add_render_function(render_function)
+
+ def call_render_function(self, old_value, new_value, type_value='change'):
+ r"""
+ Method that calls the existing `render_function()` callback handler.
+
+ Parameters
+ ----------
+ old_value : `int` or `float` or `dict` or `list` or `tuple`
+ The old `selected_values` value.
+ new_value : `int` or `float` or `dict` or `list` or `tuple`
+ The new `selected_values` value.
+ type_value : `str`, optional
+ The trait event type.
+ """
+ if self._render_function is not None:
+ change_dict = {'type': 'change', 'old': old_value,
+ 'name': type_value, 'new': new_value,
+ 'owner': self.__str__()}
+ self._render_function(change_dict)
+
+
+class LinearModelParametersWidget(MenpoWidget):
+ r"""
+ Creates a widget for selecting parameters values when visualizing a linear
+ model (e.g. PCA model).
+
+ Note that:
+
+ * To update the state of the widget, please refer to the
+ :meth:`set_widget_state` method.
+ * The selected values are stored in the ``self.selected_values`` `trait`
+ which is a `list`.
+ * To set the styling of this widget please refer to the
+ :meth:`predefined_style` method.
+ * To update the handler callback functions of the widget, please refer to
+ the :meth:`replace_render_function` and :meth:`replace_variance_function`
+ methods.
+
+ Parameters
+ ----------
+ n_parameters : `int`
+ The `list` of initial parameters values.
+ render_function : `callable` or ``None``, optional
+ The render function that is executed when a widgets' value changes.
+ It must have signature ``render_function(change)`` where ``change`` is
+ a `dict` with the following keys:
+
+ * ``type`` : The type of notification (normally ``'change'``).
+ * ``owner`` : the `HasTraits` instance
+ * ``old`` : the old value of the modified trait attribute
+ * ``new`` : the new value of the modified trait attribute
+ * ``name`` : the name of the modified trait attribute.
+
+ If ``None``, then nothing is assigned.
+ mode : ``{'single', 'multiple'}``, optional
+ If ``'single'``, only a single slider is constructed along with a
+ dropdown menu that allows the parameter selection.
+ If ``'multiple'``, a slider is constructed for each parameter.
+ params_str : `str`, optional
+ The string that will be used as description of the slider(s). The final
+ description has the form ``"{}{}".format(params_str, p)``, where ``p``
+ is the parameter number.
+ params_bounds : (`float`, `float`), optional
+ The minimum and maximum bounds, in std units, for the sliders.
+ params_step : `float`, optional
+ The step, in std units, of the sliders.
+ plot_variance_visible : `bool`, optional
+ Defines whether the button for plotting the variance will be visible
+ upon construction.
+ plot_variance_function : `callable` or ``None``, optional
+ The plot function that is executed when the plot variance button is
+ clicked. If ``None``, then nothing is assigned.
+ animation_visible : `bool`, optional
+ Defines whether the animation options will be visible.
+ loop_enabled : `bool`, optional
+ If ``True``, then the repeat mode of the animation is enabled.
+ interval : `float`, optional
+ The interval between the animation progress in seconds.
+ interval_step : `float`, optional
+ The interval step (in seconds) that is applied when fast
+ forward/backward buttons are pressed.
+ animation_step : `float`, optional
+ The parameters step that is applied when animation is enabled.
+ style : `str` (see below), optional
+ Sets a predefined style at the widget. Possible options are:
+
+ ============= ==================
+ Style Description
+ ============= ==================
+ ``'success'`` Green-based style
+ ``'info'`` Blue-based style
+ ``'warning'`` Yellow-based style
+ ``'danger'`` Red-based style
+ ``''`` No style
+ ============= ==================
+
+ continuous_update : `bool`, optional
+ If ``True``, then the render function is called while moving a
+ slider's handle. If ``False``, then the the function is called only
+ when the handle (mouse click) is released.
+
+ Example
+ -------
+ Let's create a linear model parameters values widget and then update its
+ state. Firstly, we need to import it:
+
+ >>> from menpowidgets.options import LinearModelParametersWidget
+
+ Now let's define a render function that will get called on every widget
+ change and will dynamically print the selected parameters:
+
+ >>> from menpo.visualize import print_dynamic
+ >>> def render_function(change):
+ >>> s = "Selected parameters: {}".format(wid.selected_values)
+ >>> print_dynamic(s)
+
+ Create the widget with some initial options and display it:
+
+ >>> wid = LinearModelParametersWidget(n_parameters=5,
+ >>> render_function=render_function,
+ >>> params_str='Parameter ',
+ >>> mode='multiple',
+ >>> params_bounds=(-3., 3.),
+ >>> plot_variance_visible=True,
+ >>> style='info')
+ >>> wid
+
+ By moving the sliders, the printed message gets updated. Finally, let's
+ change the widget status with a new set of options:
+
+ >>> wid.set_widget_state(n_parameters=10, params_str='',
+ >>> params_step=0.1, params_bounds=(-10, 10),
+ >>> plot_variance_visible=False,
+ >>> allow_callback=True)
+ """
+ def __init__(self, n_parameters, render_function=None, mode='multiple',
+ params_str='Parameter ', params_bounds=(-3., 3.),
+ params_step=0.1, plot_variance_visible=True,
+ plot_variance_function=None, animation_visible=True,
+ loop_enabled=False, interval=0., interval_step=0.05,
+ animation_step=0.5, style='', continuous_update=False):
+
+ # Get the kernel to use it later in order to make sure that the widgets'
+ # traits changes are passed during a while-loop
+ self.kernel = get_ipython().kernel
+
+ # If only one slider requested, then set mode to multiple
+ if n_parameters == 1:
+ mode = 'multiple'
+
+ # Create children
+ if mode == 'multiple':
+ self.sliders = []
+ self.parameters_children = []
+ for p in range(n_parameters):
+ slider_title = ipywidgets.HTML(
+ value="{}{}".format(params_str, p))
+ slider_wid = ipywidgets.FloatSlider(
+ description='', min=params_bounds[0], max=params_bounds[1],
+ step=params_step, value=0.,
+ continuous_update=continuous_update,
+ layout=ipywidgets.Layout(width='8cm'))
+ tmp = ipywidgets.HBox([slider_title, slider_wid])
+ tmp.layout.align_items = 'center'
+ self.sliders.append(slider_wid)
+ self.parameters_children.append(tmp)
+ self.parameters_wid = ipywidgets.VBox(self.parameters_children)
+ self.parameters_wid.layout.align_items = 'flex-end'
+ else:
+ vals = OrderedDict()
+ for p in range(n_parameters):
+ vals["{}{}".format(params_str, p)] = p
+ self.slider = ipywidgets.FloatSlider(
+ description='', min=params_bounds[0], max=params_bounds[1],
+ step=params_step, value=0., readout=True,
+ layout=ipywidgets.Layout(width='8cm'),
+ continuous_update=continuous_update)
+ self.dropdown_params = ipywidgets.Dropdown(
+ options=vals, layout=ipywidgets.Layout(width='3cm'))
+ self.dropdown_params.layout.margin = '0px 10px 0px 0px'
+ self.parameters_wid = ipywidgets.HBox([self.dropdown_params,
+ self.slider])
+ self.parameters_wid.layout.margin = '0px 0px 10px 0px'
+ self.plot_button = ipywidgets.Button(
+ description='Variance', layout=ipywidgets.Layout(width='80px'))
+ self.plot_button.layout.display = (
+ 'inline' if plot_variance_visible else 'none')
+ self.reset_button = ipywidgets.Button(
+ description='Reset', layout=ipywidgets.Layout(width='80px'))
+ self.plot_and_reset = ipywidgets.HBox([self.reset_button,
+ self.plot_button])
+ self.play_button = ipywidgets.Button(
+ icon='play', description='', tooltip='Play animation',
+ layout=ipywidgets.Layout(width='40px'))
+ self.stop_button = ipywidgets.Button(
+ icon='stop', description='', tooltip='Stop animation',
+ layout=ipywidgets.Layout(width='40px'))
+ self.fast_forward_button = ipywidgets.Button(
+ icon='fast-forward', description='',
+ layout=ipywidgets.Layout(width='40px'),
+ tooltip='Increase animation speed')
+ self.fast_backward_button = ipywidgets.Button(
+ icon='fast-backward', description='',
+ layout=ipywidgets.Layout(width='40px'),
+ tooltip='Decrease animation speed')
+ loop_icon = 'repeat' if loop_enabled else 'long-arrow-right'
+ self.loop_toggle = ipywidgets.ToggleButton(
+ icon=loop_icon, description='', value=loop_enabled,
+ layout=ipywidgets.Layout(width='40px'), tooltip='Repeat animation')
+ self.animation_buttons = ipywidgets.HBox(
+ [self.play_button, self.stop_button, self.loop_toggle,
+ self.fast_backward_button, self.fast_forward_button])
+ self.animation_buttons.layout.display = (
+ 'flex' if animation_visible else 'none')
+ self.animation_buttons.layout.margin = '0px 15px 0px 0px'
+ self.buttons_box = ipywidgets.HBox([self.animation_buttons,
+ self.plot_and_reset])
+ self.container = ipywidgets.VBox([self.parameters_wid,
+ self.buttons_box])
+
+ # Create final widget
+ super(LinearModelParametersWidget, self).__init__(
+ [self.container], List, [0.] * n_parameters,
+ render_function=render_function)
+
+ # Assign output
+ self.n_parameters = n_parameters
+ self.mode = mode
+ self.params_str = params_str
+ self.params_bounds = params_bounds
+ self.params_step = params_step
+ self.plot_variance_visible = plot_variance_visible
+ self.loop_enabled = loop_enabled
+ self.continuous_update = continuous_update
+ self.interval = interval
+ self.interval_step = interval_step
+ self.animation_step = animation_step
+ self.animation_visible = animation_visible
+ self.please_stop = False
+
+ # Set style
+ self.predefined_style(style)
+
+ # Set functionality
+ if mode == 'single':
+ # Assign slider value to parameters values list
+ def save_slider_value(change):
+ current_parameters = list(self.selected_values)
+ current_parameters[self.dropdown_params.value] = change['new']
+ self.selected_values = current_parameters
+ self.slider.observe(save_slider_value, names='value', type='change')
+
+ # Set correct value to slider when drop down menu value changes
+ def set_slider_value(change):
+ # Temporarily remove render callback
+ render_fun = self._render_function
+ self.remove_render_function()
+ # Set slider value
+ self.slider.value = self.selected_values[change['new']]
+ # Re-assign render callback
+ self.add_render_function(render_fun)
+ self.dropdown_params.observe(set_slider_value, names='value',
+ type='change')
+ else:
+ # Assign saving values and main plotting function to all sliders
+ for w in self.sliders:
+ w.observe(self._save_slider_value_from_id, names='value',
+ type='change')
+
+ def reset_parameters(name):
+ # Keep old value
+ old_value = self.selected_values
+
+ # Temporarily remove render callback
+ render_fun = self._render_function
+ self.remove_render_function()
+
+ # Set parameters to 0
+ self.selected_values = [0.0] * self.n_parameters
+ if mode == 'multiple':
+ for ww in self.sliders:
+ ww.value = 0.
+ else:
+ self.parameters_wid.children[0].value = 0
+ self.parameters_wid.children[1].value = 0.
+
+ # Re-assign render callback and trigger it
+ self.add_render_function(render_fun)
+ self.call_render_function(old_value, self.selected_values)
+ self.reset_button.on_click(reset_parameters)
+
+ # Set functionality
+ def loop_pressed(change):
+ if change['new']:
+ self.loop_toggle.icon = 'repeat'
+ else:
+ self.loop_toggle.icon = 'long-arrow-right'
+ self.kernel.do_one_iteration()
+ self.loop_toggle.observe(loop_pressed, names='value', type='change')
+
+ def fast_forward_pressed(name):
+ tmp = self.interval
+ tmp -= self.interval_step
+ if tmp < 0:
+ tmp = 0
+ self.interval = tmp
+ self.kernel.do_one_iteration()
+ self.fast_forward_button.on_click(fast_forward_pressed)
+
+ def fast_backward_pressed(name):
+ self.interval += self.interval_step
+ self.kernel.do_one_iteration()
+ self.fast_backward_button.on_click(fast_backward_pressed)
+
+ def animate(change):
+ reset_parameters('')
+ self.please_stop = False
+ self.reset_button.disabled = True
+ self.plot_button.disabled = True
+ if mode == 'multiple':
+ n_sliders = self.n_parameters
+ slider_id = 0
+ while slider_id < n_sliders:
+ # animate from 0 to min
+ slider_val = 0.
+ while slider_val > self.params_bounds[0]:
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # Check stop flag
+ if self.please_stop:
+ break
+
+ # update slider value
+ slider_val -= self.animation_step
+
+ # set value
+ self.sliders[slider_id].value = slider_val
+
+ # wait
+ sleep(self.interval)
+
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # animate from min to max
+ slider_val = self.params_bounds[0]
+ while slider_val < self.params_bounds[1]:
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # Check stop flag
+ if self.please_stop:
+ break
+
+ # update slider value
+ slider_val += self.animation_step
+
+ # set value
+ self.sliders[slider_id].value = slider_val
+
+ # wait
+ sleep(self.interval)
+
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # animate from max to 0
+ slider_val = self.params_bounds[1]
+ while slider_val > 0.:
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # Check stop flag
+ if self.please_stop:
+ break
+
+ # update slider value
+ slider_val -= self.animation_step
+
+ # set value
+ self.sliders[slider_id].value = slider_val
+
+ # wait
+ sleep(self.interval)
+
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # reset value
+ self.sliders[slider_id].value = 0.
+
+ # Check stop flag
+ if self.please_stop:
+ break
+
+ # update slider id
+ if self.loop_toggle.value and slider_id == n_sliders - 1:
+ slider_id = 0
+ else:
+ slider_id += 1
+
+ if not self.loop_toggle.value and slider_id >= n_sliders:
+ self.stop_animation()
+ else:
+ n_sliders = self.n_parameters
+ slider_id = 0
+ self.please_stop = False
+ while slider_id < n_sliders:
+ # set dropdown value
+ self.parameters_wid.children[0].value = slider_id
+
+ # animate from 0 to min
+ slider_val = 0.
+ while slider_val > self.params_bounds[0]:
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # Check stop flag
+ if self.please_stop:
+ break
+
+ # update slider value
+ slider_val -= self.animation_step
+
+ # set value
+ self.parameters_wid.children[1].value = slider_val
+
+ # wait
+ sleep(self.interval)
+
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # animate from min to max
+ slider_val = self.params_bounds[0]
+ while slider_val < self.params_bounds[1]:
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # Check stop flag
+ if self.please_stop:
+ break
+
+ # update slider value
+ slider_val += self.animation_step
+
+ # set value
+ self.parameters_wid.children[1].value = slider_val
+
+ # wait
+ sleep(self.interval)
+
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # animate from max to 0
+ slider_val = self.params_bounds[1]
+ while slider_val > 0.:
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # Check stop flag
+ if self.please_stop:
+ break
+
+ # update slider value
+ slider_val -= self.animation_step
+
+ # set value
+ self.parameters_wid.children[1].value = slider_val
+
+ # wait
+ sleep(self.interval)
+
+ # Run IPython iteration.
+ self.kernel.do_one_iteration()
+
+ # reset value
+ self.parameters_wid.children[1].value = 0.
+
+ # Check stop flag
+ if self.please_stop:
+ break
+
+ # update slider id
+ if self.loop_toggle.value and slider_id == n_sliders - 1:
+ slider_id = 0
+ else:
+ slider_id += 1
+ self.reset_button.disabled = False
+ self.plot_button.disabled = False
+ self.play_button.on_click(animate)
+
+ def stop_pressed(_):
+ self.stop_animation()
+ self.stop_button.on_click(stop_pressed)
+
+ # Set plot variance function
+ self._variance_function = None
+ self.add_variance_function(plot_variance_function)
+
+ def _save_slider_value_from_id(self, change):
+ current_parameters = list(self.selected_values)
+ i = self.sliders.index(change['owner'])
+ current_parameters[i] = change['new']
+ self.selected_values = current_parameters
+
+ def predefined_style(self, style):
+ r"""
+ Function that sets a predefined style on the widget.
+
+ Parameters
+ ----------
+ style : `str` (see below)
+ Style options:
+
+ ============= ==================
+ Style Description
+ ============= ==================
+ ``'success'`` Green-based style
+ ``'info'`` Blue-based style
+ ``'warning'`` Yellow-based style
+ ``'danger'`` Red-based style
+ ``''`` No style
+ ============= ==================
+ """
+ self.container.box_style = style
+ self.container.border = '0px'
+ self.play_button.button_style = 'success'
+ self.stop_button.button_style = 'danger'
+ self.fast_forward_button.button_style = 'info'
+ self.fast_backward_button.button_style = 'info'
+ self.loop_toggle.button_style = 'warning'
+ self.reset_button.button_style = 'danger'
+ self.plot_button.button_style = 'primary'
+
+ def stop_animation(self):
+ r"""
+ Method that stops an active annotation.
+ """
+ self.please_stop = True
+
+ def add_variance_function(self, variance_function):
+ r"""
+ Method that adds a `variance_function()` to the `Variance` button of the
+ widget. The given function is also stored in `self._variance_function`.
+
+ Parameters
+ ----------
+ variance_function : `callable` or ``None``, optional
+ The variance function that behaves as a callback. If ``None``,
+ then nothing is added.
+ """
+ self._variance_function = variance_function
+ if self._variance_function is not None:
+ self.plot_button.on_click(self._variance_function)
+
+ def remove_variance_function(self):
+ r"""
+ Method that removes the current `self._variance_function()` from
+ the `Variance` button of the widget and sets
+ ``self._variance_function = None``.
+ """
+ self.plot_button.on_click(self._variance_function, remove=True)
+ self._variance_function = None
+
+ def replace_variance_function(self, variance_function):
+ r"""
+ Method that replaces the current `self._variance_function()` of the
+ `Variance` button of the widget with the given `variance_function()`.
+
+ Parameters
+ ----------
+ variance_function : `callable` or ``None``, optional
+ The variance function that behaves as a callback. If ``None``,
+ then nothing happens.
+ """
+ # remove old function
+ self.remove_variance_function()
+
+ # add new function
+ self.add_variance_function(variance_function)
+
+ def set_widget_state(self, n_parameters=None, params_str=None,
+ params_bounds=None, params_step=None,
+ plot_variance_visible=True, animation_step=0.5,
+ allow_callback=True):
+ r"""
+ Method that updates the state of the widget with a new set of options.
+
+ Parameters
+ ----------
+ n_parameters : `int`
+ The `list` of initial parameters values.
+ params_str : `str`, optional
+ The string that will be used as description of the slider(s). The
+ final description has the form ``"{}{}".format(params_str, p)``,
+ where ``p`` is the parameter number.
+ params_bounds : (`float`, `float`), optional
+ The minimum and maximum bounds, in std units, for the sliders.
+ params_step : `float`, optional
+ The step, in std units, of the sliders.
+ plot_variance_visible : `bool`, optional
+ Defines whether the button for plotting the variance will be visible
+ upon construction.
+ animation_step : `float`, optional
+ The parameters step that is applied when animation is enabled.
+ allow_callback : `bool`, optional
+ If ``True``, it allows triggering of any callback functions.
+ """
+ # Keep old value
+ old_value = self.selected_values
+
+ # Temporarily remove render callback
+ render_function = self._render_function
+ self.remove_render_function()
+
+ # Parse given options
+ if n_parameters is None:
+ n_parameters = self.n_parameters
+ if params_str is None:
+ params_str = ''
+ if params_bounds is None:
+ params_bounds = self.params_bounds
+ if params_step is None:
+ params_step = self.params_step
+
+ # Set plot variance visibility
+ self.plot_button.layout.visibility = (
+ 'visible' if plot_variance_visible else 'hidden')
+ self.animation_step = animation_step
+
+ # Update widget
+ if n_parameters == self.n_parameters:
+ # The number of parameters hasn't changed
+ if self.mode == 'multiple':
+ for p, sl in enumerate(self.sliders):
+ self.parameters_children[p].children[0].value = \
+ "{}{}".format(params_str, p)
+ sl.min = params_bounds[0]
+ sl.max = params_bounds[1]
+ sl.step = params_step
+ else:
+ self.slider.min = params_bounds[0]
+ self.slider.max = params_bounds[1]
+ self.slider.step = params_step
+ if not params_str == '':
+ vals = OrderedDict()
+ for p in range(n_parameters):
+ vals["{}{}".format(params_str, p)] = p
+ self.dropdown_params.options = vals
+ else:
+ # The number of parameters has changed
+ self.selected_values = [0.] * n_parameters
+ if self.mode == 'multiple':
+ # Create new sliders
+ self.sliders = []
+ self.parameters_children = []
+ for p in range(n_parameters):
+ slider_title = ipywidgets.HTML(
+ value="{}{}".format(params_str, p))
+ slider_wid = ipywidgets.FloatSlider(
+ description='', min=params_bounds[0],
+ max=params_bounds[1],
+ step=params_step, value=0., width='8cm',
+ continuous_update=self.continuous_update)
+ tmp = ipywidgets.HBox([slider_title, slider_wid])
+ tmp.layout.align_items = 'center'
+ self.sliders.append(slider_wid)
+ self.parameters_children.append(tmp)
+ self.parameters_wid.children = self.parameters_children
+
+ # Assign saving values and main plotting function to all sliders
+ for w in self.sliders:
+ w.observe(self._save_slider_value_from_id, names='value',
+ type='change')
+ else:
+ self.slider.min = params_bounds[0]
+ self.slider.max = params_bounds[1]
+ self.slider.step = params_step
+ vals = OrderedDict()
+ for p in range(n_parameters):
+ vals["{}{}".format(params_str, p)] = p
+ if self.dropdown_params.value == 0 and n_parameters > 1:
+ self.dropdown_params.value = 1
+ self.dropdown_params.value = 0
+ self.dropdown_params.options = vals
+ self.slider.value = 0.
+
+ # Re-assign render callback
+ self.add_render_function(render_function)
+
+ # Assign new selected options
+ self.n_parameters = n_parameters
+ self.params_str = params_str
+ self.params_bounds = params_bounds
+ self.params_step = params_step
+ self.plot_variance_visible = plot_variance_visible
+
+ # trigger render function if allowed
+ if allow_callback:
+ self.call_render_function(old_value, self.selected_values)
diff --git a/menpo3d/visualize/viewk3dwidgets.py b/menpo3d/visualize/viewk3dwidgets.py
new file mode 100644
index 0000000..ff958a6
--- /dev/null
+++ b/menpo3d/visualize/viewk3dwidgets.py
@@ -0,0 +1,789 @@
+import numpy as np
+from k3d import (Plot, mesh as k3d_mesh, points as k3d_points,
+ text as k3d_text, vectors as k3d_vectors,
+ line as k3d_line)
+from k3d.colormaps import matplotlib_color_maps
+from io import BytesIO
+from ipywidgets import GridBox, Layout, Widget
+from collections import defaultdict
+# The colour map used for all lines and markers
+GLOBAL_CMAP = 'jet'
+
+
+def dict_figures():
+ dict_fig = defaultdict(list)
+ for x in Widget.widgets.values():
+ if hasattr(x, 'figure_id'):
+ dict_fig[x.figure_id].append(x.model_id)
+ return dict_fig
+
+
+def list_figures():
+ list_figures = list(dict_figures().keys())
+ for figure_id in list_figures:
+ print(figure_id)
+
+
+def clear_figure(figure_id=None):
+ # TODO remove figures, clear memory
+ dict_fig = dict_figures()
+
+
+def _calc_distance(points):
+ from menpo.shape import PointCloud
+ pc = PointCloud(points, copy=False)
+ # This is the way that mayavi automatically computes the scale factor
+ # in case the user passes scale_factor = 'auto'. We use it for both
+ # the marker_size as well as the numbers_size.
+ xyz_min, xyz_max = pc.bounds()
+ x_min, y_min, z_min = xyz_min
+ x_max, y_max, z_max = xyz_max
+ distance = np.sqrt(((x_max - x_min) ** 2 +
+ (y_max - y_min) ** 2 +
+ (z_max - z_min) ** 2) /
+ (4 * pc.n_points ** 0.33))
+ return distance
+
+
+def rgb2int(rgb_array, keep_alpha=False):
+ """
+ Convert rgb_array to an int color
+
+ Parameters
+ ----------
+ rgb_array: ndarray
+ An RGBA array
+ keep_alpha: bool
+ If True, the alpha value is also used
+ Returns
+ --------
+ A ndarray with an int color value for each point
+ """
+
+ type_error_message = "RGB shape should be (num_points,3) or (num_points,4)"
+ if isinstance(rgb_array, np.ndarray):
+ if len(rgb_array.shape) != 2:
+ raise TypeError(type_error_message)
+ if rgb_array.shape[1] != 3 and rgb_array.shape[1] != 4:
+ print(rgb_array.shape[1])
+ raise TypeError(type_error_message)
+ else:
+ raise TypeError("RGB shape should be numpy ndarray")
+
+ if not keep_alpha:
+ rgb_array = rgb_array[:, :3]
+
+ num_points, num_colors = rgb_array.shape
+ if rgb_array.dtype in (np.float32, np.float64):
+ rgb_array = np.asarray(np.round(255*rgb_array), dtype='uint32')
+ # TODO
+ # check for overfloat
+ if num_colors == 4:
+ return ((rgb_array[:, 0] << 32) + (rgb_array[:, 1] << 16)
+ + (rgb_array[:, 2] << 8) + rgb_array[:, 3])
+
+ return ((rgb_array[:, 0] << 16) + (rgb_array[:, 1] << 8) + rgb_array[:, 2])
+
+def _parse_marker_size(marker_size, points):
+ distance = _calc_distance(points)
+ if marker_size is None:
+ if distance == 0:
+ marker_size = 1
+ else:
+ marker_size = 0.1 * distance
+ return marker_size
+
+
+def _parse_colour(colour):
+ from matplotlib.colors import rgb2hex
+ if isinstance(colour, int):
+ return colour
+ else:
+ return int(rgb2hex(colour)[1:], base=16)
+
+
+def _check_colours_list(render_flag, colours_list, n_objects, error_str):
+ from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+ if render_flag:
+ if colours_list is None:
+ # sample colours from jet colour map
+ colours_list = sample_colours_from_colourmap(n_objects,
+ GLOBAL_CMAP)
+ colours_list = list(map(_parse_colour, colours_list))
+
+ if isinstance(colours_list, list):
+ if len(colours_list) == 1:
+ colours_list[0] = _parse_colour(colours_list[0])
+ colours_list *= n_objects
+ elif len(colours_list) != n_objects:
+ raise ValueError(error_str)
+ else:
+ colours_list = [_parse_colour(colours_list)] * n_objects
+ else:
+ colours_list = [0x00FF00] * n_objects
+ return colours_list
+
+
+def _calc_camera_position(points):
+ from menpo.shape import PointCloud
+
+ pc = PointCloud(points, copy=False)
+ bounds = pc.bounding_box().points
+ distance = np.max(bounds[1::2] - bounds[::2]) * 2.0
+ camera = [0, 0, distance, 0, 0, 0, 0, 1, 0]
+
+ return camera
+
+
+def _check_figure_id(obj, figure_id, new_figure):
+ if figure_id is None:
+ if new_figure:
+ # A new figure is created but with no figure_id
+ # we should create an id of 'Figure_n form'
+ list_ids = []
+ for x in obj.widgets.values():
+ if hasattr(x, 'figure_id') and x is not obj:
+ if x.figure_id is not None and 'Figure_' in str(x.figure_id):
+ try:
+ n_figure_id = int(x.figure_id.split('Figure_')[1])
+ except ValueError:
+ continue
+ list_ids.append(n_figure_id)
+ if len(list_ids):
+ figure_id = 'Figure_{}'.format(sorted(list_ids)[-1] + 1)
+ else:
+ figure_id = 'Figure_0'
+
+ else:
+ if len(obj.list_figures_ids):
+ figure_id = obj.list_figures_ids[-1]
+ else:
+ obj.remove_widget()
+ raise ValueError('You cannot plot a figure with no id and new figure False')
+ else:
+ if new_figure:
+ for x in obj.widgets.values():
+ if hasattr(x, 'figure_id') and x is not obj:
+ if x.figure_id == figure_id:
+ obj.remove_widget()
+ raise ValueError('Figure id is already given')
+ else:
+ return figure_id
+
+ obj.list_figures_ids.append(figure_id)
+ if hasattr(obj, 'model_id'):
+ obj.dict_figure_id_to_model_id[figure_id] = obj.model_id
+ return figure_id
+
+
+class K3dwidgetIdentity():
+ list_figures_ids = []
+ dict_figure_id_to_model_id = {}
+
+
+class K3dwidgetsRenderer(Plot, K3dwidgetIdentity):
+ """ Abstract class for performing visualizations using K3dwidgets.
+
+ Parameters
+ ----------
+ figure_id : str or `None`
+ A figure name or `None`.
+ new_figure : bool
+ If `True`, creates a new figure on the cell.
+ """
+ # list_figures_ids = []
+
+ def __init__(self, figure_id, new_figure):
+ super(K3dwidgetsRenderer, self).__init__()
+
+ self.figure_id = _check_figure_id(self, figure_id, new_figure)
+ self.new_figure = new_figure
+ self.grid_visible = False
+
+ def _render(self):
+ widg_to_draw = self
+ if not self.new_figure:
+ for widg in self.widgets.values():
+ if isinstance(widg, K3dwidgetsRenderer):
+ if widg.figure_id == self.figure_id and widg.model_id != self.model_id and widg.new_figure:
+ widg_to_draw = widg
+ return widg_to_draw
+ self.remove_widget()
+ raise Exception('Figure with id {} was not found '.format(self.figure_id))
+
+ return widg_to_draw
+
+ def remove_widget(self):
+ super(K3dwidgetsRenderer, self).close()
+ # copy from close from ipywidgets.widget.Widget
+ self.widgets.pop(self.model_id, None)
+ self.comm.close()
+ self.comm = None
+ self._repr_mimebundle_ = None
+
+ def get_figure(self):
+ r"""
+ Gets the figure specified by the combination of `self.figure_id` and
+ `self.new_figure`. If `self.figure_id == None` then `mlab.gcf()`
+ is used. `self.figure_id` is also set to the correct id of the figure
+ if a new figure is created.
+
+ Returns
+ -------
+ figure : Mayavi figure object
+ The figure we will be rendering on.
+ """
+ # return self.figure
+ pass
+
+ def save_figure(self, filename, format='png', size=None,
+ magnification='auto', overwrite=False):
+ r"""
+ Method for saving the figure of the current `figure_id` to file.
+
+ Parameters
+ ----------
+ filename : `str` or `file`-like object
+ The string path or file-like object to save the figure at/into.
+ format : `str`
+ The format to use. This must match the file path if the file path is
+ a `str`.
+ size : `tuple` of `int` or ``None``, optional
+ The size of the image created (unless magnification is set,
+ in which case it is the size of the window used for rendering). If
+ ``None``, then the figure size is used.
+ magnification : `double` or ``'auto'``, optional
+ The magnification is the scaling between the pixels on the screen,
+ and the pixels in the file saved. If you do not specify it, it will
+ be calculated so that the file is saved with the specified size.
+ If you specify a magnification, Mayavi will use the given size as a
+ screen size, and the file size will be ``magnification * size``.
+ If ``'auto'``, then the magnification will be set automatically.
+ overwrite : `bool`, optional
+ If ``True``, the file will be overwritten if it already exists.
+ """
+ pass
+
+ @property
+ def modelview_matrix(self):
+ r"""
+ Retrieves the modelview matrix for this scene.
+
+ :type: ``(4, 4)`` `ndarray`
+ """
+ pass
+
+ @property
+ def projection_matrix(self):
+ r"""
+ Retrieves the projection matrix for this scene.
+
+ :type: ``(4, 4)`` `ndarray`
+ """
+ pass
+
+ @property
+ def renderer_settings(self):
+ r"""
+ Returns all the information required to construct an identical
+ renderer to this one.
+
+ Returns
+ -------
+ settings : `dict`
+ The dictionary with the following keys:
+
+ * ``'width'`` (`int`) : The width of the scene.
+ * ``'height'`` (`int`) : The height of the scene.
+ * ``'model_matrix'`` (`ndarray`) : The model array (identity).
+ * ``'view_matrix'`` (`ndarray`) : The view array.
+ * ``'projection_matrix'`` (`ndarray`) : The projection array.
+
+ """
+ pass
+
+ def force_draw(self):
+ r"""
+ Method for forcing the current figure to render. This is useful for
+ the widgets animation.
+ """
+ self.render()
+
+
+class K3dwidgetsVectorViewer3d(K3dwidgetsRenderer):
+ def __init__(self, figure_id, new_figure, points, vectors):
+ super(K3dwidgetsVectorViewer3d, self).__init__(figure_id, new_figure)
+ non_zero_indices = np.unique(np.nonzero(vectors.reshape(-1, 3))[0])
+ self.points = points[non_zero_indices].astype(np.float32)
+ self.vectors = vectors[non_zero_indices].astype(np.float32)
+
+ def _render(self, colour='r', line_width=2, marker_size=None):
+ marker_size = _parse_marker_size(marker_size, self.points)
+ colour = _parse_colour(colour)
+
+ widg_to_draw = super(K3dwidgetsVectorViewer3d, self)._render()
+ vectors_to_add = k3d_vectors(self.points, self.vectors,
+ color=colour, head_size=marker_size,
+ line_width=line_width)
+ widg_to_draw += vectors_to_add
+ return widg_to_draw
+
+
+class K3dwidgetsPointGraphViewer3d(K3dwidgetsRenderer):
+ def __init__(self, figure_id, new_figure, points, edges):
+ super(K3dwidgetsPointGraphViewer3d, self).__init__(figure_id,
+ new_figure)
+ self.points = points.astype(np.float32)
+ self.edges = edges
+
+ def _render(self, render_lines=True, line_colour='r', line_width=2,
+ render_markers=True, marker_style='flat', marker_size=10,
+ marker_colour='g', alpha=1.0, render_numbering=False,
+ numbers_colour='k', numbers_size=None,
+ colours=None, keep_alpha=False):
+
+ widg_to_draw = super(K3dwidgetsPointGraphViewer3d, self)._render()
+ # Render the lines if requested
+ if render_lines and self.edges is not None:
+ if isinstance(line_colour, list):
+ line_colour = [_parse_colour(i_color) for i_color in
+ line_colour]
+ else:
+ line_colour = _parse_colour(line_colour)
+
+ lines_to_add = None
+ for edge in self.edges:
+ if isinstance(line_colour, list):
+ if len(line_colour):
+ color_this_line = line_colour.pop()
+ else:
+ color_this_line = 0xFF0000
+ else:
+ color_this_line = line_colour
+
+ if lines_to_add is None:
+ lines_to_add = k3d_line(self.points[edge],
+ color=color_this_line)
+ else:
+ lines_to_add += k3d_line(self.points[edge],
+ color=color_this_line)
+ widg_to_draw += lines_to_add
+
+ # Render the markers if requested
+ if render_markers:
+ marker_size = _parse_marker_size(marker_size, self.points)
+ if colours is not None:
+ colours = rgb2int(colours, keep_alpha)
+ marker_colour = 'w'
+ else:
+ colours = []
+
+ marker_colour = _parse_colour(marker_colour)
+
+ # In order to be compatible with mayavi, we just change the
+ # default value for marker_style to mesh
+ if marker_style == 'sphere':
+ marker_style = 'mesh'
+
+ # When the number of points is greater than 1000, it is recommended
+ # to use fast shaders: flat, 3d or 3dSpecular.
+ # The mesh shader generates much bigger overhead,
+ # but it has a properly triangularized sphere
+ # representing each point.
+ if self.points.shape[0] > 1000:
+ marker_style = '3dSpecular'
+
+ points_to_add = k3d_points(self.points, colors=colours,
+ color=marker_colour,
+ point_size=marker_size,
+ opacity=alpha,
+ shader=marker_style)
+ widg_to_draw += points_to_add
+
+ # TODO
+ # A class of k3d.texts that groups all texts should be created
+ # Till then, we go that way
+ if render_numbering:
+ text_to_add = None
+
+ numbers_colour = _parse_colour(numbers_colour)
+ for i, point in enumerate(self.points):
+ if text_to_add is None:
+ text_to_add = k3d_text(str(i), color=numbers_colour,
+ position=point, label_box=False)
+ else:
+ text_to_add += k3d_text(str(i), color=numbers_colour,
+ position=point, label_box=False)
+ widg_to_draw += text_to_add
+
+ return widg_to_draw
+
+
+class K3dwidgetsTriMeshViewer3d(K3dwidgetsRenderer):
+ def __init__(self, figure_id, new_figure, points, trilist, landmarks=None):
+ super(K3dwidgetsTriMeshViewer3d, self).__init__(figure_id, new_figure)
+ self.points = points.astype(np.float32)
+ self.trilist = trilist.astype(np.uint32)
+ self.landmarks = landmarks
+
+ def _render_mesh(self, line_width, colour, mesh_type,
+ marker_style, marker_size, alpha=1.0):
+ marker_size = _parse_marker_size(marker_size, self.points)
+ colour = _parse_colour(colour)
+
+ widg_to_draw = super(K3dwidgetsTriMeshViewer3d, self)._render()
+ wireframe = False
+ opacity = alpha
+ if mesh_type == 'wireframe':
+ wireframe = True
+ opacity = 0.3
+
+ mesh_to_add = k3d_mesh(self.points, self.trilist.flatten(),
+ flat_shading=False, opacity=opacity,
+ color=colour, wireframe=wireframe,
+ side='double')
+ widg_to_draw += mesh_to_add
+
+ if hasattr(self.landmarks, 'points'):
+ self.landmarks.view(figure_id=self.figure_id,
+ new_figure=False,
+ marker_style=marker_style,
+ marker_size=marker_size,
+ inline=True)
+ return widg_to_draw
+
+ def _render(self, line_width=2, colour='r', mesh_type='surface',
+ marker_style='mesh', marker_size=None,
+ normals=None, normals_colour='k', normals_line_width=2,
+ normals_marker_size=None, alpha=1.0):
+
+ widg_to_draw = self._render_mesh(line_width, colour, mesh_type,
+ marker_style, marker_size, alpha)
+ if normals is not None:
+ tmp_normals_widget = K3dwidgetsVectorViewer3d(self.figure_id,
+ False, self.points,
+ normals)
+ tmp_normals_widget._render(colour=normals_colour,
+ line_width=normals_line_width,
+ marker_size=normals_marker_size)
+
+ widg_to_draw.camera = _calc_camera_position(self.points)
+ widg_to_draw.camera_auto_fit = False
+
+ return widg_to_draw
+
+
+class K3dwidgetsTexturedTriMeshViewer3d(K3dwidgetsRenderer):
+ def __init__(self, figure_id, new_figure, points, trilist, texture,
+ tcoords, landmarks):
+ super(K3dwidgetsTexturedTriMeshViewer3d, self).__init__(figure_id,
+ new_figure)
+ self.points = points
+ self.trilist = trilist
+ self.texture = texture
+ self.tcoords = tcoords
+ self.landmarks = landmarks
+
+ def _render_mesh(self, mesh_type='surface', ambient_light=0.0,
+ specular_light=0.0, alpha=1.0):
+
+ widg_to_draw = super(K3dwidgetsTexturedTriMeshViewer3d, self)._render()
+
+ uvs = self.tcoords.points
+ tmp_img = self.texture.mirror(axis=0).as_PILImage()
+ img_byte_arr = BytesIO()
+ tmp_img.save(img_byte_arr, format='PNG')
+ texture = img_byte_arr.getvalue()
+ texture_file_format = 'png'
+
+ mesh_to_add = k3d_mesh(self.points.astype(np.float32),
+ self.trilist.flatten().astype(np.uint32),
+ flat_shading=False,
+ color=0xFFFFFF, side='front', texture=texture,
+ uvs=uvs,
+ texture_file_format=texture_file_format)
+
+ widg_to_draw += mesh_to_add
+
+ if hasattr(self.landmarks, 'points'):
+ self.landmarks.view(inline=True, new_figure=False,
+ figure_id=self.figure_id)
+
+ widg_to_draw.camera = _calc_camera_position(self.points)
+ widg_to_draw.camera_auto_fit = False
+
+ return widg_to_draw
+
+ def _render(self, normals=None, normals_colour='k',
+ normals_line_width=2, normals_marker_size=None):
+
+ if normals is not None:
+ tmp_normals_widget = K3dwidgetsVectorViewer3d(self.figure_id,
+ False, self.points,
+ normals)
+ tmp_normals_widget._render(colour=normals_colour,
+ line_width=normals_line_width,
+ marker_size=normals_marker_size)
+
+ self._render_mesh()
+ return self
+
+
+class K3dwidgetsColouredTriMeshViewer3d(K3dwidgetsRenderer):
+ # TODO
+ def __init__(self, figure_id, new_figure, points, trilist,
+ colour_per_point, landmarks):
+ super(K3dwidgetsColouredTriMeshViewer3d, self).__init__(figure_id,
+ new_figure)
+ self.points = points
+ self.trilist = trilist
+ self.colour_per_point = colour_per_point
+ self.colorbar_object_id = False
+ self.landmarks = landmarks
+
+ def _render_mesh(self):
+ widg_to_draw = super(K3dwidgetsColouredTriMeshViewer3d, self)._render()
+
+ mesh_to_add = k3d_mesh(self.points.astype(np.float32),
+ self.trilist.flatten().astype(np.uint32),
+ attribute=self.colour_per_point,
+ )
+ widg_to_draw += mesh_to_add
+
+ if hasattr(self.landmarks, 'points'):
+ self.landmarks.view(inline=True, new_figure=False,
+ figure_id=self.figure_id)
+ widg_to_draw.camera = _calc_camera_position(self.points)
+ widg_to_draw.camera_auto_fit = False
+
+ def _render(self, normals=None, normals_colour='k', normals_line_width=2,
+ normals_marker_size=None):
+ if normals is not None:
+ K3dwidgetsVectorViewer3d(self.figure_id, False,
+ self.points, normals)._render(
+ colour=normals_colour, line_width=normals_line_width,
+ marker_size=normals_marker_size)
+ self._render_mesh()
+ return self
+
+
+class K3dwidgetsSurfaceViewer3d(K3dwidgetsRenderer):
+ def __init__(self, figure_id, new_figure, values, mask=None):
+ super(K3dwidgetsSurfaceViewer3d, self).__init__(figure_id, new_figure)
+ if mask is not None:
+ values[~mask] = np.nan
+ self.values = values
+
+ def render(self, colour=(1, 0, 0), line_width=2, step=None,
+ marker_style='2darrow', marker_resolution=8, marker_size=0.05,
+ alpha=1.0):
+ # warp_scale = kwargs.get('warp_scale', 'auto')
+ # mlab.surf(self.values, warp_scale=warp_scale)
+ return self
+
+
+class K3dwidgetsLandmarkViewer3d(K3dwidgetsRenderer):
+ def __init__(self, figure_id, new_figure, group, landmark_group):
+ super(K3dwidgetsLandmarkViewer3d, self).__init__(figure_id, new_figure)
+ self.group = group
+ self.landmark_group = landmark_group
+
+ def _render(self, render_lines=True, line_colour='r', line_width=2,
+ render_markers=True, marker_style='mesh', marker_size=None,
+ marker_colour='r', alpha=1.0, render_numbering=False,
+ numbers_colour='k', numbers_size=None):
+ # Regarding the labels colours, we may get passed either no colours (in
+ # which case we generate random colours) or a single colour to colour
+ # all the labels with
+ # TODO: All marker and line options could be defined as lists...
+ n_labels = self.landmark_group.n_labels
+ line_colour = _check_colours_list(
+ render_lines, line_colour, n_labels,
+ 'Must pass a list of line colours with length n_labels or a single'
+ 'line colour for all labels.')
+ marker_colour = _check_colours_list(
+ render_markers, marker_colour, n_labels,
+ 'Must pass a list of marker colours with length n_labels or a '
+ 'single marker face colour for all labels.')
+ marker_size = _parse_marker_size(marker_size,
+ self.landmark_group.points)
+ numbers_size = _parse_marker_size(numbers_size,
+ self.landmark_group.points)
+
+ # get pointcloud of each label
+ sub_pointclouds = self._build_sub_pointclouds()
+
+ widg_to_draw = super(K3dwidgetsLandmarkViewer3d, self)._render()
+
+ if marker_style == 'sphere':
+ marker_style = 'mesh'
+
+ for i, (label, pc) in enumerate(sub_pointclouds):
+ points_to_add = k3d_points(pc.points.astype(np.float32),
+ color=marker_colour[i],
+ point_size=marker_size,
+ shader=marker_style)
+ widg_to_draw += points_to_add
+ if render_numbering:
+ text_to_add = None
+ numbers_colour = _parse_colour(numbers_colour)
+ for i, point in enumerate(self.landmark_group.points):
+ if text_to_add is None:
+ text_to_add = k3d_text(str(i), color=numbers_colour,
+ position=point, label_box=False)
+ else:
+ text_to_add += k3d_text(str(i), color=numbers_colour,
+ position=point, label_box=False)
+ widg_to_draw += text_to_add
+ # widg_to_draw.camera = _calc_camera_position(pc.points)
+ # widg_to_draw.camera_auto_fit = False
+
+ return widg_to_draw
+
+ def _build_sub_pointclouds(self):
+ return [(label, self.landmark_group.get_label(label))
+ for label in self.landmark_group.labels]
+
+
+class K3dwidgetsHeatmapViewer3d(K3dwidgetsRenderer):
+ def __init__(self, figure_id, new_figure, points, trilist, landmarks=None):
+ super(K3dwidgetsHeatmapViewer3d, self).__init__(figure_id, new_figure)
+ self.points = points
+ self.trilist = trilist
+ self.landmarks = landmarks
+
+ def _render_mesh(self, distances_between_meshes, type_cmap,
+ scalar_range, show_statistics=False):
+
+ marker_size = _parse_marker_size(None, self.points)
+ widg_to_draw = super(K3dwidgetsHeatmapViewer3d, self)._render()
+
+ try:
+ color_map = getattr(matplotlib_color_maps, type_cmap)
+ except AttributeError:
+ print('Could not find colormap {}. Hot_r is going to be used instead'.format(type_cmap))
+ color_map = getattr(matplotlib_color_maps, 'hot_r')
+
+ mesh_to_add = k3d_mesh(self.points.astype(np.float32),
+ self.trilist.flatten().astype(np.uint32),
+ color_map=color_map,
+ attribute=distances_between_meshes,
+ color_range=scalar_range
+ )
+ widg_to_draw += mesh_to_add
+
+ if hasattr(self.landmarks, 'points'):
+ self.landmarks.view(figure_id=self.figure_id,
+ new_figure=False,
+ marker_size=marker_size,
+ inline=True)
+
+ if show_statistics:
+ text = '\\begin{{matrix}} \\mu & {:.3} \\\\ \\sigma^2 & {:.3} \\\\ \\max & {:.3} \\end{{matrix}}'\
+ .format(distances_between_meshes.mean(),
+ distances_between_meshes.std(),
+ distances_between_meshes.max())
+ min_b = np.min(self.points, axis=0)
+ max_b = np.max(self.points, axis=0)
+ text_position = (max_b-min_b)/2
+ widg_to_draw += k3d_text(text, position=text_position,
+ color=0xff0000, size=1)
+
+ widg_to_draw.camera = _calc_camera_position(self.points)
+ widg_to_draw.camera_auto_fit = False
+
+ return widg_to_draw
+
+ def _render(self, distances_between_meshes, type_cmap='hot_r',
+ scalar_range=[0, 2], show_statistics=False):
+ return self._render_mesh(distances_between_meshes, type_cmap,
+ scalar_range, show_statistics)
+
+
+class K3dwidgetsPCAModelViewer3d(GridBox, K3dwidgetIdentity):
+ def __init__(self, figure_id, new_figure, points, trilist,
+ components, eigenvalues, n_parameters, parameters_bound,
+ landmarks_indices, widget_style):
+
+ from .menpowidgets import LinearModelParametersWidget
+
+ self.figure_id = _check_figure_id(self, figure_id, new_figure)
+ self.new_figure = new_figure
+ self.points = points.astype(np.float32)
+ if trilist is None:
+ self.trilist = None
+ else:
+ self.trilist = trilist.astype(np.uint32)
+ self.components = components.astype(np.float32)
+ self.eigenvalues = eigenvalues.astype(np.float32)
+ self.n_parameters = n_parameters
+ self.landmarks_indices = landmarks_indices
+ self.layout = Layout(grid_template_columns='1fr 1fr')
+ self.wid = LinearModelParametersWidget(n_parameters=n_parameters,
+ render_function=self.render_function,
+ params_str='Parameter ',
+ mode='multiple',
+ params_bounds=parameters_bound,
+ plot_variance_visible=False,
+ style=widget_style)
+ if self.trilist is None:
+ self.mesh_window = K3dwidgetsPointGraphViewer3d(self.figure_id, False,
+ self.points, self.trilist)
+ else:
+ self.mesh_window = K3dwidgetsTriMeshViewer3d(self.figure_id, False,
+ self.points, self.trilist)
+ super(K3dwidgetsPCAModelViewer3d, self).__init__(children=[self.wid, self.mesh_window],
+ layout=Layout(grid_template_columns='1fr 1fr'))
+
+ self.dict_figure_id_to_model_id[figure_id] = self.model_id
+
+ def _render_mesh(self, mesh_type, line_width, colour,
+ marker_size, marker_style, alpha):
+ marker_size = _parse_marker_size(marker_size, self.points)
+ colour = _parse_colour(colour)
+
+ if self.trilist is None:
+ mesh_to_add = k3d_points(self.points, color=colour,
+ opacity=alpha,
+ point_size=marker_size,
+ shader='3dSpecular')
+ else:
+ mesh_to_add = k3d_mesh(self.points, self.trilist.flatten(),
+ flat_shading=False, opacity=alpha,
+ color=colour, name='Instance',
+ side='double')
+
+ self.mesh_window += mesh_to_add
+
+ if self.landmarks_indices is not None:
+ landmarks_to_add = k3d_points(self.points[self.landmarks_indices],
+ color=0x00FF00, name='landmarks',
+ point_size=marker_size,
+ shader='mesh')
+ self.mesh_window += landmarks_to_add
+ self.mesh_window.camera = _calc_camera_position(self.points)
+ self.mesh_window.camera_auto_fit = False
+
+ return self
+
+ def render_function(self, change):
+ weights = np.asarray(self.wid.selected_values).astype(np.float32)
+ weighted_eigenvalues = weights * self.eigenvalues[:self.n_parameters]**0.5
+ new_instance = (self.components[:self.n_parameters, :].T@weighted_eigenvalues).reshape(-1, 3)
+ new_points = self.points + new_instance
+
+ if self.trilist is None:
+ self.mesh_window.objects[0].positions = new_points
+ else:
+ self.mesh_window.objects[0].vertices = new_points
+ if self.landmarks_indices is not None:
+ self.mesh_window.objects[1].positions = new_points[self.landmarks_indices]
+
+ def _render(self, mesh_type='wireframe', line_width=2, colour='r',
+ marker_style='mesh', marker_size=None, alpha=1.0):
+
+ return self._render_mesh(mesh_type, line_width, colour,
+ marker_size, marker_style, alpha)
+
+ def remove_widget(self):
+ super(K3dwidgetsPCAModelViewer3d, self).close()
diff --git a/menpo3d/visualize/viewmayavi.py b/menpo3d/visualize/viewmayavi.py
index a72f1f3..5fac21f 100644
--- a/menpo3d/visualize/viewmayavi.py
+++ b/menpo3d/visualize/viewmayavi.py
@@ -695,29 +695,67 @@ def render(
self.figure.scene.disable_render = True
for i, (label, pc) in enumerate(sub_pointclouds):
# render pointcloud
- pc.view(
- figure_id=self.figure_id,
- new_figure=False,
- render_lines=render_lines,
- line_colour=line_colour[i],
- line_width=line_width,
- render_markers=render_markers,
- marker_style=marker_style,
- marker_size=marker_size,
- marker_colour=marker_colour[i],
- marker_resolution=marker_resolution,
- step=step,
- alpha=alpha,
- render_numbering=render_numbering,
- numbers_colour=numbers_colour,
- numbers_size=numbers_size,
- )
+ pc.view(figure_id=self.figure_id, new_figure=False,
+ render_lines=render_lines, line_colour=line_colour[i],
+ line_width=line_width, render_markers=render_markers,
+ marker_style=marker_style, marker_size=marker_size,
+ marker_colour=marker_colour[i],
+ marker_resolution=marker_resolution, step=step,
+ alpha=alpha, render_numbering=render_numbering,
+ numbers_colour=numbers_colour, numbers_size=numbers_size,
+ inline=False)
self.figure.scene.disable_render = False
return self
def _build_sub_pointclouds(self):
- return [
- (label, self.landmark_group.get_label(label))
- for label in self.landmark_group.labels
- ]
+ return [(label, self.landmark_group.get_label(label))
+ for label in self.landmark_group.labels]
+
+
+class MayaviHeatmapViewer3d(MayaviRenderer):
+ def __init__(self, figure_id, new_figure, points, trilist):
+ super(MayaviHeatmapViewer3d, self).__init__(figure_id, new_figure)
+ self.points = points
+ self.trilist = trilist
+
+ def _render_mesh(self, scaled_distances_between_meshes,
+ type_cmap, scalar_range, show_statistics):
+ from mayavi import mlab
+ # v = mlab.figure(figure=figure_name, size=size,
+ # bgcolor=(1, 1, 1), fgcolor=(0, 0, 0))
+ src = mlab.pipeline.triangular_mesh_source(self.points[:, 0],
+ self.points[:, 1],
+ self.points[:, 2],
+ self.trilist,
+ scalars=scaled_distances_between_meshes)
+ surf = mlab.pipeline.surface(src, colormap=type_cmap)
+ # When font size bug resolved, uncomment
+ # cb=mlab.colorbar(title='Distances in mm',
+ # orientation='vertical', nb_labels=5)
+ # cb.title_text_property.font_size = 20
+ # cb.label_text_property.font_family = 'times'
+ # cb.label_text_property.font_size=10
+ cb = mlab.colorbar(orientation='vertical', nb_labels=5)
+ cb.data_range = scalar_range
+ cb.scalar_bar_representation.position = [0.8, 0.15]
+ cb.scalar_bar_representation.position2 = [0.15, 0.7]
+ text = mlab.text(0.8, 0.85, 'Distances in mm')
+ text.width = 0.20
+ if show_statistics:
+ text2 = mlab.text(0.5, 0.02,
+ 'Mean error {:.3}mm \nMax error {:.3}mm \
+ '.format(scaled_distances_between_meshes.mean(),
+ scaled_distances_between_meshes.max()))
+ text2.width = 0.20
+ surf.module_manager.scalar_lut_manager.reverse_lut = True
+ # perhaps we shouud usew kwargs
+ # if camera_settings is None:
+ mlab.gcf().scene.z_plus_view()
+
+ def render(self, scaled_distances_between_meshes, type_cmap='hot',
+ scalar_range=(0, 2), show_statistics=False):
+
+ self._render_mesh(scaled_distances_between_meshes, type_cmap,
+ scalar_range, show_statistics)
+ return self
diff --git a/setup.py b/setup.py
index 1e8e5f2..5498039 100644
--- a/setup.py
+++ b/setup.py
@@ -111,7 +111,8 @@ def build_extension_from_pyx(pyx_path, extra_sources_paths=None):
version, cmdclass = get_version_and_cmdclass("menpo3d")
-install_requires = ["menpo>=0.9.0,<0.12.0", "mayavi>=4.7.0", "moderngl>=5.6.*,<6.0"]
+install_requires = ["menpo>=0.9.0,<0.12.0", "mayavi>=4.7.0",
+ "moderngl>=5.6,<6.0", "k3d<=2.9.2", "ipywidgets<8"]
setup(
name="menpo3d",