diff --git a/examples/0_Introduction_to_K3d_Widgets.ipynb b/examples/0_Introduction_to_K3d_Widgets.ipynb new file mode 100644 index 0000000..cff64e9 --- /dev/null +++ b/examples/0_Introduction_to_K3d_Widgets.ipynb @@ -0,0 +1,664 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import menpo3d.io as m3io\n", + "import menpo.io as mio\n", + "from menpo.shape import PointCloud, ColouredTriMesh\n", + "from menpo.landmark import face_ibug_68_to_face_ibug_68\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Load the data (Mesh, landmarks and model)

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mesh = m3io.import_mesh('../menpo3d/data/james.obj')\n", + "lms = m3io.import_landmark_file('../menpo3d/data/james.ljson')['LJSON']\n", + "\n", + "# Load model and its landmarks indices \n", + "model = mio.import_pickle('../menpo3d/data/3DMD_all_all_all_10.pkl')['model']\n", + "lms_indices = [21868, 22404, 22298, 22327, 43430, 45175, 46312, 47132, 47911, 48692,\n", + " 49737, 51376, 53136, 32516, 32616, 32205, 32701, 38910, 39396, 39693,\n", + " 39934, 40131, 40843, 41006, 41179, 41430, 13399, 8161, 8172, 8179, 8185,\n", + " 5622, 6881, 8202, 9403, 10764, 1831, 3887, 5049, 6214, 4805, 3643, 9955,\n", + " 11095, 12255, 14197, 12397, 11366, 5779, 6024, 7014, 8215, 9294, 10267,\n", + " 10922, 9556, 8836, 8236, 7636, 6794, 5905, 7264, 8223, 9063, 10404, 8828,\n", + " 8228, 7509]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Create new random instances

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cov = np.diag(model.eigenvalues)\n", + "model_mean = model.mean()\n", + "synthetic_weights = np.random.multivariate_normal(np.zeros(model.n_active_components),\n", + " cov, 1000)\n", + "random_mesh = model.instance(synthetic_weights[5])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Show the mesh

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Default values for TriMesh, TextureMesh viewer are\n", + "# figure_id None\n", + "# new_figure True\n", + "# in that case an automatic figure_id will be given\n", + "# with 'Figure_{n}' format\n", + "# n will be an increased integer starting from zero\n", + "mesh.view() # wait a bit before magic happens" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Show the mesh and landmarks

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "mesh.view(figure_id='James')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add landmarks to figure with id James\n", + "lms_poincloud = PointCloud(lms.points)\n", + "lms_poincloud.view(figure_id='James',new_figure=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add landmarks to figure with id Figure_0\n", + "lms_poincloud.view(figure_id='Figure_0', new_figure=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Show a mesh that has landmarks

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mesh.landmarks = lms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lms.view(new_figure=True, render_numbering=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The mesh has now landmarks, so they would be plotted as well\n", + "# the figure id is now Figure_2\n", + "mesh.view()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show the TexturedMesh without texture\n", + "mesh.view(render_texture=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mesh.view(render_texture=False, mesh_type='wireframe')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random_mesh.view(mesh_type='surface')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

HeatMaps

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Heatmap between a random mesh and mean mesh\n", + "random_mesh.heatmap(model_mean)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Heatmap with statistics \n", + "# Be careful, since we have already drawn a heatmap between\n", + "# random and mean, we should use another name for figure\n", + "random_mesh.heatmap(model_mean, show_statistics=True, figure_id='Heatmap2')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Heatmap with landmarks\n", + "model_mean.landmarks = face_ibug_68_to_face_ibug_68(PointCloud(model_mean.points[lms_indices]))\n", + "model_mean.heatmap(random_mesh, inline=True, show_statistics=True, figure_id='Heatmap3')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random_mesh_landmarks = face_ibug_68_to_face_ibug_68(random_mesh.points[lms_indices])\n", + "random_mesh_landmarks.view(inline=True, new_figure=False, figure_id='Heatmap2')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Show Normals

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pts = random_mesh.points[lms_indices]\n", + "vrt = np.zeros((random_mesh.n_points,3))\n", + "vrt[lms_indices] = random_mesh.vertex_normals()[lms_indices] / 5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random_mesh.view(normals=vrt, \n", + " normals_marker_size= 0.5,\n", + " normals_line_width = 0.01,\n", + " figure_id='Normals')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random_mesh_landmarks.view(figure_id='Normals', new_figure=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "random_mesh.landmarks = random_mesh_landmarks\n", + "random_mesh.view( normals=vrt, \n", + " normals_marker_size= 0.5,\n", + " normals_line_width = 0.01)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Show PointCloud with colours

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.utils.tempdir import TemporaryWorkingDirectory\n", + "with TemporaryWorkingDirectory() as tmpdir:\n", + " !mkdir -p ./data/PittsburghBridge\n", + " !wget -P ./data/PittsburghBridge https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz\n", + " pointcloud = np.load('./data/PittsburghBridge/pointcloud.npz')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "points = pointcloud['verts']\n", + "colours = pointcloud['rgb']\n", + "\n", + "coloured_pointcloud = PointCloud(points)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "coloured_pointcloud.view(colours=colours)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Show Surface

" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Show ColouredTriMesh

\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "colors = np.random.rand(random_mesh.n_points)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "colors[0:1000]= 0.1\n", + "colors[1000:10000]= 0.2\n", + "colors[1000:10000]= 0.4\n", + "colors[10000:30000]= 0.6\n", + "colors[30000:40000]= 0.8\n", + "colors[40000:50000]= 0.8\n", + "colors[50000:]= 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_mesh = ColouredTriMesh(random_mesh.points, random_mesh.trilist, colours=colors)\n", + "new_mesh.view()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_mesh.landmarks = random_mesh_landmarks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Show Graphs

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from menpo.shape import PointUndirectedGraph \n", + "import numpy as np\n", + "points = np.array([[10, 30, 10], [0, 20, 11], [20, 20, 11], [0, 10, 12], [20, 10, 12], [0, 0, 12]]) \n", + "edges = np.array([[0, 1], [1, 0], [0, 2], [2, 0], [1, 2], [2, 1], \n", + " [1, 3], [3, 1], [2, 4], [4, 2], [3, 4], [4, 3],[3, 5], [5, 3]]) \n", + "colors = [\n", + " 0xff,\n", + " 0xffff,\n", + " 0xff00ff,\n", + " 0x00ffff,\n", + " 0xffff00,]\n", + "\n", + "graph = PointUndirectedGraph.init_from_edges(points, edges) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "graph.view()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "graph.view(line_colour=colors, render_numbering=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

View a Morphable Model

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "model.view(figure_id='Model', n_parameters=10, landmarks_indices=lms_indices)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random_mesh.view()\n", + "lms.view(new_figure=False)\n", + "mesh.view(new_figure=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Load big meshes

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mesh = m3io.import_mesh('/data/meshes/mesh/33_plain.obj')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mesh.view( )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mesh_hq = m3io.import_mesh('/data/meshes/mesh/HQ_mesh_alex.obj')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mesh_hq" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mesh_hq.view()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Fail cases (supposed you have already executed all the above cells)

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# It should fail if the previous cells have been executed, as default values for landmarker viewer are\n", + "# figure_id = None and new_figure=False, so it could not\n", + "# find a figure with id None\n", + "lms.view()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# It should fail if the previous cells have been executed, as we have already had a figure with id \n", + "# James and we cannot create a new one with the same figure_id\n", + "mesh.view(figure_id='James', new_figure=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# It should fail if the previous cells have been executed, as we have already had a figure with id \n", + "# Model and we cannot create a new one with the same figure_id\n", + "model.view(inline=True, figure_id='Model')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# You have already created a heatmap between random_mesh and model_mean\n", + "random_mesh.heatmap(model_mean, inline=True)`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Additional functions

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from menpo3d.visualize import list_figures, dict_figures" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "list_figures()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dict_figures()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Testing

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ipywidgets import Widget\n", + "from menpo3d.visualize.viewk3dwidgets import K3dwidgetsRenderer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for x in Widget.widgets.values():\n", + " print(type(x),x.model_id)\n", + "# if isinstance(x, K3dwidgetsRenderer):\n", + " if hasattr(x,'figure_id'):\n", + " print(type(x),x.model_id, x.figure_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "list_figures()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/menpo3d/visualize/__init__.py b/menpo3d/visualize/__init__.py index 87e7761..d221a22 100644 --- a/menpo3d/visualize/__init__.py +++ b/menpo3d/visualize/__init__.py @@ -1,8 +1,10 @@ -from .base import ( - PointGraphViewer3d, - TriMeshViewer3d, - VectorViewer3d, - ColouredTriMeshViewer3d, - TexturedTriMeshViewer3d, - LandmarkViewer3d, -) +from .base import (PointGraphViewer3d, TriMeshViewer3d, VectorViewer3d, + ColouredTriMeshViewer3d, TexturedTriMeshViewer3d, + LandmarkViewer3d, HeatmapViewer3d, + TriMeshInlineViewer3d, ColouredTriMeshInlineViewer3d, + PointGraphInlineViewer3d, TexturedTriMeshInlineViewer3d, + LandmarkInlineViewer3d, PointGraphInlineViewer3d, + VectorInlineViewer3d, HeatmapInlineViewer3d, + PCAModelInlineViewer3d) + +from .viewk3dwidgets import (list_figures, clear_figure, dict_figures) diff --git a/menpo3d/visualize/base.py b/menpo3d/visualize/base.py index 696c898..b7f524c 100644 --- a/menpo3d/visualize/base.py +++ b/menpo3d/visualize/base.py @@ -1,11 +1,17 @@ from .viewmayavi import ( - MayaviTriMeshViewer3d, - MayaviPointGraphViewer3d, - MayaviTexturedTriMeshViewer3d, - MayaviLandmarkViewer3d, - MayaviVectorViewer3d, - MayaviColouredTriMeshViewer3d, -) + MayaviTriMeshViewer3d, MayaviPointGraphViewer3d, + MayaviTexturedTriMeshViewer3d, MayaviLandmarkViewer3d, + MayaviVectorViewer3d, MayaviColouredTriMeshViewer3d, MayaviHeatmapViewer3d) + + +from .viewk3dwidgets import (K3dwidgetsTriMeshViewer3d, + K3dwidgetsPointGraphViewer3d, + K3dwidgetsVectorViewer3d, + K3dwidgetsLandmarkViewer3d, + K3dwidgetsTexturedTriMeshViewer3d, + K3dwidgetsColouredTriMeshViewer3d, + K3dwidgetsHeatmapViewer3d, + K3dwidgetsPCAModelViewer3d) PointGraphViewer3d = MayaviPointGraphViewer3d TriMeshViewer3d = MayaviTriMeshViewer3d @@ -13,3 +19,13 @@ ColouredTriMeshViewer3d = MayaviColouredTriMeshViewer3d LandmarkViewer3d = MayaviLandmarkViewer3d VectorViewer3d = MayaviVectorViewer3d +HeatmapViewer3d = MayaviHeatmapViewer3d + +TriMeshInlineViewer3d = K3dwidgetsTriMeshViewer3d +TexturedTriMeshInlineViewer3d = K3dwidgetsTexturedTriMeshViewer3d +LandmarkInlineViewer3d = K3dwidgetsLandmarkViewer3d +PointGraphInlineViewer3d = K3dwidgetsPointGraphViewer3d +VectorInlineViewer3d = K3dwidgetsVectorViewer3d +HeatmapInlineViewer3d = K3dwidgetsHeatmapViewer3d +PCAModelInlineViewer3d = K3dwidgetsPCAModelViewer3d +ColouredTriMeshInlineViewer3d = K3dwidgetsColouredTriMeshViewer3d diff --git a/menpo3d/visualize/menpowidgets.py b/menpo3d/visualize/menpowidgets.py new file mode 100644 index 0000000..749dba2 --- /dev/null +++ b/menpo3d/visualize/menpowidgets.py @@ -0,0 +1,828 @@ +from collections import OrderedDict +from time import sleep +from IPython import get_ipython +from ipywidgets import Box +import ipywidgets +from traitlets.traitlets import List + +# The below classes have been copied from +# the deprecated menpowidgets package +# MenpoWidget can be found in abstract.py +# LinearModelParametersWidget in options.py +class MenpoWidget(Box): + r""" + Base class for defining a Menpo widget. + + The widget has a `selected_values` trait that can be used in order to + inspect any changes that occur to its children. It also has functionality + for adding, removing, replacing or calling the handler callback function of + the `selected_values` trait. + + Parameters + ---------- + children : `list` of `ipywidgets` + The `list` of `ipywidgets` objects to be set as children in the + `ipywidgets.Box`. + trait : `traitlets.TraitType` subclass + The type of the `selected_values` object that gets added as a trait + in the widget. Possible options from `traitlets` are {``Int``, ``Float``, + ``Dict``, ``List``, ``Tuple``}. + trait_initial_value : `int` or `float` or `dict` or `list` or `tuple` + The initial value of the `selected_values` trait. + render_function : `callable` or ``None``, optional + The render function that behaves as a callback handler of the + `selected_values` trait for the `change` event. Its signature can be + ``render_function()`` or ``render_function(change)``, where ``change`` + is a `dict` with the following keys: + + - ``owner`` : the `HasTraits` instance + - ``old`` : the old value of the modified trait attribute + - ``new`` : the new value of the modified trait attribute + - ``name`` : the name of the modified trait attribute. + - ``type`` : ``'change'`` + + If ``None``, then nothing is added. + """ + def __init__(self, children, trait, trait_initial_value, + render_function=None): + # Create box object + super(MenpoWidget, self).__init__(children=children) + + # Add trait for selected values + selected_values = trait(default_value=trait_initial_value) + selected_values_trait = {'selected_values': selected_values} + self.add_traits(**selected_values_trait) + self.selected_values = trait_initial_value + + # Set render function + self._render_function = None + self.add_render_function(render_function) + + def add_render_function(self, render_function): + r""" + Method that adds the provided `render_function()` as a callback handler + to the `selected_values` trait of the widget. The given function is + also stored in `self._render_function`. + + Parameters + ---------- + render_function : `callable` or ``None``, optional + The render function that behaves as a callback handler of the + `selected_values` trait for the `change` event. Its signature can be + ``render_function()`` or ``render_function(change)``, where + ``change`` is a `dict` with the following keys: + + - ``owner`` : the `HasTraits` instance + - ``old`` : the old value of the modified trait attribute + - ``new`` : the new value of the modified trait attribute + - ``name`` : the name of the modified trait attribute. + - ``type`` : ``'change'`` + + If ``None``, then nothing is added. + """ + self._render_function = render_function + if self._render_function is not None: + self.observe(self._render_function, names='selected_values', + type='change') + + def remove_render_function(self): + r""" + Method that removes the current `self._render_function()` as a callback + handler to the `selected_values` trait of the widget and sets + ``self._render_function = None``. + """ + if self._render_function is not None: + self.unobserve(self._render_function, names='selected_values', + type='change') + self._render_function = None + + def replace_render_function(self, render_function): + r""" + Method that replaces the current `self._render_function()` with the + given `render_function()` as a callback handler to the `selected_values` + trait of the widget. + + Parameters + ---------- + render_function : `callable` or ``None``, optional + The render function that behaves as a callback handler of the + `selected_values` trait for the `change` event. Its signature can be + ``render_function()`` or ``render_function(change)``, where + ``change`` is a `dict` with the following keys: + + - ``owner`` : the `HasTraits` instance + - ``old`` : the old value of the modified trait attribute + - ``new`` : the new value of the modified trait attribute + - ``name`` : the name of the modified trait attribute. + - ``type`` : ``'change'`` + + If ``None``, then nothing is added. + """ + # remove old function + self.remove_render_function() + + # add new function + self.add_render_function(render_function) + + def call_render_function(self, old_value, new_value, type_value='change'): + r""" + Method that calls the existing `render_function()` callback handler. + + Parameters + ---------- + old_value : `int` or `float` or `dict` or `list` or `tuple` + The old `selected_values` value. + new_value : `int` or `float` or `dict` or `list` or `tuple` + The new `selected_values` value. + type_value : `str`, optional + The trait event type. + """ + if self._render_function is not None: + change_dict = {'type': 'change', 'old': old_value, + 'name': type_value, 'new': new_value, + 'owner': self.__str__()} + self._render_function(change_dict) + + +class LinearModelParametersWidget(MenpoWidget): + r""" + Creates a widget for selecting parameters values when visualizing a linear + model (e.g. PCA model). + + Note that: + + * To update the state of the widget, please refer to the + :meth:`set_widget_state` method. + * The selected values are stored in the ``self.selected_values`` `trait` + which is a `list`. + * To set the styling of this widget please refer to the + :meth:`predefined_style` method. + * To update the handler callback functions of the widget, please refer to + the :meth:`replace_render_function` and :meth:`replace_variance_function` + methods. + + Parameters + ---------- + n_parameters : `int` + The `list` of initial parameters values. + render_function : `callable` or ``None``, optional + The render function that is executed when a widgets' value changes. + It must have signature ``render_function(change)`` where ``change`` is + a `dict` with the following keys: + + * ``type`` : The type of notification (normally ``'change'``). + * ``owner`` : the `HasTraits` instance + * ``old`` : the old value of the modified trait attribute + * ``new`` : the new value of the modified trait attribute + * ``name`` : the name of the modified trait attribute. + + If ``None``, then nothing is assigned. + mode : ``{'single', 'multiple'}``, optional + If ``'single'``, only a single slider is constructed along with a + dropdown menu that allows the parameter selection. + If ``'multiple'``, a slider is constructed for each parameter. + params_str : `str`, optional + The string that will be used as description of the slider(s). The final + description has the form ``"{}{}".format(params_str, p)``, where ``p`` + is the parameter number. + params_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + params_step : `float`, optional + The step, in std units, of the sliders. + plot_variance_visible : `bool`, optional + Defines whether the button for plotting the variance will be visible + upon construction. + plot_variance_function : `callable` or ``None``, optional + The plot function that is executed when the plot variance button is + clicked. If ``None``, then nothing is assigned. + animation_visible : `bool`, optional + Defines whether the animation options will be visible. + loop_enabled : `bool`, optional + If ``True``, then the repeat mode of the animation is enabled. + interval : `float`, optional + The interval between the animation progress in seconds. + interval_step : `float`, optional + The interval step (in seconds) that is applied when fast + forward/backward buttons are pressed. + animation_step : `float`, optional + The parameters step that is applied when animation is enabled. + style : `str` (see below), optional + Sets a predefined style at the widget. Possible options are: + + ============= ================== + Style Description + ============= ================== + ``'success'`` Green-based style + ``'info'`` Blue-based style + ``'warning'`` Yellow-based style + ``'danger'`` Red-based style + ``''`` No style + ============= ================== + + continuous_update : `bool`, optional + If ``True``, then the render function is called while moving a + slider's handle. If ``False``, then the the function is called only + when the handle (mouse click) is released. + + Example + ------- + Let's create a linear model parameters values widget and then update its + state. Firstly, we need to import it: + + >>> from menpowidgets.options import LinearModelParametersWidget + + Now let's define a render function that will get called on every widget + change and will dynamically print the selected parameters: + + >>> from menpo.visualize import print_dynamic + >>> def render_function(change): + >>> s = "Selected parameters: {}".format(wid.selected_values) + >>> print_dynamic(s) + + Create the widget with some initial options and display it: + + >>> wid = LinearModelParametersWidget(n_parameters=5, + >>> render_function=render_function, + >>> params_str='Parameter ', + >>> mode='multiple', + >>> params_bounds=(-3., 3.), + >>> plot_variance_visible=True, + >>> style='info') + >>> wid + + By moving the sliders, the printed message gets updated. Finally, let's + change the widget status with a new set of options: + + >>> wid.set_widget_state(n_parameters=10, params_str='', + >>> params_step=0.1, params_bounds=(-10, 10), + >>> plot_variance_visible=False, + >>> allow_callback=True) + """ + def __init__(self, n_parameters, render_function=None, mode='multiple', + params_str='Parameter ', params_bounds=(-3., 3.), + params_step=0.1, plot_variance_visible=True, + plot_variance_function=None, animation_visible=True, + loop_enabled=False, interval=0., interval_step=0.05, + animation_step=0.5, style='', continuous_update=False): + + # Get the kernel to use it later in order to make sure that the widgets' + # traits changes are passed during a while-loop + self.kernel = get_ipython().kernel + + # If only one slider requested, then set mode to multiple + if n_parameters == 1: + mode = 'multiple' + + # Create children + if mode == 'multiple': + self.sliders = [] + self.parameters_children = [] + for p in range(n_parameters): + slider_title = ipywidgets.HTML( + value="{}{}".format(params_str, p)) + slider_wid = ipywidgets.FloatSlider( + description='', min=params_bounds[0], max=params_bounds[1], + step=params_step, value=0., + continuous_update=continuous_update, + layout=ipywidgets.Layout(width='8cm')) + tmp = ipywidgets.HBox([slider_title, slider_wid]) + tmp.layout.align_items = 'center' + self.sliders.append(slider_wid) + self.parameters_children.append(tmp) + self.parameters_wid = ipywidgets.VBox(self.parameters_children) + self.parameters_wid.layout.align_items = 'flex-end' + else: + vals = OrderedDict() + for p in range(n_parameters): + vals["{}{}".format(params_str, p)] = p + self.slider = ipywidgets.FloatSlider( + description='', min=params_bounds[0], max=params_bounds[1], + step=params_step, value=0., readout=True, + layout=ipywidgets.Layout(width='8cm'), + continuous_update=continuous_update) + self.dropdown_params = ipywidgets.Dropdown( + options=vals, layout=ipywidgets.Layout(width='3cm')) + self.dropdown_params.layout.margin = '0px 10px 0px 0px' + self.parameters_wid = ipywidgets.HBox([self.dropdown_params, + self.slider]) + self.parameters_wid.layout.margin = '0px 0px 10px 0px' + self.plot_button = ipywidgets.Button( + description='Variance', layout=ipywidgets.Layout(width='80px')) + self.plot_button.layout.display = ( + 'inline' if plot_variance_visible else 'none') + self.reset_button = ipywidgets.Button( + description='Reset', layout=ipywidgets.Layout(width='80px')) + self.plot_and_reset = ipywidgets.HBox([self.reset_button, + self.plot_button]) + self.play_button = ipywidgets.Button( + icon='play', description='', tooltip='Play animation', + layout=ipywidgets.Layout(width='40px')) + self.stop_button = ipywidgets.Button( + icon='stop', description='', tooltip='Stop animation', + layout=ipywidgets.Layout(width='40px')) + self.fast_forward_button = ipywidgets.Button( + icon='fast-forward', description='', + layout=ipywidgets.Layout(width='40px'), + tooltip='Increase animation speed') + self.fast_backward_button = ipywidgets.Button( + icon='fast-backward', description='', + layout=ipywidgets.Layout(width='40px'), + tooltip='Decrease animation speed') + loop_icon = 'repeat' if loop_enabled else 'long-arrow-right' + self.loop_toggle = ipywidgets.ToggleButton( + icon=loop_icon, description='', value=loop_enabled, + layout=ipywidgets.Layout(width='40px'), tooltip='Repeat animation') + self.animation_buttons = ipywidgets.HBox( + [self.play_button, self.stop_button, self.loop_toggle, + self.fast_backward_button, self.fast_forward_button]) + self.animation_buttons.layout.display = ( + 'flex' if animation_visible else 'none') + self.animation_buttons.layout.margin = '0px 15px 0px 0px' + self.buttons_box = ipywidgets.HBox([self.animation_buttons, + self.plot_and_reset]) + self.container = ipywidgets.VBox([self.parameters_wid, + self.buttons_box]) + + # Create final widget + super(LinearModelParametersWidget, self).__init__( + [self.container], List, [0.] * n_parameters, + render_function=render_function) + + # Assign output + self.n_parameters = n_parameters + self.mode = mode + self.params_str = params_str + self.params_bounds = params_bounds + self.params_step = params_step + self.plot_variance_visible = plot_variance_visible + self.loop_enabled = loop_enabled + self.continuous_update = continuous_update + self.interval = interval + self.interval_step = interval_step + self.animation_step = animation_step + self.animation_visible = animation_visible + self.please_stop = False + + # Set style + self.predefined_style(style) + + # Set functionality + if mode == 'single': + # Assign slider value to parameters values list + def save_slider_value(change): + current_parameters = list(self.selected_values) + current_parameters[self.dropdown_params.value] = change['new'] + self.selected_values = current_parameters + self.slider.observe(save_slider_value, names='value', type='change') + + # Set correct value to slider when drop down menu value changes + def set_slider_value(change): + # Temporarily remove render callback + render_fun = self._render_function + self.remove_render_function() + # Set slider value + self.slider.value = self.selected_values[change['new']] + # Re-assign render callback + self.add_render_function(render_fun) + self.dropdown_params.observe(set_slider_value, names='value', + type='change') + else: + # Assign saving values and main plotting function to all sliders + for w in self.sliders: + w.observe(self._save_slider_value_from_id, names='value', + type='change') + + def reset_parameters(name): + # Keep old value + old_value = self.selected_values + + # Temporarily remove render callback + render_fun = self._render_function + self.remove_render_function() + + # Set parameters to 0 + self.selected_values = [0.0] * self.n_parameters + if mode == 'multiple': + for ww in self.sliders: + ww.value = 0. + else: + self.parameters_wid.children[0].value = 0 + self.parameters_wid.children[1].value = 0. + + # Re-assign render callback and trigger it + self.add_render_function(render_fun) + self.call_render_function(old_value, self.selected_values) + self.reset_button.on_click(reset_parameters) + + # Set functionality + def loop_pressed(change): + if change['new']: + self.loop_toggle.icon = 'repeat' + else: + self.loop_toggle.icon = 'long-arrow-right' + self.kernel.do_one_iteration() + self.loop_toggle.observe(loop_pressed, names='value', type='change') + + def fast_forward_pressed(name): + tmp = self.interval + tmp -= self.interval_step + if tmp < 0: + tmp = 0 + self.interval = tmp + self.kernel.do_one_iteration() + self.fast_forward_button.on_click(fast_forward_pressed) + + def fast_backward_pressed(name): + self.interval += self.interval_step + self.kernel.do_one_iteration() + self.fast_backward_button.on_click(fast_backward_pressed) + + def animate(change): + reset_parameters('') + self.please_stop = False + self.reset_button.disabled = True + self.plot_button.disabled = True + if mode == 'multiple': + n_sliders = self.n_parameters + slider_id = 0 + while slider_id < n_sliders: + # animate from 0 to min + slider_val = 0. + while slider_val > self.params_bounds[0]: + # Run IPython iteration. + self.kernel.do_one_iteration() + + # Check stop flag + if self.please_stop: + break + + # update slider value + slider_val -= self.animation_step + + # set value + self.sliders[slider_id].value = slider_val + + # wait + sleep(self.interval) + + # Run IPython iteration. + self.kernel.do_one_iteration() + + # animate from min to max + slider_val = self.params_bounds[0] + while slider_val < self.params_bounds[1]: + # Run IPython iteration. + self.kernel.do_one_iteration() + + # Check stop flag + if self.please_stop: + break + + # update slider value + slider_val += self.animation_step + + # set value + self.sliders[slider_id].value = slider_val + + # wait + sleep(self.interval) + + # Run IPython iteration. + self.kernel.do_one_iteration() + + # animate from max to 0 + slider_val = self.params_bounds[1] + while slider_val > 0.: + # Run IPython iteration. + self.kernel.do_one_iteration() + + # Check stop flag + if self.please_stop: + break + + # update slider value + slider_val -= self.animation_step + + # set value + self.sliders[slider_id].value = slider_val + + # wait + sleep(self.interval) + + # Run IPython iteration. + self.kernel.do_one_iteration() + + # reset value + self.sliders[slider_id].value = 0. + + # Check stop flag + if self.please_stop: + break + + # update slider id + if self.loop_toggle.value and slider_id == n_sliders - 1: + slider_id = 0 + else: + slider_id += 1 + + if not self.loop_toggle.value and slider_id >= n_sliders: + self.stop_animation() + else: + n_sliders = self.n_parameters + slider_id = 0 + self.please_stop = False + while slider_id < n_sliders: + # set dropdown value + self.parameters_wid.children[0].value = slider_id + + # animate from 0 to min + slider_val = 0. + while slider_val > self.params_bounds[0]: + # Run IPython iteration. + self.kernel.do_one_iteration() + + # Check stop flag + if self.please_stop: + break + + # update slider value + slider_val -= self.animation_step + + # set value + self.parameters_wid.children[1].value = slider_val + + # wait + sleep(self.interval) + + # Run IPython iteration. + self.kernel.do_one_iteration() + + # animate from min to max + slider_val = self.params_bounds[0] + while slider_val < self.params_bounds[1]: + # Run IPython iteration. + self.kernel.do_one_iteration() + + # Check stop flag + if self.please_stop: + break + + # update slider value + slider_val += self.animation_step + + # set value + self.parameters_wid.children[1].value = slider_val + + # wait + sleep(self.interval) + + # Run IPython iteration. + self.kernel.do_one_iteration() + + # animate from max to 0 + slider_val = self.params_bounds[1] + while slider_val > 0.: + # Run IPython iteration. + self.kernel.do_one_iteration() + + # Check stop flag + if self.please_stop: + break + + # update slider value + slider_val -= self.animation_step + + # set value + self.parameters_wid.children[1].value = slider_val + + # wait + sleep(self.interval) + + # Run IPython iteration. + self.kernel.do_one_iteration() + + # reset value + self.parameters_wid.children[1].value = 0. + + # Check stop flag + if self.please_stop: + break + + # update slider id + if self.loop_toggle.value and slider_id == n_sliders - 1: + slider_id = 0 + else: + slider_id += 1 + self.reset_button.disabled = False + self.plot_button.disabled = False + self.play_button.on_click(animate) + + def stop_pressed(_): + self.stop_animation() + self.stop_button.on_click(stop_pressed) + + # Set plot variance function + self._variance_function = None + self.add_variance_function(plot_variance_function) + + def _save_slider_value_from_id(self, change): + current_parameters = list(self.selected_values) + i = self.sliders.index(change['owner']) + current_parameters[i] = change['new'] + self.selected_values = current_parameters + + def predefined_style(self, style): + r""" + Function that sets a predefined style on the widget. + + Parameters + ---------- + style : `str` (see below) + Style options: + + ============= ================== + Style Description + ============= ================== + ``'success'`` Green-based style + ``'info'`` Blue-based style + ``'warning'`` Yellow-based style + ``'danger'`` Red-based style + ``''`` No style + ============= ================== + """ + self.container.box_style = style + self.container.border = '0px' + self.play_button.button_style = 'success' + self.stop_button.button_style = 'danger' + self.fast_forward_button.button_style = 'info' + self.fast_backward_button.button_style = 'info' + self.loop_toggle.button_style = 'warning' + self.reset_button.button_style = 'danger' + self.plot_button.button_style = 'primary' + + def stop_animation(self): + r""" + Method that stops an active annotation. + """ + self.please_stop = True + + def add_variance_function(self, variance_function): + r""" + Method that adds a `variance_function()` to the `Variance` button of the + widget. The given function is also stored in `self._variance_function`. + + Parameters + ---------- + variance_function : `callable` or ``None``, optional + The variance function that behaves as a callback. If ``None``, + then nothing is added. + """ + self._variance_function = variance_function + if self._variance_function is not None: + self.plot_button.on_click(self._variance_function) + + def remove_variance_function(self): + r""" + Method that removes the current `self._variance_function()` from + the `Variance` button of the widget and sets + ``self._variance_function = None``. + """ + self.plot_button.on_click(self._variance_function, remove=True) + self._variance_function = None + + def replace_variance_function(self, variance_function): + r""" + Method that replaces the current `self._variance_function()` of the + `Variance` button of the widget with the given `variance_function()`. + + Parameters + ---------- + variance_function : `callable` or ``None``, optional + The variance function that behaves as a callback. If ``None``, + then nothing happens. + """ + # remove old function + self.remove_variance_function() + + # add new function + self.add_variance_function(variance_function) + + def set_widget_state(self, n_parameters=None, params_str=None, + params_bounds=None, params_step=None, + plot_variance_visible=True, animation_step=0.5, + allow_callback=True): + r""" + Method that updates the state of the widget with a new set of options. + + Parameters + ---------- + n_parameters : `int` + The `list` of initial parameters values. + params_str : `str`, optional + The string that will be used as description of the slider(s). The + final description has the form ``"{}{}".format(params_str, p)``, + where ``p`` is the parameter number. + params_bounds : (`float`, `float`), optional + The minimum and maximum bounds, in std units, for the sliders. + params_step : `float`, optional + The step, in std units, of the sliders. + plot_variance_visible : `bool`, optional + Defines whether the button for plotting the variance will be visible + upon construction. + animation_step : `float`, optional + The parameters step that is applied when animation is enabled. + allow_callback : `bool`, optional + If ``True``, it allows triggering of any callback functions. + """ + # Keep old value + old_value = self.selected_values + + # Temporarily remove render callback + render_function = self._render_function + self.remove_render_function() + + # Parse given options + if n_parameters is None: + n_parameters = self.n_parameters + if params_str is None: + params_str = '' + if params_bounds is None: + params_bounds = self.params_bounds + if params_step is None: + params_step = self.params_step + + # Set plot variance visibility + self.plot_button.layout.visibility = ( + 'visible' if plot_variance_visible else 'hidden') + self.animation_step = animation_step + + # Update widget + if n_parameters == self.n_parameters: + # The number of parameters hasn't changed + if self.mode == 'multiple': + for p, sl in enumerate(self.sliders): + self.parameters_children[p].children[0].value = \ + "{}{}".format(params_str, p) + sl.min = params_bounds[0] + sl.max = params_bounds[1] + sl.step = params_step + else: + self.slider.min = params_bounds[0] + self.slider.max = params_bounds[1] + self.slider.step = params_step + if not params_str == '': + vals = OrderedDict() + for p in range(n_parameters): + vals["{}{}".format(params_str, p)] = p + self.dropdown_params.options = vals + else: + # The number of parameters has changed + self.selected_values = [0.] * n_parameters + if self.mode == 'multiple': + # Create new sliders + self.sliders = [] + self.parameters_children = [] + for p in range(n_parameters): + slider_title = ipywidgets.HTML( + value="{}{}".format(params_str, p)) + slider_wid = ipywidgets.FloatSlider( + description='', min=params_bounds[0], + max=params_bounds[1], + step=params_step, value=0., width='8cm', + continuous_update=self.continuous_update) + tmp = ipywidgets.HBox([slider_title, slider_wid]) + tmp.layout.align_items = 'center' + self.sliders.append(slider_wid) + self.parameters_children.append(tmp) + self.parameters_wid.children = self.parameters_children + + # Assign saving values and main plotting function to all sliders + for w in self.sliders: + w.observe(self._save_slider_value_from_id, names='value', + type='change') + else: + self.slider.min = params_bounds[0] + self.slider.max = params_bounds[1] + self.slider.step = params_step + vals = OrderedDict() + for p in range(n_parameters): + vals["{}{}".format(params_str, p)] = p + if self.dropdown_params.value == 0 and n_parameters > 1: + self.dropdown_params.value = 1 + self.dropdown_params.value = 0 + self.dropdown_params.options = vals + self.slider.value = 0. + + # Re-assign render callback + self.add_render_function(render_function) + + # Assign new selected options + self.n_parameters = n_parameters + self.params_str = params_str + self.params_bounds = params_bounds + self.params_step = params_step + self.plot_variance_visible = plot_variance_visible + + # trigger render function if allowed + if allow_callback: + self.call_render_function(old_value, self.selected_values) diff --git a/menpo3d/visualize/viewk3dwidgets.py b/menpo3d/visualize/viewk3dwidgets.py new file mode 100644 index 0000000..ff958a6 --- /dev/null +++ b/menpo3d/visualize/viewk3dwidgets.py @@ -0,0 +1,789 @@ +import numpy as np +from k3d import (Plot, mesh as k3d_mesh, points as k3d_points, + text as k3d_text, vectors as k3d_vectors, + line as k3d_line) +from k3d.colormaps import matplotlib_color_maps +from io import BytesIO +from ipywidgets import GridBox, Layout, Widget +from collections import defaultdict +# The colour map used for all lines and markers +GLOBAL_CMAP = 'jet' + + +def dict_figures(): + dict_fig = defaultdict(list) + for x in Widget.widgets.values(): + if hasattr(x, 'figure_id'): + dict_fig[x.figure_id].append(x.model_id) + return dict_fig + + +def list_figures(): + list_figures = list(dict_figures().keys()) + for figure_id in list_figures: + print(figure_id) + + +def clear_figure(figure_id=None): + # TODO remove figures, clear memory + dict_fig = dict_figures() + + +def _calc_distance(points): + from menpo.shape import PointCloud + pc = PointCloud(points, copy=False) + # This is the way that mayavi automatically computes the scale factor + # in case the user passes scale_factor = 'auto'. We use it for both + # the marker_size as well as the numbers_size. + xyz_min, xyz_max = pc.bounds() + x_min, y_min, z_min = xyz_min + x_max, y_max, z_max = xyz_max + distance = np.sqrt(((x_max - x_min) ** 2 + + (y_max - y_min) ** 2 + + (z_max - z_min) ** 2) / + (4 * pc.n_points ** 0.33)) + return distance + + +def rgb2int(rgb_array, keep_alpha=False): + """ + Convert rgb_array to an int color + + Parameters + ---------- + rgb_array: ndarray + An RGBA array + keep_alpha: bool + If True, the alpha value is also used + Returns + -------- + A ndarray with an int color value for each point + """ + + type_error_message = "RGB shape should be (num_points,3) or (num_points,4)" + if isinstance(rgb_array, np.ndarray): + if len(rgb_array.shape) != 2: + raise TypeError(type_error_message) + if rgb_array.shape[1] != 3 and rgb_array.shape[1] != 4: + print(rgb_array.shape[1]) + raise TypeError(type_error_message) + else: + raise TypeError("RGB shape should be numpy ndarray") + + if not keep_alpha: + rgb_array = rgb_array[:, :3] + + num_points, num_colors = rgb_array.shape + if rgb_array.dtype in (np.float32, np.float64): + rgb_array = np.asarray(np.round(255*rgb_array), dtype='uint32') + # TODO + # check for overfloat + if num_colors == 4: + return ((rgb_array[:, 0] << 32) + (rgb_array[:, 1] << 16) + + (rgb_array[:, 2] << 8) + rgb_array[:, 3]) + + return ((rgb_array[:, 0] << 16) + (rgb_array[:, 1] << 8) + rgb_array[:, 2]) + +def _parse_marker_size(marker_size, points): + distance = _calc_distance(points) + if marker_size is None: + if distance == 0: + marker_size = 1 + else: + marker_size = 0.1 * distance + return marker_size + + +def _parse_colour(colour): + from matplotlib.colors import rgb2hex + if isinstance(colour, int): + return colour + else: + return int(rgb2hex(colour)[1:], base=16) + + +def _check_colours_list(render_flag, colours_list, n_objects, error_str): + from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap + if render_flag: + if colours_list is None: + # sample colours from jet colour map + colours_list = sample_colours_from_colourmap(n_objects, + GLOBAL_CMAP) + colours_list = list(map(_parse_colour, colours_list)) + + if isinstance(colours_list, list): + if len(colours_list) == 1: + colours_list[0] = _parse_colour(colours_list[0]) + colours_list *= n_objects + elif len(colours_list) != n_objects: + raise ValueError(error_str) + else: + colours_list = [_parse_colour(colours_list)] * n_objects + else: + colours_list = [0x00FF00] * n_objects + return colours_list + + +def _calc_camera_position(points): + from menpo.shape import PointCloud + + pc = PointCloud(points, copy=False) + bounds = pc.bounding_box().points + distance = np.max(bounds[1::2] - bounds[::2]) * 2.0 + camera = [0, 0, distance, 0, 0, 0, 0, 1, 0] + + return camera + + +def _check_figure_id(obj, figure_id, new_figure): + if figure_id is None: + if new_figure: + # A new figure is created but with no figure_id + # we should create an id of 'Figure_n form' + list_ids = [] + for x in obj.widgets.values(): + if hasattr(x, 'figure_id') and x is not obj: + if x.figure_id is not None and 'Figure_' in str(x.figure_id): + try: + n_figure_id = int(x.figure_id.split('Figure_')[1]) + except ValueError: + continue + list_ids.append(n_figure_id) + if len(list_ids): + figure_id = 'Figure_{}'.format(sorted(list_ids)[-1] + 1) + else: + figure_id = 'Figure_0' + + else: + if len(obj.list_figures_ids): + figure_id = obj.list_figures_ids[-1] + else: + obj.remove_widget() + raise ValueError('You cannot plot a figure with no id and new figure False') + else: + if new_figure: + for x in obj.widgets.values(): + if hasattr(x, 'figure_id') and x is not obj: + if x.figure_id == figure_id: + obj.remove_widget() + raise ValueError('Figure id is already given') + else: + return figure_id + + obj.list_figures_ids.append(figure_id) + if hasattr(obj, 'model_id'): + obj.dict_figure_id_to_model_id[figure_id] = obj.model_id + return figure_id + + +class K3dwidgetIdentity(): + list_figures_ids = [] + dict_figure_id_to_model_id = {} + + +class K3dwidgetsRenderer(Plot, K3dwidgetIdentity): + """ Abstract class for performing visualizations using K3dwidgets. + + Parameters + ---------- + figure_id : str or `None` + A figure name or `None`. + new_figure : bool + If `True`, creates a new figure on the cell. + """ + # list_figures_ids = [] + + def __init__(self, figure_id, new_figure): + super(K3dwidgetsRenderer, self).__init__() + + self.figure_id = _check_figure_id(self, figure_id, new_figure) + self.new_figure = new_figure + self.grid_visible = False + + def _render(self): + widg_to_draw = self + if not self.new_figure: + for widg in self.widgets.values(): + if isinstance(widg, K3dwidgetsRenderer): + if widg.figure_id == self.figure_id and widg.model_id != self.model_id and widg.new_figure: + widg_to_draw = widg + return widg_to_draw + self.remove_widget() + raise Exception('Figure with id {} was not found '.format(self.figure_id)) + + return widg_to_draw + + def remove_widget(self): + super(K3dwidgetsRenderer, self).close() + # copy from close from ipywidgets.widget.Widget + self.widgets.pop(self.model_id, None) + self.comm.close() + self.comm = None + self._repr_mimebundle_ = None + + def get_figure(self): + r""" + Gets the figure specified by the combination of `self.figure_id` and + `self.new_figure`. If `self.figure_id == None` then `mlab.gcf()` + is used. `self.figure_id` is also set to the correct id of the figure + if a new figure is created. + + Returns + ------- + figure : Mayavi figure object + The figure we will be rendering on. + """ + # return self.figure + pass + + def save_figure(self, filename, format='png', size=None, + magnification='auto', overwrite=False): + r""" + Method for saving the figure of the current `figure_id` to file. + + Parameters + ---------- + filename : `str` or `file`-like object + The string path or file-like object to save the figure at/into. + format : `str` + The format to use. This must match the file path if the file path is + a `str`. + size : `tuple` of `int` or ``None``, optional + The size of the image created (unless magnification is set, + in which case it is the size of the window used for rendering). If + ``None``, then the figure size is used. + magnification : `double` or ``'auto'``, optional + The magnification is the scaling between the pixels on the screen, + and the pixels in the file saved. If you do not specify it, it will + be calculated so that the file is saved with the specified size. + If you specify a magnification, Mayavi will use the given size as a + screen size, and the file size will be ``magnification * size``. + If ``'auto'``, then the magnification will be set automatically. + overwrite : `bool`, optional + If ``True``, the file will be overwritten if it already exists. + """ + pass + + @property + def modelview_matrix(self): + r""" + Retrieves the modelview matrix for this scene. + + :type: ``(4, 4)`` `ndarray` + """ + pass + + @property + def projection_matrix(self): + r""" + Retrieves the projection matrix for this scene. + + :type: ``(4, 4)`` `ndarray` + """ + pass + + @property + def renderer_settings(self): + r""" + Returns all the information required to construct an identical + renderer to this one. + + Returns + ------- + settings : `dict` + The dictionary with the following keys: + + * ``'width'`` (`int`) : The width of the scene. + * ``'height'`` (`int`) : The height of the scene. + * ``'model_matrix'`` (`ndarray`) : The model array (identity). + * ``'view_matrix'`` (`ndarray`) : The view array. + * ``'projection_matrix'`` (`ndarray`) : The projection array. + + """ + pass + + def force_draw(self): + r""" + Method for forcing the current figure to render. This is useful for + the widgets animation. + """ + self.render() + + +class K3dwidgetsVectorViewer3d(K3dwidgetsRenderer): + def __init__(self, figure_id, new_figure, points, vectors): + super(K3dwidgetsVectorViewer3d, self).__init__(figure_id, new_figure) + non_zero_indices = np.unique(np.nonzero(vectors.reshape(-1, 3))[0]) + self.points = points[non_zero_indices].astype(np.float32) + self.vectors = vectors[non_zero_indices].astype(np.float32) + + def _render(self, colour='r', line_width=2, marker_size=None): + marker_size = _parse_marker_size(marker_size, self.points) + colour = _parse_colour(colour) + + widg_to_draw = super(K3dwidgetsVectorViewer3d, self)._render() + vectors_to_add = k3d_vectors(self.points, self.vectors, + color=colour, head_size=marker_size, + line_width=line_width) + widg_to_draw += vectors_to_add + return widg_to_draw + + +class K3dwidgetsPointGraphViewer3d(K3dwidgetsRenderer): + def __init__(self, figure_id, new_figure, points, edges): + super(K3dwidgetsPointGraphViewer3d, self).__init__(figure_id, + new_figure) + self.points = points.astype(np.float32) + self.edges = edges + + def _render(self, render_lines=True, line_colour='r', line_width=2, + render_markers=True, marker_style='flat', marker_size=10, + marker_colour='g', alpha=1.0, render_numbering=False, + numbers_colour='k', numbers_size=None, + colours=None, keep_alpha=False): + + widg_to_draw = super(K3dwidgetsPointGraphViewer3d, self)._render() + # Render the lines if requested + if render_lines and self.edges is not None: + if isinstance(line_colour, list): + line_colour = [_parse_colour(i_color) for i_color in + line_colour] + else: + line_colour = _parse_colour(line_colour) + + lines_to_add = None + for edge in self.edges: + if isinstance(line_colour, list): + if len(line_colour): + color_this_line = line_colour.pop() + else: + color_this_line = 0xFF0000 + else: + color_this_line = line_colour + + if lines_to_add is None: + lines_to_add = k3d_line(self.points[edge], + color=color_this_line) + else: + lines_to_add += k3d_line(self.points[edge], + color=color_this_line) + widg_to_draw += lines_to_add + + # Render the markers if requested + if render_markers: + marker_size = _parse_marker_size(marker_size, self.points) + if colours is not None: + colours = rgb2int(colours, keep_alpha) + marker_colour = 'w' + else: + colours = [] + + marker_colour = _parse_colour(marker_colour) + + # In order to be compatible with mayavi, we just change the + # default value for marker_style to mesh + if marker_style == 'sphere': + marker_style = 'mesh' + + # When the number of points is greater than 1000, it is recommended + # to use fast shaders: flat, 3d or 3dSpecular. + # The mesh shader generates much bigger overhead, + # but it has a properly triangularized sphere + # representing each point. + if self.points.shape[0] > 1000: + marker_style = '3dSpecular' + + points_to_add = k3d_points(self.points, colors=colours, + color=marker_colour, + point_size=marker_size, + opacity=alpha, + shader=marker_style) + widg_to_draw += points_to_add + + # TODO + # A class of k3d.texts that groups all texts should be created + # Till then, we go that way + if render_numbering: + text_to_add = None + + numbers_colour = _parse_colour(numbers_colour) + for i, point in enumerate(self.points): + if text_to_add is None: + text_to_add = k3d_text(str(i), color=numbers_colour, + position=point, label_box=False) + else: + text_to_add += k3d_text(str(i), color=numbers_colour, + position=point, label_box=False) + widg_to_draw += text_to_add + + return widg_to_draw + + +class K3dwidgetsTriMeshViewer3d(K3dwidgetsRenderer): + def __init__(self, figure_id, new_figure, points, trilist, landmarks=None): + super(K3dwidgetsTriMeshViewer3d, self).__init__(figure_id, new_figure) + self.points = points.astype(np.float32) + self.trilist = trilist.astype(np.uint32) + self.landmarks = landmarks + + def _render_mesh(self, line_width, colour, mesh_type, + marker_style, marker_size, alpha=1.0): + marker_size = _parse_marker_size(marker_size, self.points) + colour = _parse_colour(colour) + + widg_to_draw = super(K3dwidgetsTriMeshViewer3d, self)._render() + wireframe = False + opacity = alpha + if mesh_type == 'wireframe': + wireframe = True + opacity = 0.3 + + mesh_to_add = k3d_mesh(self.points, self.trilist.flatten(), + flat_shading=False, opacity=opacity, + color=colour, wireframe=wireframe, + side='double') + widg_to_draw += mesh_to_add + + if hasattr(self.landmarks, 'points'): + self.landmarks.view(figure_id=self.figure_id, + new_figure=False, + marker_style=marker_style, + marker_size=marker_size, + inline=True) + return widg_to_draw + + def _render(self, line_width=2, colour='r', mesh_type='surface', + marker_style='mesh', marker_size=None, + normals=None, normals_colour='k', normals_line_width=2, + normals_marker_size=None, alpha=1.0): + + widg_to_draw = self._render_mesh(line_width, colour, mesh_type, + marker_style, marker_size, alpha) + if normals is not None: + tmp_normals_widget = K3dwidgetsVectorViewer3d(self.figure_id, + False, self.points, + normals) + tmp_normals_widget._render(colour=normals_colour, + line_width=normals_line_width, + marker_size=normals_marker_size) + + widg_to_draw.camera = _calc_camera_position(self.points) + widg_to_draw.camera_auto_fit = False + + return widg_to_draw + + +class K3dwidgetsTexturedTriMeshViewer3d(K3dwidgetsRenderer): + def __init__(self, figure_id, new_figure, points, trilist, texture, + tcoords, landmarks): + super(K3dwidgetsTexturedTriMeshViewer3d, self).__init__(figure_id, + new_figure) + self.points = points + self.trilist = trilist + self.texture = texture + self.tcoords = tcoords + self.landmarks = landmarks + + def _render_mesh(self, mesh_type='surface', ambient_light=0.0, + specular_light=0.0, alpha=1.0): + + widg_to_draw = super(K3dwidgetsTexturedTriMeshViewer3d, self)._render() + + uvs = self.tcoords.points + tmp_img = self.texture.mirror(axis=0).as_PILImage() + img_byte_arr = BytesIO() + tmp_img.save(img_byte_arr, format='PNG') + texture = img_byte_arr.getvalue() + texture_file_format = 'png' + + mesh_to_add = k3d_mesh(self.points.astype(np.float32), + self.trilist.flatten().astype(np.uint32), + flat_shading=False, + color=0xFFFFFF, side='front', texture=texture, + uvs=uvs, + texture_file_format=texture_file_format) + + widg_to_draw += mesh_to_add + + if hasattr(self.landmarks, 'points'): + self.landmarks.view(inline=True, new_figure=False, + figure_id=self.figure_id) + + widg_to_draw.camera = _calc_camera_position(self.points) + widg_to_draw.camera_auto_fit = False + + return widg_to_draw + + def _render(self, normals=None, normals_colour='k', + normals_line_width=2, normals_marker_size=None): + + if normals is not None: + tmp_normals_widget = K3dwidgetsVectorViewer3d(self.figure_id, + False, self.points, + normals) + tmp_normals_widget._render(colour=normals_colour, + line_width=normals_line_width, + marker_size=normals_marker_size) + + self._render_mesh() + return self + + +class K3dwidgetsColouredTriMeshViewer3d(K3dwidgetsRenderer): + # TODO + def __init__(self, figure_id, new_figure, points, trilist, + colour_per_point, landmarks): + super(K3dwidgetsColouredTriMeshViewer3d, self).__init__(figure_id, + new_figure) + self.points = points + self.trilist = trilist + self.colour_per_point = colour_per_point + self.colorbar_object_id = False + self.landmarks = landmarks + + def _render_mesh(self): + widg_to_draw = super(K3dwidgetsColouredTriMeshViewer3d, self)._render() + + mesh_to_add = k3d_mesh(self.points.astype(np.float32), + self.trilist.flatten().astype(np.uint32), + attribute=self.colour_per_point, + ) + widg_to_draw += mesh_to_add + + if hasattr(self.landmarks, 'points'): + self.landmarks.view(inline=True, new_figure=False, + figure_id=self.figure_id) + widg_to_draw.camera = _calc_camera_position(self.points) + widg_to_draw.camera_auto_fit = False + + def _render(self, normals=None, normals_colour='k', normals_line_width=2, + normals_marker_size=None): + if normals is not None: + K3dwidgetsVectorViewer3d(self.figure_id, False, + self.points, normals)._render( + colour=normals_colour, line_width=normals_line_width, + marker_size=normals_marker_size) + self._render_mesh() + return self + + +class K3dwidgetsSurfaceViewer3d(K3dwidgetsRenderer): + def __init__(self, figure_id, new_figure, values, mask=None): + super(K3dwidgetsSurfaceViewer3d, self).__init__(figure_id, new_figure) + if mask is not None: + values[~mask] = np.nan + self.values = values + + def render(self, colour=(1, 0, 0), line_width=2, step=None, + marker_style='2darrow', marker_resolution=8, marker_size=0.05, + alpha=1.0): + # warp_scale = kwargs.get('warp_scale', 'auto') + # mlab.surf(self.values, warp_scale=warp_scale) + return self + + +class K3dwidgetsLandmarkViewer3d(K3dwidgetsRenderer): + def __init__(self, figure_id, new_figure, group, landmark_group): + super(K3dwidgetsLandmarkViewer3d, self).__init__(figure_id, new_figure) + self.group = group + self.landmark_group = landmark_group + + def _render(self, render_lines=True, line_colour='r', line_width=2, + render_markers=True, marker_style='mesh', marker_size=None, + marker_colour='r', alpha=1.0, render_numbering=False, + numbers_colour='k', numbers_size=None): + # Regarding the labels colours, we may get passed either no colours (in + # which case we generate random colours) or a single colour to colour + # all the labels with + # TODO: All marker and line options could be defined as lists... + n_labels = self.landmark_group.n_labels + line_colour = _check_colours_list( + render_lines, line_colour, n_labels, + 'Must pass a list of line colours with length n_labels or a single' + 'line colour for all labels.') + marker_colour = _check_colours_list( + render_markers, marker_colour, n_labels, + 'Must pass a list of marker colours with length n_labels or a ' + 'single marker face colour for all labels.') + marker_size = _parse_marker_size(marker_size, + self.landmark_group.points) + numbers_size = _parse_marker_size(numbers_size, + self.landmark_group.points) + + # get pointcloud of each label + sub_pointclouds = self._build_sub_pointclouds() + + widg_to_draw = super(K3dwidgetsLandmarkViewer3d, self)._render() + + if marker_style == 'sphere': + marker_style = 'mesh' + + for i, (label, pc) in enumerate(sub_pointclouds): + points_to_add = k3d_points(pc.points.astype(np.float32), + color=marker_colour[i], + point_size=marker_size, + shader=marker_style) + widg_to_draw += points_to_add + if render_numbering: + text_to_add = None + numbers_colour = _parse_colour(numbers_colour) + for i, point in enumerate(self.landmark_group.points): + if text_to_add is None: + text_to_add = k3d_text(str(i), color=numbers_colour, + position=point, label_box=False) + else: + text_to_add += k3d_text(str(i), color=numbers_colour, + position=point, label_box=False) + widg_to_draw += text_to_add + # widg_to_draw.camera = _calc_camera_position(pc.points) + # widg_to_draw.camera_auto_fit = False + + return widg_to_draw + + def _build_sub_pointclouds(self): + return [(label, self.landmark_group.get_label(label)) + for label in self.landmark_group.labels] + + +class K3dwidgetsHeatmapViewer3d(K3dwidgetsRenderer): + def __init__(self, figure_id, new_figure, points, trilist, landmarks=None): + super(K3dwidgetsHeatmapViewer3d, self).__init__(figure_id, new_figure) + self.points = points + self.trilist = trilist + self.landmarks = landmarks + + def _render_mesh(self, distances_between_meshes, type_cmap, + scalar_range, show_statistics=False): + + marker_size = _parse_marker_size(None, self.points) + widg_to_draw = super(K3dwidgetsHeatmapViewer3d, self)._render() + + try: + color_map = getattr(matplotlib_color_maps, type_cmap) + except AttributeError: + print('Could not find colormap {}. Hot_r is going to be used instead'.format(type_cmap)) + color_map = getattr(matplotlib_color_maps, 'hot_r') + + mesh_to_add = k3d_mesh(self.points.astype(np.float32), + self.trilist.flatten().astype(np.uint32), + color_map=color_map, + attribute=distances_between_meshes, + color_range=scalar_range + ) + widg_to_draw += mesh_to_add + + if hasattr(self.landmarks, 'points'): + self.landmarks.view(figure_id=self.figure_id, + new_figure=False, + marker_size=marker_size, + inline=True) + + if show_statistics: + text = '\\begin{{matrix}} \\mu & {:.3} \\\\ \\sigma^2 & {:.3} \\\\ \\max & {:.3} \\end{{matrix}}'\ + .format(distances_between_meshes.mean(), + distances_between_meshes.std(), + distances_between_meshes.max()) + min_b = np.min(self.points, axis=0) + max_b = np.max(self.points, axis=0) + text_position = (max_b-min_b)/2 + widg_to_draw += k3d_text(text, position=text_position, + color=0xff0000, size=1) + + widg_to_draw.camera = _calc_camera_position(self.points) + widg_to_draw.camera_auto_fit = False + + return widg_to_draw + + def _render(self, distances_between_meshes, type_cmap='hot_r', + scalar_range=[0, 2], show_statistics=False): + return self._render_mesh(distances_between_meshes, type_cmap, + scalar_range, show_statistics) + + +class K3dwidgetsPCAModelViewer3d(GridBox, K3dwidgetIdentity): + def __init__(self, figure_id, new_figure, points, trilist, + components, eigenvalues, n_parameters, parameters_bound, + landmarks_indices, widget_style): + + from .menpowidgets import LinearModelParametersWidget + + self.figure_id = _check_figure_id(self, figure_id, new_figure) + self.new_figure = new_figure + self.points = points.astype(np.float32) + if trilist is None: + self.trilist = None + else: + self.trilist = trilist.astype(np.uint32) + self.components = components.astype(np.float32) + self.eigenvalues = eigenvalues.astype(np.float32) + self.n_parameters = n_parameters + self.landmarks_indices = landmarks_indices + self.layout = Layout(grid_template_columns='1fr 1fr') + self.wid = LinearModelParametersWidget(n_parameters=n_parameters, + render_function=self.render_function, + params_str='Parameter ', + mode='multiple', + params_bounds=parameters_bound, + plot_variance_visible=False, + style=widget_style) + if self.trilist is None: + self.mesh_window = K3dwidgetsPointGraphViewer3d(self.figure_id, False, + self.points, self.trilist) + else: + self.mesh_window = K3dwidgetsTriMeshViewer3d(self.figure_id, False, + self.points, self.trilist) + super(K3dwidgetsPCAModelViewer3d, self).__init__(children=[self.wid, self.mesh_window], + layout=Layout(grid_template_columns='1fr 1fr')) + + self.dict_figure_id_to_model_id[figure_id] = self.model_id + + def _render_mesh(self, mesh_type, line_width, colour, + marker_size, marker_style, alpha): + marker_size = _parse_marker_size(marker_size, self.points) + colour = _parse_colour(colour) + + if self.trilist is None: + mesh_to_add = k3d_points(self.points, color=colour, + opacity=alpha, + point_size=marker_size, + shader='3dSpecular') + else: + mesh_to_add = k3d_mesh(self.points, self.trilist.flatten(), + flat_shading=False, opacity=alpha, + color=colour, name='Instance', + side='double') + + self.mesh_window += mesh_to_add + + if self.landmarks_indices is not None: + landmarks_to_add = k3d_points(self.points[self.landmarks_indices], + color=0x00FF00, name='landmarks', + point_size=marker_size, + shader='mesh') + self.mesh_window += landmarks_to_add + self.mesh_window.camera = _calc_camera_position(self.points) + self.mesh_window.camera_auto_fit = False + + return self + + def render_function(self, change): + weights = np.asarray(self.wid.selected_values).astype(np.float32) + weighted_eigenvalues = weights * self.eigenvalues[:self.n_parameters]**0.5 + new_instance = (self.components[:self.n_parameters, :].T@weighted_eigenvalues).reshape(-1, 3) + new_points = self.points + new_instance + + if self.trilist is None: + self.mesh_window.objects[0].positions = new_points + else: + self.mesh_window.objects[0].vertices = new_points + if self.landmarks_indices is not None: + self.mesh_window.objects[1].positions = new_points[self.landmarks_indices] + + def _render(self, mesh_type='wireframe', line_width=2, colour='r', + marker_style='mesh', marker_size=None, alpha=1.0): + + return self._render_mesh(mesh_type, line_width, colour, + marker_size, marker_style, alpha) + + def remove_widget(self): + super(K3dwidgetsPCAModelViewer3d, self).close() diff --git a/menpo3d/visualize/viewmayavi.py b/menpo3d/visualize/viewmayavi.py index a72f1f3..5fac21f 100644 --- a/menpo3d/visualize/viewmayavi.py +++ b/menpo3d/visualize/viewmayavi.py @@ -695,29 +695,67 @@ def render( self.figure.scene.disable_render = True for i, (label, pc) in enumerate(sub_pointclouds): # render pointcloud - pc.view( - figure_id=self.figure_id, - new_figure=False, - render_lines=render_lines, - line_colour=line_colour[i], - line_width=line_width, - render_markers=render_markers, - marker_style=marker_style, - marker_size=marker_size, - marker_colour=marker_colour[i], - marker_resolution=marker_resolution, - step=step, - alpha=alpha, - render_numbering=render_numbering, - numbers_colour=numbers_colour, - numbers_size=numbers_size, - ) + pc.view(figure_id=self.figure_id, new_figure=False, + render_lines=render_lines, line_colour=line_colour[i], + line_width=line_width, render_markers=render_markers, + marker_style=marker_style, marker_size=marker_size, + marker_colour=marker_colour[i], + marker_resolution=marker_resolution, step=step, + alpha=alpha, render_numbering=render_numbering, + numbers_colour=numbers_colour, numbers_size=numbers_size, + inline=False) self.figure.scene.disable_render = False return self def _build_sub_pointclouds(self): - return [ - (label, self.landmark_group.get_label(label)) - for label in self.landmark_group.labels - ] + return [(label, self.landmark_group.get_label(label)) + for label in self.landmark_group.labels] + + +class MayaviHeatmapViewer3d(MayaviRenderer): + def __init__(self, figure_id, new_figure, points, trilist): + super(MayaviHeatmapViewer3d, self).__init__(figure_id, new_figure) + self.points = points + self.trilist = trilist + + def _render_mesh(self, scaled_distances_between_meshes, + type_cmap, scalar_range, show_statistics): + from mayavi import mlab + # v = mlab.figure(figure=figure_name, size=size, + # bgcolor=(1, 1, 1), fgcolor=(0, 0, 0)) + src = mlab.pipeline.triangular_mesh_source(self.points[:, 0], + self.points[:, 1], + self.points[:, 2], + self.trilist, + scalars=scaled_distances_between_meshes) + surf = mlab.pipeline.surface(src, colormap=type_cmap) + # When font size bug resolved, uncomment + # cb=mlab.colorbar(title='Distances in mm', + # orientation='vertical', nb_labels=5) + # cb.title_text_property.font_size = 20 + # cb.label_text_property.font_family = 'times' + # cb.label_text_property.font_size=10 + cb = mlab.colorbar(orientation='vertical', nb_labels=5) + cb.data_range = scalar_range + cb.scalar_bar_representation.position = [0.8, 0.15] + cb.scalar_bar_representation.position2 = [0.15, 0.7] + text = mlab.text(0.8, 0.85, 'Distances in mm') + text.width = 0.20 + if show_statistics: + text2 = mlab.text(0.5, 0.02, + 'Mean error {:.3}mm \nMax error {:.3}mm \ + '.format(scaled_distances_between_meshes.mean(), + scaled_distances_between_meshes.max())) + text2.width = 0.20 + surf.module_manager.scalar_lut_manager.reverse_lut = True + # perhaps we shouud usew kwargs + # if camera_settings is None: + mlab.gcf().scene.z_plus_view() + + def render(self, scaled_distances_between_meshes, type_cmap='hot', + scalar_range=(0, 2), show_statistics=False): + + self._render_mesh(scaled_distances_between_meshes, type_cmap, + scalar_range, show_statistics) + return self diff --git a/setup.py b/setup.py index 1e8e5f2..5498039 100644 --- a/setup.py +++ b/setup.py @@ -111,7 +111,8 @@ def build_extension_from_pyx(pyx_path, extra_sources_paths=None): version, cmdclass = get_version_and_cmdclass("menpo3d") -install_requires = ["menpo>=0.9.0,<0.12.0", "mayavi>=4.7.0", "moderngl>=5.6.*,<6.0"] +install_requires = ["menpo>=0.9.0,<0.12.0", "mayavi>=4.7.0", + "moderngl>=5.6,<6.0", "k3d<=2.9.2", "ipywidgets<8"] setup( name="menpo3d",