Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slicing and nurbs derivatives #69

Draft
wants to merge 5 commits into
base: 6.x-dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions geomdl/NURBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@ def __deepcopy__(self, memo):
result.init_cache()
return result

def init_cache(self):
self._cache['ctrlpts'] = self._init_array()
self._cache['weights'] = self._init_array()
def init_cache(self, ctrlpts=[], weights=[]):
self._cache['ctrlpts'] = self._array_type(iter(ctrlpts))
self._cache['weights'] = self._array_type(iter(weights))
self._cache['ctrlpts'].register_callback(lambda: setattr(self, '_control_points_valid', False))
self._cache['weights'].register_callback(lambda: setattr(self, '_control_points_valid', False))

@property
def ctrlptsw(self):
Expand All @@ -107,6 +109,9 @@ def ctrlptsw(self):
:getter: Gets the weighted control points
:setter: Sets the weighted control points
"""
if not self._control_points_valid:
ctrlptsw = compatibility.combine_ctrlpts_weights(self.ctrlpts, self.weights)
self.set_ctrlpts(ctrlptsw)
return self._control_points

@ctrlptsw.setter
Expand All @@ -127,22 +132,17 @@ def ctrlpts(self):
# Populate the cache, if necessary
if not self._cache['ctrlpts']:
c, w = compatibility.separate_ctrlpts_weights(self._control_points)
self._cache['ctrlpts'] = [crd for crd in c]
self._cache['weights'] = w
self.init_cache(c, w)
return self._cache['ctrlpts']

@ctrlpts.setter
def ctrlpts(self, value):
# Check if we can retrieve the existing weights. If not, generate a weights vector of 1.0s.
if not self.weights:
weights = [1.0 for _ in range(len(value))]
else:
weights = self.weights
self.weights[:] = [1.0 for _ in range(len(value))]

# Generate weighted control points using the new control points
ctrlptsw = compatibility.combine_ctrlpts_weights(value, weights)

# Set new weighted control points
ctrlptsw = compatibility.combine_ctrlpts_weights(value, self.weights)
self.set_ctrlpts(ctrlptsw)

@property
Expand All @@ -159,8 +159,7 @@ def weights(self):
# Populate the cache, if necessary
if not self._cache['weights']:
c, w = compatibility.separate_ctrlpts_weights(self._control_points)
self._cache['ctrlpts'] = [crd for crd in c]
self._cache['weights'] = w
self.init_cache(c, w)
return self._cache['weights']

@weights.setter
Expand All @@ -174,6 +173,12 @@ def weights(self, value):
# Set new weighted control points
self.set_ctrlpts(ctrlptsw)

def _check_variables(self):
super(Curve, self)._check_variables()
if not self._control_points_valid:
ctrlptsw = compatibility.combine_ctrlpts_weights(self.ctrlpts, self.weights)
self.set_ctrlpts(ctrlptsw)

def reset(self, **kwargs):
""" Resets control points and/or evaluated points.

Expand All @@ -189,10 +194,9 @@ def reset(self, **kwargs):
# Call parent function
super(Curve, self).reset(ctrlpts=reset_ctrlpts, evalpts=reset_evalpts)

# Delete the caches
if reset_ctrlpts:
# Delete the caches
self._cache['ctrlpts'] = self._init_array()
self._cache['weights'][:] = self._init_array()
self.init_cache()


@export
Expand Down Expand Up @@ -330,8 +334,8 @@ def ctrlpts(self):
"""
if not self._cache['ctrlpts']:
c, w = compatibility.separate_ctrlpts_weights(self._control_points)
self._cache['ctrlpts'] = [crd for crd in c]
self._cache['weights'] = w
self._cache['ctrlpts'] = self._array_type(iter(c))
self._cache['weights'] = self._array_type(iter(w))
return self._cache['ctrlpts']

@ctrlpts.setter
Expand Down Expand Up @@ -361,8 +365,8 @@ def weights(self):
"""
if not self._cache['weights']:
c, w = compatibility.separate_ctrlpts_weights(self._control_points)
self._cache['ctrlpts'] = [crd for crd in c]
self._cache['weights'] = w
self._cache['ctrlpts'] = self._array_type(iter(c))
self._cache['weights'] = self._array_type(iter(w))
return self._cache['weights']

@weights.setter
Expand Down Expand Up @@ -518,8 +522,8 @@ def ctrlpts(self):
"""
if not self._cache['ctrlpts']:
c, w = compatibility.separate_ctrlpts_weights(self._control_points)
self._cache['ctrlpts'] = [crd for crd in c]
self._cache['weights'] = w
self._cache['ctrlpts'] = self._array_type(iter(c))
self._cache['weights'] = self._array_type(iter(w))
return self._cache['ctrlpts']

@ctrlpts.setter
Expand Down Expand Up @@ -549,8 +553,8 @@ def weights(self):
"""
if not self._cache['weights']:
c, w = compatibility.separate_ctrlpts_weights(self._control_points)
self._cache['ctrlpts'] = [crd for crd in c]
self._cache['weights'] = w
self._cache['ctrlpts'] = self._array_type(iter(c))
self._cache['weights'] = self._array_type(iter(w))
return self._cache['weights']

@weights.setter
Expand Down
55 changes: 55 additions & 0 deletions geomdl/_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# callback handlers for list modification
# https://stackoverflow.com/a/13259435/1162349

import sys

_pyversion = sys.version_info[0]

def callback_method(func):
def notify(self,*args,**kwargs):
for _,callback in self._callbacks:
callback()
return func(self,*args,**kwargs)
return notify

class NotifyList(list):
extend = callback_method(list.extend)
append = callback_method(list.append)
remove = callback_method(list.remove)
pop = callback_method(list.pop)
__delitem__ = callback_method(list.__delitem__)
__setitem__ = callback_method(list.__setitem__)
__iadd__ = callback_method(list.__iadd__)
__imul__ = callback_method(list.__imul__)

#Take care to return a new NotifyList if we slice it.
if _pyversion < 3:
__setslice__ = callback_method(list.__setslice__)
__delslice__ = callback_method(list.__delslice__)
def __getslice__(self,*args):
return self.__class__(list.__getslice__(self,*args))

def __getitem__(self,item):
if isinstance(item,slice):
return self.__class__(list.__getitem__(self,item))
else:
return list.__getitem__(self,item)

def __init__(self,*args):
list.__init__(self,*args)
self._callbacks = []
self._callback_cntr = 0

def register_callback(self,cb):
self._callbacks.append((self._callback_cntr,cb))
self._callback_cntr += 1
return self._callback_cntr - 1

def unregister_callback(self,cbid):
for idx,(i,cb) in enumerate(self._callbacks):
if i == cbid:
self._callbacks.pop(idx)
return cb
else:
return None

10 changes: 7 additions & 3 deletions geomdl/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .six import add_metaclass
from . import vis, helpers, knotvector, voxelize, utilities, tessellate
from .base import GeomdlBase, GeomdlEvaluator, GeomdlError, GeomdlWarning, GeomdlTypeSequence
from ._collections import NotifyList


@add_metaclass(abc.ABCMeta)
Expand All @@ -34,9 +35,9 @@ class Geometry(GeomdlBase):
# __slots__ = ('_iter_index', '_array_type', '_eval_points')

def __init__(self, *args, **kwargs):
self._geometry_type = "default" if not hasattr(self, '_geometry_type') else self._geometry_type # geometry type
super(Geometry, self).__init__(*args, **kwargs)
self._array_type = list if not hasattr(self, '_array_type') else self._array_type # array storage type
self._geometry_type = getattr(self, '_geometry_type', 'default') # geometry type
self._array_type = getattr(self, '_array_type', NotifyList) # array storage type
self._eval_points = self._init_array() # evaluated points

def __iter__(self):
Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(self, **kwargs):
self._knot_vector = [self._init_array() for _ in range(self._pdim)] # knot vector
self._control_points = self._init_array() # control points
self._control_points_size = [0 for _ in range(self._pdim)] # control points length
self._control_points_valid = False
self._delta = [self._dinit for _ in range(self._pdim)] # evaluation delta
self._bounding_box = self._init_array() # bounding box
self._evaluator = None # evaluator instance
Expand Down Expand Up @@ -465,7 +467,7 @@ def validate_and_clean(pts_in, check_for, dimension, pts_out, **kws):
raise ValueError("Number of arguments after ctrlpts must be " + str(self._pdim))

# Keyword arguments
array_init = kwargs.get('array_init', [[] for _ in range(len(ctrlpts))])
array_init = kwargs.get('array_init', self._array_type([] for _ in range(len(ctrlpts))))
array_check_for = kwargs.get('array_check_for', (list, tuple))
callback_func = kwargs.get('callback', validate_and_clean)
self._dimension = kwargs.get('dimension', len(ctrlpts[0]))
Expand All @@ -479,6 +481,7 @@ def validate_and_clean(pts_in, check_for, dimension, pts_out, **kws):
# Set control points and sizes
self._control_points = callback_func(ctrlpts, array_check_for, self._dimension, array_init, **kwargs)
self._control_points_size = [int(arg) for arg in args]
self._control_points_valid = True

@abc.abstractmethod
def render(self, **kwargs):
Expand Down Expand Up @@ -890,6 +893,7 @@ def reset(self, **kwargs):
if reset_ctrlpts:
self._control_points = self._init_array()
self._bounding_box = self._init_array()
self._control_points_valid = False

if reset_evalpts:
self._eval_points = self._init_array()
Expand Down
30 changes: 30 additions & 0 deletions tests/test_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

Requires "pytest" to run.
"""
import math

from pytest import fixture, mark
from geomdl import BSpline
from geomdl import NURBS
from geomdl import evaluators
from geomdl import helpers
from geomdl import convert
Expand Down Expand Up @@ -234,6 +236,13 @@ def nurbs_curve(spline_curve):
curve.weights = [0.5, 1.0, 0.75, 1.0, 0.25, 1.0]
return curve

@fixture
def unit_circle_tri_ctrlpts():
r = 1.
a, h = 3. * r / math.sqrt(3.), 1.5 * r
ctrlpts = [(0., -r), (-a,-r), (-a/2,-r+h), (0., 2*h-r), (a/2, -r+h), (a, -r), (0., -r)]
return ctrlpts


def test_nurbs_curve2d_weights(nurbs_curve):
assert nurbs_curve.weights == [0.5, 1.0, 0.75, 1.0, 0.25, 1.0]
Expand All @@ -252,6 +261,27 @@ def test_nurbs_curve2d_eval(nurbs_curve, param, res):
assert abs(evalpt[1] - res[1]) < GEOMDL_DELTA


@mark.parametrize("param, res", [
(0.0, (0.0, -1.0)),
(0.2, (-0.9571859726038534, -0.2894736842105261)),
(0.5, (1.1102230246251568e-16, 1.0)),
(0.95, (0.27544074447012257, -0.9613180515759312))
])
def test_nurbs_curve2d_slice_eval(unit_circle_tri_ctrlpts, param, res):
crv = NURBS.Curve()
crv.degree = 2
crv.ctrlpts = unit_circle_tri_ctrlpts
crv.knotvector = [0.,0.,0., 1./3, 1./3, 2./3, 2./3, 1.,1.,1.]
crv.weights[1::2] = [0.5, 0.5, 0.5]

evalpt = crv.evaluate_single(param)

assert abs(evalpt[0] - res[0]) < GEOMDL_DELTA
assert abs(evalpt[1] - res[1]) < GEOMDL_DELTA


# TODO: derivative of a circle is a circle
@mark.xfail
@mark.parametrize("param, order, res", [
(0.0, 1, ((5.0, 5.0), (90.9090, 90.9090))),
(0.2, 2, ((13.8181, 11.5103), (40.0602, 17.3878), (104.4062, -29.3672))),
Expand Down