diff --git a/spatialmath/DualQuaternion.py b/spatialmath/DualQuaternion.py index 3b945d7c..f4fa0ece 100644 --- a/spatialmath/DualQuaternion.py +++ b/spatialmath/DualQuaternion.py @@ -75,7 +75,7 @@ def __init__(self, real: Quaternion = None, dual: Quaternion = None): @classmethod def Pure(cls, x: ArrayLike3) -> Self: x = base.getvector(x, 3) - return cls(UnitQuaternion(), Quaternion.Pure(x)) + return cls(UnitQuaternion.identity(), Quaternion.Pure(x)) def __repr__(self) -> str: return str(self) @@ -354,13 +354,3 @@ def SE3(self) -> SE3: # w = self.real.v # v = self.dual.v # theta = base.norm(w) - - -if __name__ == "__main__": # pragma: no cover - - from spatialmath import SE3, UnitDualQuaternion - - print(UnitDualQuaternion(SE3())) - # import pathlib - - # exec(open(pathlib.Path(__file__).parent.parent.absolute() / "tests" / "test_dualquaternion.py").read()) # pylint: disable=exec-used diff --git a/spatialmath/base/animate.py b/spatialmath/base/animate.py index a2e31f72..730ce18b 100755 --- a/spatialmath/base/animate.py +++ b/spatialmath/base/animate.py @@ -882,36 +882,3 @@ def set_xlabel(self, *args, **kwargs): def set_ylabel(self, *args, **kwargs): self.ax.set_ylabel(*args, **kwargs) - - -if __name__ == "__main__": - # from spatialmath import UnitQuaternion - # from spatialmath.base import tranimate, r2t - - # J = np.array([[2, -1, 0], [-1, 4, 0], [0, 0, 3]]) - # dt = 0.05 - # def attitude(): - # attitude = UnitQuaternion() - # w = 0.2 * np.r_[1, 2, 2].T - # for t in np.arange(0, 3, dt): - # wd = -np.linalg.inv(J) @ (np.cross(w, J @ w)) - # w += wd * dt - # attitude.increment(w * dt) - # yield attitude.R - # plt.figure() - # plotvol3(2) - # tranimate(attitude()) - - from spatialmath import base - - # T = smb.rpy2r(0.3, 0.4, 0.5) - # # smb.tranimate(T, wait=True) - # s = smb.tranimate(T, movie=True) - # with open("zz.html", "w") as f: - # print(f"{s}", file=f) - - T = smb.rot2(2) - # smb.tranimate2(T, wait=True) - s = smb.tranimate2(T, movie=True) - with open("zz.html", "w") as f: - print(f"{s}", file=f) diff --git a/spatialmath/base/argcheck.py b/spatialmath/base/argcheck.py index 40f94336..b8e3001b 100644 --- a/spatialmath/base/argcheck.py +++ b/spatialmath/base/argcheck.py @@ -679,16 +679,3 @@ def islistof(value: Any, what: Union[Type, Callable], n: Optional[int] = None): return all([what(x) for x in value]) else: raise ValueError("bad value of what") - - -if __name__ == "__main__": - import pathlib - - exec( - open( - pathlib.Path(__file__).parent.parent.parent.absolute() - / "tests" - / "base" - / "test_argcheck.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/base/graphics.py b/spatialmath/base/graphics.py index 2ce18dc8..e25eb606 100644 --- a/spatialmath/base/graphics.py +++ b/spatialmath/base/graphics.py @@ -1752,18 +1752,6 @@ def isnotebook() -> bool: except NameError: return False # Probably standard Python interpreter - if __name__ == "__main__": - import pathlib - - exec( - open( - pathlib.Path(__file__).parent.parent.parent.absolute() - / "tests" - / "base" - / "test_graphics.py" - ).read() - ) # pylint: disable=exec-used - except ImportError: # pragma: no cover def plot_text(*args, **kwargs) -> None: diff --git a/spatialmath/base/numeric.py b/spatialmath/base/numeric.py index 748086fa..52839fda 100644 --- a/spatialmath/base/numeric.py +++ b/spatialmath/base/numeric.py @@ -429,18 +429,3 @@ def gauss2d(mu: ArrayLike2, P: NDArray, X: NDArray, Y: NDArray) -> NDArray: * np.exp(-0.5 * (x**2 * Pi[0, 0] + y**2 * Pi[1, 1] + 2 * x * y * Pi[0, 1])) ) return g.reshape(X.shape) - - -if __name__ == "__main__": - r = np.linspace(-4, 4, 6) - x, y = np.meshgrid(r, r) - print(gauss2d([0, 0], np.diag([1, 2]), x, y)) - # print(bresenham([2,2], [2,4])) - # print(bresenham([2,2], [2,-4])) - # print(bresenham([2,2], [4,2])) - # print(bresenham([2,2], [-4,2])) - # print(bresenham([2,2], [2,2])) - # print(bresenham([2,2], [3,6])) # steep - # print(bresenham([2,2], [6,3])) # shallow - # print(bresenham([2,2], [3,6])) # steep - # print(bresenham([2,2], [6,3])) # shallow diff --git a/spatialmath/base/quaternions.py b/spatialmath/base/quaternions.py index 8f33bc1c..ed788f0d 100755 --- a/spatialmath/base/quaternions.py +++ b/spatialmath/base/quaternions.py @@ -1134,16 +1134,3 @@ def qprint( file.write(s + "\n") else: return s - - -if __name__ == "__main__": # pragma: no cover - import pathlib - - exec( - open( - pathlib.Path(__file__).parent.parent.parent.absolute() - / "tests" - / "base" - / "test_quaternions.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/base/transforms2d.py b/spatialmath/base/transforms2d.py index 682ea0ca..aaf09382 100644 --- a/spatialmath/base/transforms2d.py +++ b/spatialmath/base/transforms2d.py @@ -1515,58 +1515,3 @@ def tranimate2(T: Union[SO2Array, SE2Array], **kwargs): anim = smb.animate.Animate2(dims=dims, axes=ax, **kwargs) anim.trplot2(T, **kwargs) return anim.run(**kwargs) - - -if __name__ == "__main__": # pragma: no cover - import pathlib - import matplotlib.pyplot as plt - - # trplot2( transl2(1,2), frame='A', rviz=True, width=1) - # trplot2( transl2(3,1), color='red', arrow=True, width=3, frame='B') - # trplot2( transl2(4, 3)@trot2(math.pi/3), color='green', frame='c') - # plt.grid(True) - - # fig, ax = plt.subplots(3,3, figsize=(10,10)) - # text_opts = dict(bbox=dict(boxstyle="round", - # fc="w", - # alpha=0.9), - # zorder=20, - # family='monospace', - # fontsize=8, - # verticalalignment='top') - # T = transl2(2, 1)@trot2(math.pi/3) - # trplot2(T, ax=ax[0][0], dims=[0,4,0,4]) - # ax[0][0].text(0.2, 3.8, "trplot2(T)", **text_opts) - - # trplot2(T, ax=ax[0][1], dims=[0,4,0,4], originsize=0) - # ax[0][1].text(0.2, 3.8, "trplot2(T, originsize=0)", **text_opts) - - # trplot2(T, ax=ax[0][2], dims=[0,4,0,4], arrow=False) - # ax[0][2].text(0.2, 3.8, "trplot2(T, arrow=False)", **text_opts) - - # trplot2(T, ax=ax[1][0], dims=[0,4,0,4], axislabel=False) - # ax[1][0].text(0.2, 3.8, "trplot2(T, axislabel=False)", **text_opts) - - # trplot2(T, ax=ax[1][1], dims=[0,4,0,4], width=3) - # ax[1][1].text(0.2, 3.8, "trplot2(T, width=3)", **text_opts) - - # trplot2(T, ax=ax[1][2], dims=[0,4,0,4], frame='B') - # ax[1][2].text(0.2, 3.8, "trplot2(T, frame='B')", **text_opts) - - # trplot2(T, ax=ax[2][0], dims=[0,4,0,4], color='r', textcolor='k') - # ax[2][0].text(0.2, 3.8, "trplot2(T, color='r',\n textcolor='k')", **text_opts) - - # trplot2(T, ax=ax[2][1], dims=[0,4,0,4], labels=("u", "v")) - # ax[2][1].text(0.2, 3.8, "trplot2(T, labels=('u', 'v'))", **text_opts) - - # trplot2(T, ax=ax[2][2], dims=[0,4,0,4], rviz=True) - # ax[2][2].text(0.2, 3.8, "trplot2(T, rviz=True)", **text_opts) - - exec( - open( - pathlib.Path(__file__).parent.parent.parent.absolute() - / "tests" - / "base" - / "test_transforms2d.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index 3617f965..b7a4edfe 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -339,7 +339,7 @@ def transl(x, y=None, z=None): # SE(3) -> R3 return x[:3, 3] else: - raise ValueError("bad argument") + raise ValueError(f"bad argument {x}") if t.dtype != "O": t = t.astype("float64") @@ -3414,50 +3414,3 @@ def tranimate(T: Union[SO3Array, SE3Array], **kwargs) -> str: anim = Animate(dim=dim, ax=ax, **kwargs) anim.trplot(T, **kwargs) return anim.run(**kwargs) - - -if __name__ == "__main__": # pragma: no cover - # import sympy - # from spatialmath.base.symbolic import * - - # p, q, r = symbol('phi theta psi') - # print(p) - - # print(angvelxform([p, q, r], representation='eul')) - - import pathlib - - # exec( - # open( - # pathlib.Path(__file__).parent.parent.parent.absolute() - # / "tests" - # / "base" - # / "test_transforms3d.py" - # ).read() - # ) # pylint: disable=exec-used - - # exec( - # open( - # pathlib.Path(__file__).parent.parent.parent.absolute() - # / "tests" - # / "base" - # / "test_transforms3d_plot.py" - # # ).read() - # ) # pylint: disable=exec-used - import numpy as np - - T = np.array( - [ - [1, 3.881e-14, 0, -1.985e-13], - [-3.881e-14, 1, 1.438e-11, 1.192e-13], - [0, -1.438e-11, 1, 0], - [0, 0, 0, 1], - ] - ) - # theta, vec = tr2angvec(T) - # print(theta, vec) - # print(trlog(T, twist=True)) - R = rotx(np.pi / 2) - s = tranimate(R, movie=True) - with open("z.html", "w") as f: - print(f"{s} float: return Matrix(m).det() else: return np.linalg.det(m) - - -if __name__ == "__main__": # pragma: no cover - import pathlib - - exec( - open( - pathlib.Path(__file__).parent.parent.parent.absolute() - / "tests" - / "base" - / "test_transformsNd.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/base/vectors.py b/spatialmath/base/vectors.py index f29740a3..7b6a880e 100644 --- a/spatialmath/base/vectors.py +++ b/spatialmath/base/vectors.py @@ -847,16 +847,3 @@ def orthogonalize(v1: ArrayLike3, v2: ArrayLike3, normalize: bool = True) -> Arr if normalize: v_orth = v_orth / np.linalg.norm(v_orth) return v_orth - - -if __name__ == "__main__": # pragma: no cover - import pathlib - - exec( - open( - pathlib.Path(__file__).parent.parent.parent.absolute() - / "tests" - / "base" - / "test_vectors.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/baseposelist.py b/spatialmath/baseposelist.py index d729b902..10dc051f 100644 --- a/spatialmath/baseposelist.py +++ b/spatialmath/baseposelist.py @@ -14,7 +14,7 @@ _numtypes = (int, np.int64, float, np.float64) -class BasePoseList(UserList, ABC): +class BasePoseList(ABC): """ List properties for spatial math classes @@ -38,21 +38,6 @@ class BasePoseList(UserList, ABC): syntax meaning ================== ============================================================ ``C()`` create a singleton instance of ``C`` with the identity value - ``C.Empty()`` create an instance of ``C`` with zero items - ``C.Alloc(n)`` create an instance of ``C`` with ``n`` identity items - ``len(x)`` return the number of items in ``x`` - ``x[i]`` return the ``i``'th item of ``x``, ``i`` is an index - or a slice. - ``x[i] = y`` set the ``i``'th item of ``x`` to the singleton instance - ``y`` and ``i`` is an index - ``x.append(y)`` append the value of singleton instance ``y`` to ``x`` - ``x.extend(y)`` append the items of ``y`` to ``x`` - ``x.pop()`` pop the first item of ``x`` - ``x.insert(i, y)`` insert the value of singleton instsance ``y`` into ``x`` - at position ``i``. - ``del x[i]`` delete the ``i``'th element of ``x`` - ``x.reverse()`` reverse the elements of ``x`` in place - ``x.clear()`` remove all items from ``x`` ================== ============================================================ where ``C`` is the class, and ``x`` and ``y`` are instances of ``C``. @@ -74,6 +59,10 @@ def shape(self): def isvalid(x, check=True): pass + @classmethod + def identity(cls) -> Self: + return cls(cls._identity()) + @abstractstaticmethod def _identity(): pass @@ -84,58 +73,6 @@ def _import(self, x, check=True): else: return None - @classmethod - def Empty(cls) -> Self: - """ - Construct an empty instance (BasePoseList superclass method) - - :return: pose instance with zero values - - Example:: - - >>> x = X.Empty() - >>> len(x) - 0 - - where ``X`` is any of the SMTB classes. - """ - x = cls() - x.data = [] - return x - - @classmethod - def Alloc(cls, n: Optional[int] = 1) -> Self: - """ - Construct an instance with N default values (BasePoseList superclass method) - - :param n: Number of values, defaults to 1 - :type n: int, optional - :return: pose instance with ``n`` default values - - ``X.Alloc(N)`` creates an instance of the pose class ``X`` with ``N`` - default values, ie. ``len(X)`` will be ``N``. - - ``X`` can be considered a vector of pose objects, and those elements - can be referenced ``X[i]`` or assigned to ``X[i] = ...``. - - .. note:: The default value depends on the pose class and is the result - of the empty constructor. For ``SO2``, - ``SE2``, ``SO3``, ``SE3`` it is an identity matrix, for a - twist class ``Twist2`` or ``Twist3`` it is a zero vector, - for a ``UnitQuaternion`` or ``Quaternion`` it is a zero - vector. - - Example:: - - >>> x = X.Alloc(10) - >>> len(x) - 10 - - where ``X`` is any of the SMTB classes. - """ - x = cls() - x.data = [cls._identity() for i in range(n)] # make n copies of the data - return x def arghandler( self, arg: Any, convertfrom: Tuple = (), check: Optional[bool] = True @@ -182,14 +119,14 @@ def arghandler( if arg is None: # empty constructor - self.data = [self._identity()] + raise TypeError("missing required argument (or argument is None)") elif isinstance(arg, np.ndarray): # it's a numpy array x = self._import(arg, check=check) if x is not None: - self.data = [x] + self.data = x else: return False @@ -197,25 +134,25 @@ def arghandler( # it's a list of things if isinstance(arg[0], np.ndarray): # possibly a list of numpy arrays - self.data = [self._import(x, check=check) for x in arg] + self.data = [self._import(x, check=check) for x in arg][0] # TODO: confirm elif type(arg[0]) == type(self): # possibly a list of objects of same type assert all( map(lambda x: type(x) == type(self), arg) ), "elements of list are incorrect type" - self.data = [x.A for x in arg] + self.data = [x.A for x in arg][0] # TODO: confirm elif ( isnumberlist(arg) and len(self.shape) == 1 and len(arg) == self.shape[0] ): - self.data = [np.array(arg)] + self.data = [np.array(arg)][0] # TODO: confirm else: # see what NumPy makes of it X = np.array(arg) if X.shape == self.shape: - self.data = [X] + self.data = X else: # no idea what was passed return False @@ -234,7 +171,7 @@ def arghandler( raise ValueError( "argument has no conversion method to this type" ) from None - self.data = [converter(arg).A] + self.data = converter(arg).A else: # don't know this argument, let object __init__ deal with it @@ -249,23 +186,20 @@ def __array_interface__(self): so that C extenstions with this spatial math class have direct access to the underlying numpy array """ - return self.data[0].__array_interface__ + return self.data.__array_interface__ @property - def _A(self) -> Union[List[NDArray], NDArray]: + def _A(self) -> NDArray: """ Spatial vector as an array :return: Moment vector :rtype: numpy.ndarray, shape=(3,) - ``X.v`` is a 3-vector """ - if len(self.data) == 1: - return self.data[0] - else: - return self.data + return self.data @property - def A(self) -> Union[List[NDArray], NDArray]: + def A(self) -> NDArray: """ Array value of an instance (BasePoseList superclass method) @@ -278,88 +212,13 @@ def A(self) -> Union[List[NDArray], NDArray]: .. note:: This assumes that ``len(X)`` == 1, ie. it is a single-valued instance. """ + return self.data - if len(self.data) == 1: - return self.data[0] - else: - return self.data - + def __len__(self) -> int: + return 1 + # ------------------------------------------------------------------------ # - def __getitem__(self, i: Union[int, slice]) -> BasePoseList: - """ - Access value of an instance (BasePoseList superclass method) - - :param i: index of element to return - :type i: int - :return: the specific element of the pose - :rtype: Quaternion or UnitQuaternion instance - :raises IndexError: if the element is out of bounds - - Note that only a single index is supported, slices are not. - - Example:: - - >>> x = X.Alloc(10) - >>> len(x) - 10 - >>> y = x[1] - >>> len(y) - 1 - >>> y = x[1:5] - >>> len(y) - 4 - - where ``X`` is any of the SMTB classes. - """ - - if isinstance(i, slice): - if i.stop is None: - # stop not given - end = len(self) - elif i.stop < 0: - # stop is negative, - - end = i.stop + len(self) + 1 - else: - # stop is positive, use it directly - end = i.stop - return self.__class__( - [self.data[k] for k in range(i.start or 0, end, i.step or 1)] - ) - else: - ret = self.__class__(self.data[i], check=False) - # ret.__array_interface__ = self.data[i].__array_interface__ - return ret - # return self.__class__(self.data[i], check=False) - - def __setitem__(self, i: int, value: BasePoseList) -> None: - """ - Assign a value to an instance (BasePoseList superclass method) - - :param i: index of element to assign to - :type i: int - :param value: the value to insert - :type value: Quaternion or UnitQuaternion instance - :raises ValueError: incorrect type of assigned value - - Assign the argument to an element of the object's internal list of values. - This supports the assignement operator, for example:: - - >>> x = X.Alloc(10) - >>> len(x) - 10 - >>> x[3] = X() # assign to position 3 in the list - - where ``X`` is any of the SMTB classes. - - """ - if not type(self) == type(value): - raise ValueError("can't insert different type of object") - if len(value) > 1: - raise ValueError( - "can't insert a multivalued element - must have len() == 1" - ) - self.data[i] = value.A # flag these binary operators as being not supported def __lt__(self, other: BasePoseList) -> Type[Exception]: @@ -374,125 +233,6 @@ def __gt__(self, other: BasePoseList) -> Type[Exception]: def __ge__(self, other: BasePoseList) -> Type[Exception]: return NotImplementedError - def append(self, item: BasePoseList) -> None: - """ - Append a value to an instance (BasePoseList superclass method) - - :param x: the value to append - :type x: Quaternion or UnitQuaternion instance - :raises ValueError: incorrect type of appended object - - Appends the argument to the object's internal list of values. - - Example:: - - >>> x = X.Alloc(10) - >>> len(x) - 10 - >>> x.append(X()) # append to the list - >>> len(x) - 11 - - where ``X`` is any of the SMTB classes. - """ - # print('in append method') - if not type(self) == type(item): - raise ValueError("can't append different type of object") - if len(item) > 1: - raise ValueError("can't append a multivalued instance - use extend") - super().append(item.A) - - def extend(self, iterable: BasePoseList) -> None: - """ - Extend sequence of values in an instance (BasePoseList superclass method) - - :param x: the value to extend - :type x: instance of same type - :raises ValueError: incorrect type of appended object - - Appends the argument's values to the object's internal list of values. - - Example:: - - >>> x = X.Alloc(10) - >>> len(x) - 10 - >>> x.append(X.Alloc(5)) # extend the list - >>> len(x) - 15 - - where ``X`` is any of the SMTB classes. - """ - # print('in extend method') - if not type(self) == type(iterable): - raise ValueError("can't append different type of object") - super().extend(iterable._A) - - def insert(self, i: int, item: BasePoseList) -> None: - """ - Insert a value to an instance (BasePoseList superclass method) - - :param i: element to insert value before - :type i: int - :param item: the value to insert - :type item: instance of same type - :raises ValueError: incorrect type of inserted value - - Inserts the argument into the object's internal list of values. - - Example:: - - >>> x = X.Alloc(10) - >>> len(x) - 10 - >>> x.insert(0, X()) # insert at start of list - >>> len(x) - 11 - >>> x.insert(10, X()) # append to the list - >>> len(x) - 11 - - where ``X`` is any of the SMTB classes. - - .. note:: If ``i`` is beyond the end of the list, the item is appended - to the list - """ - if not type(self) == type(item): - raise ValueError("can't insert different type of object") - if len(item) > 1: - raise ValueError( - "can't insert a multivalued instance - must have len() == 1" - ) - super().insert(i, item._A) - - def pop(self, i: Optional[int] = -1) -> Self: - """ - Pop value from an instance (BasePoseList superclass method) - - :param i: item in the list to pop, default is last - :type i: int - :return: the popped value - :rtype: instance of same type - :raises IndexError: if there are no values to pop - - Removes a value from the value list and returns it. The original - instance is modified. - - Example:: - - >>> x = X.Alloc(10) - >>> len(x) - 10 - >>> y = x.pop() # pop the last value x[9] - >>> len(x) - 9 - >>> y = x.pop(0) # pop the first value x[0] - >>> len(x) - 8 - - where ``X`` is any of the SMTB classes. - """ - return self.__class__(super().pop(i)) def binop( self, @@ -663,15 +403,6 @@ def unop( """ if matrix: - return np.vstack([op(x) for x in self.data]) + return np.vstack([op(self.data)]) else: - return [op(x) for x in self.data] - -if __name__ == "__main__": - from spatialmath import SO3, SO2 - - R = SO3([[1,0,0],[0,1,0],[0,0,1]]) - print(R.eulervec()) - - R = SO2([0.3, 0.4, 0.5]) - pass \ No newline at end of file + return [op(self.data)] diff --git a/spatialmath/baseposematrix.py b/spatialmath/baseposematrix.py index 1a850600..dba211a5 100644 --- a/spatialmath/baseposematrix.py +++ b/spatialmath/baseposematrix.py @@ -127,18 +127,6 @@ class BasePoseMatrix(BasePoseList): __array_ufunc__ = None # allow pose matrices operators with NumPy values - def __new__(cls, *args, **kwargs): - """ - Create the subclass instance (superclass method) - - Create a new instance and call the superclass initializer to enable the - ``UserList`` capabilities. - """ - - pose = super(BasePoseMatrix, cls).__new__(cls) # create a new instance - super().__init__(pose) # initialize UserList - return pose - # ------------------------------------------------------------------------ # @property @@ -323,15 +311,9 @@ def det(self) -> Tuple[float, Rn]: :SymPy: not supported """ if type(self).__name__ in ("SO3", "SE3"): - if len(self) == 1: - return np.linalg.det(self.A[:3, :3]) - else: - return [np.linalg.det(T[:3, :3]) for T in self.data] + return np.linalg.det(self.A[:3, :3]) elif type(self).__name__ in ("SO2", "SE2"): - if len(self) == 1: - return np.linalg.det(self.A[:2, :2]) - else: - return [np.linalg.det(T[:2, :2]) for T in self.data] + return np.linalg.det(self.A[:2, :2]) def log(self, twist: Optional[bool] = False) -> Union[NDArray, List[NDArray]]: """ @@ -369,15 +351,12 @@ def log(self, twist: Optional[bool] = False) -> Union[NDArray, List[NDArray]]: :SymPy: not supported """ if self.N == 2: - log = [smb.trlog2(x, twist=twist) for x in self.data] + log = smb.trlog2(self.data, twist=twist) else: - log = [smb.trlog(x, twist=twist) for x in self.data] - if len(log) == 1: - return log[0] - else: - return log + log = smb.trlog(self.data, twist=twist) + return log - def interp(self, end: Optional[bool] = None, s: Union[int, float] = None, shortest: bool = True) -> Self: + def interp(self, end: Optional[bool] = None, s: Union[int, float] = None, shortest: bool = True) -> list[Self]: """ Interpolate between poses (superclass method) @@ -423,9 +402,6 @@ def interp(self, end: Optional[bool] = None, s: Union[int, float] = None, shorte s = smb.getvector(s) s = np.clip(s, 0, 1) - if len(self) > 1: - raise ValueError("start pose must be a singleton") - if end is not None: if len(end) > 1: raise ValueError("end pose must be a singleton") @@ -433,88 +409,21 @@ def interp(self, end: Optional[bool] = None, s: Union[int, float] = None, shorte if self.N == 2: # SO(2) or SE(2) - return self.__class__( - [smb.trinterp2(start=self.A, end=end, s=_s, shortest=shortest) for _s in s] - ) - - elif self.N == 3: - # SO(3) or SE(3) - return self.__class__( - [smb.trinterp(start=self.A, end=end, s=_s, shortest=shortest) for _s in s] - ) - - def interp1(self, s: float = None) -> Self: - """ - Interpolate pose (superclass method) - - :param end: final pose - :type end: same as ``self`` - :param s: interpolation coefficient, range 0 to 1 - :type s: array_like - :return: interpolated pose - :rtype: SO2, SE2, SO3, SE3 instance - - - ``X.interp(s)`` interpolates pose between identity when s=0, and X when s=1. - - ====== ====== =========== =============================== - len(X) len(s) len(result) Result - ====== ====== =========== =============================== - 1 1 1 Y = interp(X, s) - M 1 M Y[i] = interp(X[i], s) - 1 M M Y[i] = interp(X, s[i]) - ====== ====== =========== =============================== - - Example:: - - >>> x = SE3.Rx(0.3) - >>> print(x.interp(0)) - SE3(array([[1., 0., 0., 0.], - [0., 1., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., 1.]])) - >>> print(x.interp(1)) - SE3(array([[ 1. , 0. , 0. , 0. ], - [ 0. , 0.95533649, -0.29552021, 0. ], - [ 0. , 0.29552021, 0.95533649, 0. ], - [ 0. , 0. , 0. , 1. ]])) - >>> y = x.interp(x, np.linspace(0, 1, 10)) - >>> len(y) - 10 - >>> y[5] - SE3(array([[ 1. , 0. , 0. , 0. ], - [ 0. , 0.98614323, -0.16589613, 0. ], - [ 0. , 0.16589613, 0.98614323, 0. ], - [ 0. , 0. , 0. , 1. ]])) - - Notes: - - #. For SO3 and SE3 rotation is interpolated using quaternion spherical linear interpolation (slerp). - - :seealso: :func:`interp`, :func:`~spatialmath.base.transforms3d.trinterp`, :func:`~spatialmath.base.quaternions.qslerp`, :func:`~spatialmath.smb.transforms2d.trinterp2` - - :SymPy: not supported - """ - s = smb.getvector(s) - s = np.clip(s, 0, 1) - - if self.N == 2: - # SO(2) or SE(2) - if len(s) > 1: - assert len(self) == 1, "if len(s) > 1, len(X) must == 1" - return self.__class__([smb.trinterp2(start, self.A, s=_s) for _s in s]) - else: - return self.__class__( - [smb.trinterp2(start, x, s=s[0]) for x in self.data] + return [ + self.__class__( + smb.trinterp2(start=self.A, end=end, s=_s, shortest=shortest) ) + for _s in s + ] + elif self.N == 3: # SO(3) or SE(3) - if len(s) > 1: - assert len(self) == 1, "if len(s) > 1, len(X) must == 1" - return self.__class__([smb.trinterp(None, self.A, s=_s) for _s in s]) - else: - return self.__class__( - [smb.trinterp(None, x, s=s[0]) for x in self.data] + return [ + self.__class__( + smb.trinterp(start=self.A, end=end, s=_s, shortest=shortest) ) + for _s in s + ] def norm(self) -> Self: """ @@ -546,9 +455,9 @@ def norm(self) -> Self: :seealso: :func:`~spatialmath.base.transforms3d.trnorm`, :func:`~spatialmath.base.transforms2d.trnorm2` """ if self.N == 2: - return self.__class__([smb.trnorm2(x) for x in self.data]) + return self.__class__(smb.trnorm2(self.data)) else: - return self.__class__([smb.trnorm(x) for x in self.data]) + return self.__class__(smb.trnorm(self.data)) def simplify(self) -> Self: """ @@ -581,7 +490,7 @@ def simplify(self) -> Self: """ vf = np.vectorize(smb.sym.simplify) - return self.__class__([vf(x) for x in self.data], check=False) + return self.__class__(vf(self.data), check=False) def stack(self) -> NDArray: """ @@ -709,11 +618,9 @@ def printline(self, *args, **kwargs) -> None: :seealso: :meth:`strline` :func:`trprint`, :func:`trprint2` """ if self.N == 2: - for x in self.data: - smb.trprint2(x, *args, **kwargs) + smb.trprint2(self.data, *args, **kwargs) else: - for x in self.data: - smb.trprint(x, *args, **kwargs) + smb.trprint(self.data, *args, **kwargs) def strline(self, *args, **kwargs) -> str: """ @@ -771,11 +678,9 @@ def strline(self, *args, **kwargs) -> str: """ s = "" if self.N == 2: - for x in self.data: - s += smb.trprint2(x, *args, file=False, **kwargs) + s += smb.trprint2(self.data, *args, file=False, **kwargs) else: - for x in self.data: - s += smb.trprint(x, *args, file=False, **kwargs) + s += smb.trprint(self.data, *args, file=False, **kwargs) return s def __repr__(self) -> str: @@ -804,19 +709,8 @@ def trim(x): return smb.removesmall(x) name = type(self).__name__ - if len(self) == 0: - return name + "([])" - elif len(self) == 1: - # need to indent subsequent lines of the native repr string by 4 spaces - return name + "(" + trim(self.A).__repr__().replace("\n", "\n ") + ")" - else: - # format this as a list of ndarrays - return ( - name - + "([\n" - + ",\n".join([trim(v).__repr__() for v in self.data]) - + " ])" - ) + # need to indent subsequent lines of the native repr string by 4 spaces + return name + "(" + trim(self.A).__repr__().replace("\n", "\n ") + ")" def _repr_pretty_(self, p, cycle): """ @@ -835,11 +729,7 @@ def _repr_pretty_(self, p, cycle): """ # see https://ipython.org/ipython-doc/stable/api/generated/IPython.lib.pretty.html - if len(self) == 1: - p.text(str(self)) - else: - for i, x in enumerate(self): - p.text(f"{i}:\n{str(x)}") + p.text(str(self)) def __str__(self) -> str: """ @@ -876,7 +766,7 @@ def _string_matrix(self) -> str: if self._ansiformatter is None: self._ansiformatter = ANSIMatrix(style="thick") - return "\n".join([self._ansiformatter.str(A) for A in self.data]) + return self._ansiformatter.str(self.data) def _string_color(self, color: Optional[bool] = False) -> str: """ @@ -951,25 +841,7 @@ def mformat(self, X): out += rowstr + bgcol + " " + reset + "\n" return out - output_str = "" - - if len(self.data) == 0: - output_str = "[]" - elif len(self.data) == 1: - # single matrix case - output_str = mformat(self, self.A) - else: - # sequence case - for count, X in enumerate(self.data): - # add separator lines and the index - output_str += ( - indexcol - + "[{:d}] =".format(count) - + reset - + "\n" - + mformat(self, X) - ) - + output_str = mformat(self, self.A) return output_str # ----------------------- graphics @@ -1027,18 +899,11 @@ def animate(self, *args, start=None, **kwargs) -> None: if start is not None: start = start.A - if len(self) > 1: - # trajectory case - if self.N == 2: - return smb.tranimate2(self.data, *args, **kwargs) - else: - return smb.tranimate(self.data, *args, **kwargs) + # singleton case + if self.N == 2: + return smb.tranimate2(self.A, start=start, *args, **kwargs) else: - # singleton case - if self.N == 2: - return smb.tranimate2(self.A, start=start, *args, **kwargs) - else: - return smb.tranimate(self.A, start=start, *args, **kwargs) + return smb.tranimate(self.A, start=start, *args, **kwargs) # ------------------------------------------------------------------------ # def prod(self, norm=False, check=True) -> Self: @@ -1066,9 +931,7 @@ def prod(self, norm=False, check=True) -> Self: group. You can either disable membership checking by ``check=False`` which is risky, or normalize the result by ``norm=True``. """ - Tprod = self.__class__._identity() # identity value - for T in self.data: - Tprod = Tprod @ T + Tprod = self.data if norm: Tprod = smb.trnorm(Tprod) return self.__class__(Tprod, check=check) @@ -1097,7 +960,7 @@ def __pow__(self, n: int) -> Self: assert type(n) is int, "exponent must be an int" return self.__class__( - [np.linalg.matrix_power(x, n) for x in self.data], check=False + np.linalg.matrix_power(self.data, n), check=False ) # ----------------------- arithmetic @@ -1323,22 +1186,6 @@ def __rmul__(right, left): # pylint: disable=no-self-argument # an ``SE3`` using their own ``__rmul__`` methods. return NotImplemented - def __imul__(left, right): # noqa - """ - Overloaded ``*=`` operator (superclass method) - - :return: Product of two operands - :rtype: Pose instance or NumPy array - :raises ValueError: for incompatible arguments - - - ``X *= Y`` compounds the poses ``X`` and ``Y`` and places the result in ``X`` - - ``X *= s`` performs elementwise multiplication of the elements of ``X`` - and ``s`` and places the result in ``X`` - - :seealso: ``__mul__`` - """ - return left.__mul__(right) - def __truediv__(left, right): # pylint: disable=no-self-argument """ Overloaded ``/`` operator (superclass method) @@ -1391,22 +1238,6 @@ def __truediv__(left, right): # pylint: disable=no-self-argument else: raise ValueError("bad operands") - def __itruediv__(left, right): # pylint: disable=no-self-argument - """ - Overloaded ``/=`` operator (superclass method) - - :return: Product of right operand and inverse of left operand - :rtype: Pose instance or NumPy array - :raises ValueError: for incompatible arguments - - - ``X /= Y`` compounds the poses ``X`` and ``Y.inv()`` and places the result in ``X`` - - ``X /= s`` performs elementwise division of the elements of ``X`` - by ``s`` and places the result in ``X`` - - :seealso: ``__truediv__`` - """ - return left.__truediv__(right) - def __add__(left, right): # pylint: disable=no-self-argument """ Overloaded ``+`` operator (superclass method) @@ -1473,22 +1304,6 @@ def __radd__(right, left): # pylint: disable=no-self-argument """ return right.__add__(left) - def __iadd__(left, right): # pylint: disable=no-self-argument - """ - Overloaded ``+=`` operator (superclass method) - - :return: Sum of two operands - :rtype: NumPy array, shape=(N,N) - :raises ValueError: for incompatible arguments - - - ``X += Y`` adds the matrix values of ``X`` and ``Y`` and places the result in ``X`` - - ``X += s`` elementwise addition of the matrix elements of ``X`` - and ``s`` and places the result in ``X`` - - :seealso: ``__add__`` - """ - return left.__add__(right) - def __sub__(left, right): # pylint: disable=no-self-argument """ Overloaded ``-`` operator (superclass method) @@ -1555,23 +1370,6 @@ def __rsub__(right, left: Self): # pylint: disable=no-self-argument """ return -right.__sub__(left) - def __isub__(left, right: Self): # pylint: disable=no-self-argument - """ - Overloaded ``-=`` operator (superclass method) - - :return: Difference of two operands - :rtype: NumPy array, shape=(N,N) - :raises: ValueError - - - ``X -= Y`` is the element-wise difference of the matrix value of ``X`` - and ``Y`` and places the result in ``X`` - - ``X -= s`` is the element-wise difference of the matrix value of ``X`` - and the scalar ``s`` and places the result in ``X`` - - :seealso: ``__sub__`` - """ - return left.__sub__(right) - def __eq__(left, right: Self) -> bool: # pylint: disable=no-self-argument """ Overloaded ``==`` operator (superclass method) @@ -1689,23 +1487,3 @@ def _op2(left, right: Self, op: Callable): # pylint: disable=no-self-argument raise TypeError( f"Invalid type ({right.__class__}) for binary operation with {left.__class__}" ) - - -if __name__ == "__main__": - from spatialmath import SE3, SE2, SO2 - - C = SO2(0.5) - A = np.array([[10, 0], [0, 1]]) - - print(C * A) - print(C * A * C.inv()) - print(C.conjugation(A)) - - # x = SE3.Rand(N=6) - - # x.printline(orient="rpy/xyz", fmt="{:8.3g}") - - # d = np.diag([0.25, 0.25, 1]) - # a = SE2() - # print(a) - # print(d * a) diff --git a/spatialmath/geom2d.py b/spatialmath/geom2d.py index 55eccb2a..11df9d9d 100755 --- a/spatialmath/geom2d.py +++ b/spatialmath/geom2d.py @@ -1143,46 +1143,3 @@ def polygon(self, resolution=10) -> Polygon2: :seealso: :meth:`points` """ return Polygon2(smb.ellipse(self.E, self.centre, resolution=resolution - 1)) - - -if __name__ == "__main__": - pass - # print(Ellipse((500, 500), (100, 200))) - # p = Polygon2([(1, 2), (3, 2), (2, 4)]) - # p.transformed(SE2(0, 0, np.pi / 2)).vertices() - - # a = Line2.TwoPoints((1, 2), (7, 5)) - # print(a) - - # p = Polygon2(np.array([[4, 4, 6, 6], [2, 1, 1, 2]])) - # base.plotvol2([8]) - # p.plot(color="b", alpha=0.3) - # for theta in np.linspace(0, 2 * np.pi, 100): - # p.animate(SE2(0, 0, theta)) - # plt.show() - # plt.pause(0.05) - - # print(p) - # p.plot(alpha=0.5, color='b') - # print(p.contains([5.,5.])) - # print(p.contains([5,1.5])) - # print(p.contains([4, 2.1])) - - # print(p.vertices()) - # print(p.area()) - # print(p.centroid()) - # print(p.bbox()) - # print(p.radius()) - # print(p.vertices(closed=True)) - - # for e in p.edges(): - # print(e) - - # p2 = p.transformed(SE2(-5, -1.5, 0)) - # print(p2.vertices()) - # print(p2.area()) - - # p2.plot(alpha=0.5, facecolor='r') - - # p.move(SE2(0, 0, 0.7)) - # plt.show(block=True) diff --git a/spatialmath/geom3d.py b/spatialmath/geom3d.py index 8b191ebd..b4d3dfd8 100755 --- a/spatialmath/geom3d.py +++ b/spatialmath/geom3d.py @@ -300,8 +300,6 @@ def __init__(self, v=None, w=None, check=True): """ from spatialmath.pose3d import SE3 - super().__init__() # enable list powers - if w is None: # zero or one arguments passed if super().arghandler(v, convertfrom=(SE3,)): @@ -312,7 +310,7 @@ def __init__(self, v=None, w=None, check=True): if base.isvector(v, 3) and base.isvector(w, 3): if check and not base.iszero(np.dot(v, w)): raise ValueError("invalid Plucker coordinates") - self.data = [np.r_[v, w]] + self.data = np.r_[v, w] else: raise ValueError("invalid argument to Line3 constructor") @@ -438,14 +436,7 @@ def append(self, x: Line3): @property def A(self) -> R6: # get the underlying numpy array - if len(self.data) == 1: - return self.data[0] - else: - return self.data - - def __getitem__(self, i): - # print('getitem', i, 'class', self.__class__) - return self.__class__(self.data[i]) + return self.data @property def v(self) -> R3: @@ -459,7 +450,7 @@ def v(self) -> R3: :seealso: :meth:`w` """ - return self.data[0][0:3] + return self.data[0:3] @property def w(self) -> R3: @@ -473,7 +464,7 @@ def w(self) -> R3: :seealso: :meth:`v` :meth:`uw` """ - return self.data[0][3:6] + return self.data[3:6] @property def uw(self) -> R3: @@ -907,9 +898,6 @@ def closest_to_line( # https://web.cs.iastate.edu/~cs577/handouts/plucker-coordinates.pdf # but (20) (21) is the negative of correct answer - points = [] - dists = [] - def intersection(line1, line2): with np.errstate(divide="ignore", invalid="ignore"): # compute the distance between all pairs of lines @@ -929,30 +917,7 @@ def intersection(line1, line2): return p1, np.linalg.norm(p1 - p2) - if len(l1) == len(l2): - # two sets of lines of equal length - for line1, line2 in zip(l1, l2): - point, dist = intersection(line1, line2) - points.append(point) - dists.append(dist) - - elif len(l1) == 1 and len(l2) > 1: - for line in l2: - point, dist = intersection(l1, line) - points.append(point) - dists.append(dist) - - elif len(l1) > 1 and len(l2) == 1: - for line in l1: - point, dist = intersection(line, l2) - points.append(point) - dists.append(dist) - - if len(points) == 1: - # 1D case for self or line - return points[0], dists[0] - else: - return np.array(points).T, np.array(dists) + return intersection(line1, line2) def closest_to_point(self, x: ArrayLike3) -> Tuple[R3, float]: """ @@ -1203,8 +1168,9 @@ def intersect_volume(self, bounds: ArrayLike6) -> Tuple[Points3, Rn]: # PLOT AND DISPLAY # ------------------------------------------------------------------------- # + @staticmethod def plot( - self, + lines: List[Self], *pos, bounds: Optional[ArrayLike] = None, ax: Optional[plt.Axes] = None, @@ -1244,16 +1210,16 @@ def plot( ax.set_ylim(bounds[2:4]) ax.set_zlim(bounds[4:6]) - lines = [] - for line in self: + plot_lines = [] + for line in lines: P, lam = line.intersect_volume(bounds) if len(lam) > 0: l = ax.plot( tuple(P[0, :]), tuple(P[1, :]), tuple(P[2, :]), *pos, **kwargs ) - lines.append(l) - return lines + plot_lines.append(l) + return plot_lines def __str__(self) -> str: """ @@ -1274,13 +1240,8 @@ def __str__(self) -> str: """ - return "\n".join( - [ - "{{ {:.5g} {:.5g} {:.5g}; {:.5g} {:.5g} {:.5g}}}".format( - *list(base.removesmall(x.vec)) - ) - for x in self - ] + return "{{ {:.5g} {:.5g} {:.5g}; {:.5g} {:.5g} {:.5g}}}".format( + *list(base.removesmall(self.vec)) ) def __repr__(self) -> str: @@ -1295,23 +1256,9 @@ def __repr__(self) -> str: For a multi-valued ``Line3``, one line per value in ``Line3``. """ - if len(self) == 1: - return "Line3([{:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}])".format( - *list(self.A) - ) - else: - return ( - "Line3([\n" - + ",\n".join( - [ - " [{:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}]".format( - *list(tw) - ) - for tw in self.data - ] - ) - + "\n])" - ) + return "Line3([{:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}])".format( + *list(self.A) + ) def _repr_pretty_(self, p, cycle): """ @@ -1328,29 +1275,7 @@ def _repr_pretty_(self, p, cycle): In [1]: x """ - if len(self) == 1: - p.text(str(self)) - else: - for i, x in enumerate(self): - if i > 0: - p.break_() - p.text(f"{i:3d}: {str(x)}") - - # function z = side(self1, pl2) - # Plucker.side Plucker side operator - # - # # X = SIDE(P1, P2) is the side operator which is zero whenever - # # the lines P1 and P2 intersect or are parallel. - # - # # See also Plucker.or. - # - # if ~isa(self2, 'Plucker') - # error('SMTB:Plucker:badarg', 'both arguments to | must be Plucker objects'); - # end - # L1 = pl1.line(); L2 = pl2.line(); - # - # z = L1([1 5 2 6 3 4]) * L2([5 1 6 2 4 3])'; - # end + p.text(str(self)) def side(self, other: Line3) -> float: """ @@ -1377,51 +1302,3 @@ def __init__(self, v=None, w=None): warnings.warn("use Line class instead", DeprecationWarning) super().__init__(v, w) - - -if __name__ == "__main__": # pragma: no cover - import pathlib - import os.path - - # L = Line3.TwoPoints((1,2,0), (1,2,1)) - # print(L) - # print(L.intersect_plane([0, 0, 1, 0])) - - # z = np.eye(6) * L - - # L2 = SE3(2, 1, 10) * L - # print(L2) - # print(L2.intersect_plane([0, 0, 1, 0])) - - # print('rx') - # L2 = SE3.Rx(np.pi/4) * L - # print(L2) - # print(L2.intersect_plane([0, 0, 1, 0])) - - # print('ry') - # L2 = SE3.Ry(np.pi/4) * L - # print(L2) - # print(L2.intersect_plane([0, 0, 1, 0])) - - # print('rz') - # L2 = SE3.Rz(np.pi/4) * L - # print(L2) - # print(L2.intersect_plane([0, 0, 1, 0])) - - # base.plotvol3(10) - # S = Twist3.UnitRevolute([0, 0, 1], [2, 3, 2], 0.5); - # L = S.line() - # L.plot('k:', linewidth=2) - - # a = Plane3([0.1, -1, -1, 2]) - # base.plotvol3(5) - # a.plot(color='r', alpha=0.3) - # plt.show(block=True) - - # a = SE3.Exp([2,0,0,0,0,0]) - - exec( - open( - pathlib.Path(__file__).parent.parent.absolute() / "tests" / "test_geom3d.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/pose2d.py b/spatialmath/pose2d.py index 57f1b6b7..280d043b 100644 --- a/spatialmath/pose2d.py +++ b/spatialmath/pose2d.py @@ -25,6 +25,7 @@ import spatialmath.base as smb from spatialmath.baseposematrix import BasePoseMatrix +from spatialmath.base.types import Self # ============================== SO2 =====================================# @@ -71,19 +72,14 @@ def __init__(self, arg=None, *, unit="rad", check=True): - ``SO2([X1, X2, ... XN])`` is an SO2 instance containing a sequence of N rotations, where each Xi is an SO2 instance. """ - super().__init__() - if isinstance(arg, SE2): - self.data = [smb.t2r(x) for x in arg.data] + self.data = smb.t2r(arg.data) elif super().arghandler(arg, check=check): return elif smb.isscalar(arg): - self.data = [smb.rot2(arg, unit=unit)] - - elif smb.isvector(arg): - self.data = [smb.rot2(x, unit=unit) for x in smb.getvector(arg)] + self.data = smb.rot2(arg, unit=unit) else: raise ValueError("bad argument to constructor") @@ -103,7 +99,7 @@ def shape(self): return (2, 2) @classmethod - def Rand(cls, N=1, arange=(0, 2 * math.pi), unit="rad"): + def Rand(cls, arange=(0, 2 * math.pi), unit="rad"): r""" Construct new SO(2) with random rotation @@ -125,8 +121,8 @@ def Rand(cls, N=1, arange=(0, 2 * math.pi), unit="rad"): """ rand = np.random.uniform( - low=arange[0], high=arange[1], size=N - ) # random values in the range + low=arange[0], high=arange[1], size=1 + )[0] # random values in the range return cls([smb.rot2(x) for x in smb.getunit(rand, unit)]) @classmethod @@ -180,10 +176,7 @@ def inv(self): - for elements of SO(2) this is the transpose. - if `x` contains a sequence, returns an `SO2` with a sequence of inverses """ - if len(self) == 1: - return SO2(self.A.T) - else: - return SO2([x.T for x in self.A]) + return SO2(self.A.T) @property def R(self): @@ -217,10 +210,7 @@ def theta(self, unit="rad"): else: conv = 1.0 - if len(self) == 1: - return conv * math.atan2(self.A[1, 0], self.A[0, 0]) - else: - return [conv * math.atan2(x.A[1, 0], x.A[0, 0]) for x in self] + return conv * math.atan2(self.A[1, 0], self.A[0, 0]) def SE2(self): """ @@ -296,16 +286,16 @@ def __init__(self, x=None, y=None, theta=None, *, unit="rad", check=True): return if isinstance(x, SO2): - self.data = [smb.r2t(_x) for _x in x.data] + self.data = smb.r2t(x.data) elif smb.isscalar(x): - self.data = [smb.trot2(x, unit=unit)] + self.data = mb.trot2(x, unit=unit) elif len(x) == 2: # SE2([x,y]) - self.data = [smb.transl2(x)] + self.data = smb.transl2(x) elif len(x) == 3: # SE2([x,y,theta]) - self.data = [smb.trot2(x[2], t=x[:2], unit=unit)] + self.data = smb.trot2(x[2], t=x[:2], unit=unit) else: raise ValueError("bad argument to constructor") @@ -313,11 +303,11 @@ def __init__(self, x=None, y=None, theta=None, *, unit="rad", check=True): elif x is not None: if y is not None and theta is None: # SE2(x, y) - self.data = [smb.transl2(x, y)] + self.data = smb.transl2(x, y) elif y is not None and theta is not None: # SE2(x, y, theta) - self.data = [smb.trot2(theta, t=[x, y], unit=unit)] + self.data = smb.trot2(theta, t=[x, y], unit=unit) else: raise ValueError("bad arguments to constructor") @@ -523,10 +513,7 @@ def t(self): - 1, return an ndarray with shape=(2,) - N>1, return an ndarray with shape=(N,2) """ - if len(self) == 1: - return self.A[:2, 2] - else: - return np.array([x[:2, 2] for x in self.A]) + return self.A[:2, 2] @property def x(self): @@ -543,10 +530,7 @@ def x(self): - 1, return an float - N>1, return an ndarray with shape=(N,) """ - if len(self) == 1: - return self.A[0, 2] - else: - return np.array([v[0, 2] for v in self.A]) + return self.A[0, 2] @property def y(self): @@ -563,10 +547,7 @@ def y(self): - 1, return an float - N>1, return an ndarray with shape=(N,) """ - if len(self) == 1: - return self.A[1, 2] - else: - return np.array([v[1, 2] for v in self.A]) + return self.A[1, 2] def xyt(self): r""" @@ -581,10 +562,7 @@ def xyt(self): - 1, return an ndarray with shape=(3,) - N>1, return an ndarray with shape=(N,3) """ - if len(self) == 1: - return smb.tr2xyt(self.A) - else: - return [smb.tr2xyt(x) for x in self.A] + return smb.tr2xyt(self.A) def inv(self): r""" @@ -601,10 +579,7 @@ def inv(self): - if `x` contains a sequence, returns an `SE2` with a sequence of inverses """ - if len(self) == 1: - return SE2(smb.rt2tr(self.R.T, -self.R.T @ self.t), check=False) - else: - return SE2([smb.rt2tr(x.R.T, -x.R.T @ x.t) for x in self], check=False) + return SE2(smb.rt2tr(self.R.T, -self.R.T @ self.t), check=False) def SE3(self, z=0): """ @@ -628,19 +603,9 @@ def lift3(x): y[2, 3] = z return y - return SE3([lift3(x) for x in self]) + return SE3(lift3(self)) def Twist2(self): from spatialmath.twist import Twist2 return Twist2(self.log(twist=True)) - - -if __name__ == "__main__": # pragma: no cover - import pathlib - - exec( - open( - pathlib.Path(__file__).parent.parent.absolute() / "tests" / "test_pose2d.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/pose3d.py b/spatialmath/pose3d.py index b4301d93..b1764dce 100644 --- a/spatialmath/pose3d.py +++ b/spatialmath/pose3d.py @@ -99,10 +99,8 @@ def __init__(self, arg=None, *, check=True): :SymPy: supported """ - super().__init__() - if isinstance(arg, SE3): - self.data = [smb.t2r(x) for x in arg.data] + self.data = smb.t2r(arg.data) elif not super().arghandler(arg, check=check): raise ValueError("bad argument to constructor") @@ -149,10 +147,7 @@ def R(self) -> SO3Array: :SymPy: supported """ - if len(self) == 1: - return self.A[:3, :3] # type: ignore - else: - return np.array([x[:3, :3] for x in self.A]) # type: ignore + return self.A[:3, :3] # type: ignore @property def n(self) -> R3: @@ -166,8 +161,6 @@ def n(self) -> R3: *normal vector*. It is parallel to the x-axis of the frame defined by this pose. """ - if len(self) != 1: - raise ValueError("can only determine n-vector for singleton pose") return self.A[:3, 0] # type: ignore @property @@ -182,8 +175,6 @@ def o(self) -> R3: the *orientation vector*. It is parallel to the y-axis of the frame defined by this pose. """ - if len(self) != 1: - raise ValueError("can only determine o-vector for singleton pose") return self.A[:3, 1] # type: ignore @property @@ -198,8 +189,6 @@ def a(self) -> R3: *approach vector*. It is parallel to the z-axis of the frame defined by this pose. """ - if len(self) != 1: - raise ValueError("can only determine a-vector for singleton pose") return self.A[:3, 2] # type: ignore # ------------------------------------------------------------------------ # @@ -215,10 +204,7 @@ def inv(self) -> Self: account the matrix structure. For an SO(3) matrix the inverse is the transpose. """ - if len(self) == 1: - return SO3(self.A.T, check=False) # type: ignore - else: - return SO3([x.T for x in self.A], check=False) + return SO3(self.A.T, check=False) # type: ignore def eul(self, unit: str = "rad", flip: bool = False) -> Union[R3, RNx3]: r""" @@ -241,10 +227,7 @@ def eul(self, unit: str = "rad", flip: bool = False) -> Union[R3, RNx3]: :seealso: :func:`~spatialmath.pose3d.SE3.Eul`, :func:`~spatialmath.base.transforms3d.tr2eul` :SymPy: not supported """ - if len(self) == 1: - return smb.tr2eul(self.A, unit=unit, flip=flip) # type: ignore - else: - return np.array([base.tr2eul(x, unit=unit, flip=flip) for x in self.A]) + return smb.tr2eul(self.A, unit=unit, flip=flip) # type: ignore def rpy(self, unit: str = "rad", order: str = "zyx") -> Union[R3, RNx3]: """ @@ -279,10 +262,7 @@ def rpy(self, unit: str = "rad", order: str = "zyx") -> Union[R3, RNx3]: :seealso: :func:`~spatialmath.pose3d.SE3.RPY`, :func:`~spatialmath.base.transforms3d.tr2rpy` :SymPy: not supported """ - if len(self) == 1: - return smb.tr2rpy(self.A, unit=unit, order=order) # type: ignore - else: - return np.array([smb.tr2rpy(x, unit=unit, order=order) for x in self.A]) + return smb.tr2rpy(self.A, unit=unit, order=order) # type: ignore def angvec(self, unit: str = "rad") -> Tuple[float, R3]: r""" @@ -455,12 +435,10 @@ def Rz(cls, theta, unit: str = "rad") -> Self: return cls([smb.rotz(x, unit=unit) for x in smb.getvector(theta)], check=False) @classmethod - def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> Self: + def Rand(cls, *, theta_range:Optional[ArrayLike2] = None, unit: str = "rad") -> Self: """ Construct a new SO(3) from random rotation - :param N: number of random rotations - :type N: int :param theta_range: angular magnitude range [min,max], defaults to None. :type xrange: 2-element sequence, optional :param unit: angular units: 'rad' [default], or 'deg' @@ -481,7 +459,7 @@ def Rand(cls, N: int = 1, *, theta_range:Optional[ArrayLike2] = None, unit: str :seealso: :func:`spatialmath.quaternion.UnitQuaternion.Rand` """ - return cls([smb.q2r(smb.qrand(theta_range=theta_range, unit=unit)) for _ in range(0, N)], check=False) + return cls(smb.q2r(smb.qrand(theta_range=theta_range, unit=unit)), check=False) @overload @classmethod @@ -1008,22 +986,18 @@ def __init__(self, x=None, y=None, z=None, *, check=True): if super().arghandler(x, check=check): return elif isinstance(x, SO3): - self.data = [smb.r2t(_x) for _x in x.data] + self.data = smb.r2t(x.data) elif isinstance(x, SE2): # type(x).__name__ == "SE2": self.data = x.SE3().data elif smb.isvector(x, 3): # SE3( [x, y, z] ) - self.data = [smb.transl(x)] - elif isinstance(x, np.ndarray) and x.shape[1] == 3: - # SE3( Nx3 ) - self.data = [smb.transl(T) for T in x] - + self.data = smb.transl(x) else: raise ValueError("bad argument to constructor") elif y is not None and z is not None: # SE3(x, y, z) - self.data = [smb.transl(x, y, z)] + self.data = smb.transl(x, y, z) else: raise ValueError("Invalid arguments. See documentation for correct format.") @@ -1047,8 +1021,6 @@ def shape(self) -> Tuple[int, int]: @SO3.R.setter def R(self, r: SO3Array) -> None: - if len(self) > 1: - raise ValueError("can only assign rotation to length 1 object") so3 = SO3(r) self.A[:3, :3] = so3.R @@ -1075,15 +1047,10 @@ def t(self) -> R3: :SymPy: supported """ - if len(self) == 1: - return self.A[:3, 3] - else: - return np.array([x[:3, 3] for x in self.A]) + return self.A[:3, 3] @t.setter def t(self, v: ArrayLike3): - if len(self) > 1: - raise ValueError("can only assign translation to length 1 object") v = smb.getvector(v, 3) self.A[:3, 3] = v @@ -1109,15 +1076,10 @@ def x(self) -> float: :SymPy: supported """ - if len(self) == 1: - return self.A[0, 3] - else: - return np.array([v[0, 3] for v in self.A]) + return self.A[0, 3] @x.setter def x(self, x: float): - if len(self) > 1: - raise ValueError("can only assign elements to length 1 object") self.A[0, 3] = x @property @@ -1142,15 +1104,10 @@ def y(self) -> float: :SymPy: supported """ - if len(self) == 1: - return self.A[1, 3] - else: - return np.array([v[1, 3] for v in self.A]) + return self.A[1, 3] @y.setter def y(self, y: float): - if len(self) > 1: - raise ValueError("can only assign elements to length 1 object") self.A[1, 3] = y @property @@ -1175,15 +1132,10 @@ def z(self) -> float: :SymPy: supported """ - if len(self) == 1: - return self.A[2, 3] - else: - return np.array([v[2, 3] for v in self.A]) + return self.A[2, 3] @z.setter def z(self, z: float): - if len(self) > 1: - raise ValueError("can only assign elements to length 1 object") self.A[2, 3] = z # ------------------------------------------------------------------------ # @@ -1216,10 +1168,7 @@ def inv(self) -> SE3: :SymPy: supported """ - if len(self) == 1: - return SE3(smb.trinv(self.A), check=False) - else: - return SE3([smb.trinv(x) for x in self.A], check=False) + return SE3(smb.trinv(self.A), check=False) def yaw_SE2(self, order: str = "zyx") -> SE2: """ @@ -1244,15 +1193,12 @@ def yaw_SE2(self, order: str = "zyx") -> SE2: to the optic axis and x-axis parallel to the pixel rows. """ - if len(self) == 1: - if order == "zyx": - return SE2(self.x, self.y, self.rpy(order = order)[2]) - elif order == "xyz": - return SE2(self.z, self.y, self.rpy(order = order)[2]) - elif order == "yxz": - return SE2(self.z, self.x, self.rpy(order = order)[2]) - else: - return SE2([e.yaw_SE2() for e in self]) + if order == "zyx": + return SE2(self.x, self.y, self.rpy(order = order)[2]) + elif order == "xyz": + return SE2(self.z, self.y, self.rpy(order = order)[2]) + elif order == "yxz": + return SE2(self.z, self.x, self.rpy(order = order)[2]) def delta(self, X2: Optional[SE3] = None) -> R6: r""" @@ -1517,7 +1463,6 @@ def Rz( @classmethod def Rand( cls, - N: int = 1, xrange: Optional[ArrayLike2] = (-1, 1), yrange: Optional[ArrayLike2] = (-1, 1), zrange: Optional[ArrayLike2] = (-1, 1), @@ -1537,16 +1482,12 @@ def Rand( :type xrange: 2-element sequence, optional :param unit: angular units: 'rad' [default], or 'deg' :type unit: str - :param N: number of random transforms - :type N: int :return: SE(3) matrix :rtype: SE3 instance Return an SE3 instance with random rotation and translation. - ``SE3.Rand()`` is a random SE(3) translation. - - ``SE3.Rand(N)`` is an SE3 object containing a sequence of N random - poses. Example: @@ -1559,17 +1500,17 @@ def Rand( :seealso: :func:`~spatialmath.quaternions.UnitQuaternion.Rand` """ X = np.random.uniform( - low=xrange[0], high=xrange[1], size=N + low=xrange[0], high=xrange[1], size=1 ) # random values in the range Y = np.random.uniform( - low=yrange[0], high=yrange[1], size=N + low=yrange[0], high=yrange[1], size=1 ) # random values in the range Z = np.random.uniform( - low=zrange[0], high=zrange[1], size=N + low=zrange[0], high=zrange[1], size=1 ) # random values in the range - R = SO3.Rand(N=N, theta_range=theta_range, unit=unit) + R = SO3.Rand(theta_range=theta_range, unit=unit) return cls( - [smb.transl(x, y, z) @ smb.r2t(r.A) for (x, y, z, r) in zip(X, Y, Z, R)], + smb.transl(X[0], Y[0], Z[0]) @ smb.r2t(R.A), check=False, ) @@ -2101,13 +2042,3 @@ def angdist(self, other: SE3, metric: int = 6) -> float: # return cls(base.r2t(R)) # else: # return cls(base.rt2tr(R, t)) - - -if __name__ == "__main__": # pragma: no cover - import pathlib - - exec( - open( - pathlib.Path(__file__).parent.parent.absolute() / "tests" / "test_pose3d.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/quaternion.py b/spatialmath/quaternion.py index 51561036..40ff2895 100644 --- a/spatialmath/quaternion.py +++ b/spatialmath/quaternion.py @@ -75,8 +75,6 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True): >>> print(q) """ - super().__init__() - if s is None and smb.isvector(v, 4): v,s = (s,v) @@ -86,11 +84,11 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True): return elif smb.isvector(s, 4): - self.data = [smb.getvector(s)] + self.data = smb.getvector(s) elif smb.isscalar(s) and smb.isvector(v, 3): # Quaternion(s, v) - self.data = [np.r_[s, smb.getvector(v)]] + self.data = np.r_[s, smb.getvector(v)] else: raise ValueError("bad argument to Quaternion constructor") @@ -175,10 +173,7 @@ def s(self) -> float: >>> Quaternion([np.r_[1,2,3,4], np.r_[5,6,7,8]]).s """ - if len(self) == 1: - return self._A[0] - else: - return np.array([q.s for q in self]) + return self._A[0] @property def v(self) -> R3: @@ -202,10 +197,7 @@ def v(self) -> R3: >>> Quaternion([np.r_[1,2,3,4], np.r_[5,6,7,8]]).v """ - if len(self) == 1: - return self._A[1:4] - else: - return np.array([q.v for q in self]) + return self._A[1:4] @property def vec(self) -> R4: @@ -231,10 +223,7 @@ def vec(self) -> R4: >>> Quaternion([1,2,3,4]).vec >>> Quaternion([np.r_[1,2,3,4], np.r_[5,6,7,8]]).vec """ - if len(self) == 1: - return self._A - else: - return np.array([q._A for q in self]) + return self._A @property def vec_xyzs(self) -> R4: @@ -261,10 +250,7 @@ def vec_xyzs(self) -> R4: >>> Quaternion([1,2,3,4]).vec_xyzs >>> Quaternion([np.r_[1,2,3,4], np.r_[5,6,7,8]]).vec_xyzs """ - if len(self) == 1: - return np.roll(self._A, -1) - else: - return np.array([np.roll(q._A, -1) for q in self]) + return np.roll(self._A, -1) @property def matrix(self) -> R4x4: @@ -310,7 +296,7 @@ def conj(self) -> Quaternion: :seealso: :func:`~spatialmath.base.quaternions.qconj` """ - return self.__class__([smb.qconj(q._A) for q in self]) + return self.__class__(smb.qconj(self._A)) def norm(self) -> float: r""" @@ -332,10 +318,7 @@ def norm(self) -> float: :seealso: :func:`~spatialmath.base.quaternions.qnorm` """ - if len(self) == 1: - return smb.qnorm(self._A) - else: - return np.array([smb.qnorm(q._A) for q in self]) + return smb.qnorm(self._A) def unit(self) -> UnitQuaternion: r""" @@ -361,7 +344,7 @@ def unit(self) -> UnitQuaternion: :seealso: :func:`~spatialmath.base.quaternions.qnorm` """ - return UnitQuaternion([smb.qunit(q._A) for q in self], norm=False) + return UnitQuaternion(smb.qunit(self._A), norm=False) def log(self) -> Quaternion: r""" @@ -585,7 +568,7 @@ def __mul__( elif smb.isscalar(right): # quaternion * scalar case # print('scalar * quat') - return Quaternion([right * q._A for q in left]) + return Quaternion(right * left._A) else: raise ValueError("operands to * are of different types") @@ -613,35 +596,7 @@ def __rmul__( :seealso: :func:`__mul__` """ # scalar * quaternion case - return Quaternion([left * q._A for q in right]) - - def __imul__( - left, right: Quaternion - ) -> bool: # lgtm[py/not-named-self] pylint: disable=no-self-argument - """ - Overloaded ``*=`` operator - - :return: product - :rtype: Quaternion - :raises: ValueError - - ``q1 *= q2`` sets ``q1 := q1 * q2`` - ``q1 *= s`` sets ``q1 := q1 * s`` where ``s`` is a scalar - - Example: - - .. runblock:: pycon - - >>> from spatialmath import Quaternion - >>> q = Quaternion([1,2,3,4]) - >>> q *= Quaternion([5,6,7,8]) - >>> print(q) - >>> q *= 2 - >>> print(q) - - :seealso: :func:`__mul__` - """ - return left.__mul__(right) + return Quaternion(left * right._A) def __pow__(self, n: int) -> Quaternion: """ @@ -663,33 +618,7 @@ def __pow__(self, n: int) -> Quaternion: :seealso: :func:`~spatialmath.base.quaternions.qpow` """ - return self.__class__([smb.qpow(q._A, n) for q in self]) - - def __ipow__(self, n: int) -> Quaternion: - """ - Overloaded ``=**`` operator - - :rtype: Quaternion instance - - ``q **= N`` computes the product of ``q`` with itself ``N-1`` times, where ``N`` must be - an integer. If ``N``<0 the result is conjugated. - - Example: - - .. runblock:: pycon - - >>> from spatialmath import Quaternion - >>> q = Quaternion([1,2,3,4]) - >>> q **= 2 - >>> q - >>> q = Quaternion([np.r_[1,2,3,4], np.r_[5,6,7,8]]) - >>> q **= 2 - >>> q - - - :seealso: :func:`__pow__` - """ - return self.__pow__(n) + return self.__class__(smb.qpow(self._A, n)) def __truediv__(self, other: Quaternion): return NotImplemented # Quaternion division not supported @@ -819,9 +748,7 @@ def __neg__(self) -> Quaternion: >>> -Quaternion([np.r_[1,2,3,4], np.r_[5,6,7,8]]) """ - return UnitQuaternion( - [-x for x in self.data] - ) # pylint: disable=invalid-unary-operand-type + return UnitQuaternion(-self.data) # pylint: disable=invalid-unary-operand-type def __repr__(self) -> str: """ @@ -839,19 +766,8 @@ def __repr__(self) -> str: >>> q """ name = type(self).__name__ - if len(self) == 0: - return name + "([])" - elif len(self) == 1: - # need to indent subsequent lines of the native repr string by 4 spaces - return name + "(" + self._A.__repr__() + ")" - else: - # format this as a list of ndarrays - return ( - name - + "([\n " - + ",\n ".join([v.__repr__() for v in self.data]) - + " ])" - ) + # need to indent subsequent lines of the native repr string by 4 spaces + return name + "(" + self._A.__repr__() + ")" def _repr_pretty_(self, p, cycle): """ @@ -897,7 +813,7 @@ def __str__(self) -> str: delim = ("<<", ">>") else: delim = ("<", ">") - return "\n".join([smb.qprint(q, file=None, delim=delim) for q in self.data]) + return smb.qprint(self.data, file=None, delim=delim) # ========================================================================= # @@ -980,8 +896,6 @@ def __init__( >>> print(q) # str() """ - super().__init__() - # handle: UnitQuaternion(v)`` constructs a unit quaternion with specified elements # from ``v`` which is a 4-vector given as a list, tuple, or ndarray(4) if s is None and smb.isvector(v, 4): @@ -991,7 +905,7 @@ def __init__( # single argument if super().arghandler(s, check=check): # create unit quaternion - self.data = [smb.qunit(q) for q in self.data] + self.data = smb.qunit(self.data) elif isinstance(s, np.ndarray): # passed a NumPy array, it could be: @@ -1002,7 +916,7 @@ def __init__( if s.shape == (3, 3): if smb.isrot(s, check=check): # UnitQuaternion(R) R is 3x3 rotation matrix - self.data = [smb.r2q(s)] + self.data = smb.r2q(s) else: raise ValueError( "invalid rotation matrix provided to UnitQuaternion constructor" @@ -1010,25 +924,25 @@ def __init__( elif s.shape == (4,): # passed a 4-vector if norm: - self.data = [smb.qunit(s)] + self.data = smb.qunit(s) else: - self.data = [s] + self.data = s elif s.ndim == 2 and s.shape[1] == 4: if norm: - self.data = [smb.qunit(x) for x in s] + self.data = smb.qunit(s) else: # self.data = [smb.qpositive(x) for x in s] - self.data = [x for x in s] + self.data = s else: raise ValueError("array could not be interpreted as UnitQuaternion") elif isinstance(s, SO3): # UnitQuaternion(x) x is SO3 or SE3 (since SE3 is subclass of SO3) - self.data = [smb.r2q(x.R) for x in s] + self.data = smb.r2q(s.R) elif isinstance(s[0], SO3): # list of SO3 or SE3 - self.data = [smb.r2q(x.R) for x in s] + self.data = smb.r2q(s.R) else: raise ValueError("bad argument to UnitQuaternion constructor") @@ -1038,7 +952,7 @@ def __init__( q = np.r_[s, smb.getvector(v)] if norm: q = smb.qunit(q) - self.data = [q] + self.data = q else: raise ValueError("bad argument to UnitQuaternion constructor") @@ -1097,10 +1011,7 @@ def R(self) -> SO3Array: ``x[i]``. This is different to the MATLAB version where the i'th rotation matrix is ``x(:,:,i)``. """ - if len(self) > 1: - return np.array([smb.q2r(q) for q in self.data]) - else: - return smb.q2r(self._A) + return smb.q2r(self._A) @property def vec3(self) -> R3: @@ -1501,7 +1412,7 @@ def inv(self) -> UnitQuaternion: :seealso: :func:`~spatialmath.base.quaternions.qinv` """ - return UnitQuaternion([smb.qconj(q._A) for q in self]) + return UnitQuaternion(smb.qconj(self._A)) @staticmethod def qvmul(qv1: ArrayLike3, qv2: ArrayLike3) -> R3: @@ -1667,30 +1578,6 @@ def __mul__( else: raise ValueError("UnitQuaternion: operands to * are of different types") - def __imul__( - left, right: UnitQuaternion - ) -> UnitQuaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument - """ - Multiply unit quaternion in place - - :return: product - :rtype: UnitQuaternion, Quaternion - :raises: ValueError - - Multiplies a quaternion in place. If the right operand is a list, - the result will be a list. - - Example:: - - >>> q = UQ.Rx(0.3) - >>> q *= UQ.Rx(0.3) - >>> q - - :seealso: :func:`__mul__` - - """ - return left.__mul__(right) - def __truediv__( left, right: UnitQuaternion ) -> UnitQuaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument @@ -1840,7 +1727,7 @@ def __matmul__( def interp( self, end: UnitQuaternion, s: float = 0, shortest: Optional[bool] = False - ) -> UnitQuaternion: + ) -> List[UnitQuaternion]: """ Interpolate between two unit quaternions @@ -1915,9 +1802,9 @@ def interp( s1 = float(math.cos(theta) - dot * math.sin(theta) / math.sin(theta_0)) s2 = math.sin(theta) / math.sin(theta_0) out = (q1 * s1) + (q2 * s2) - qi.append(out) + qi.append(UnitQuaternion(out)) - return UnitQuaternion(qi) + return qi def interp1(self, s: float = 0, shortest: Optional[bool] = False) -> UnitQuaternion: """ @@ -1987,9 +1874,9 @@ def interp1(self, s: float = 0, shortest: Optional[bool] = False) -> UnitQuatern s1 = float(math.cos(theta) - dot * math.sin(theta) / math.sin(theta_0)) s2 = math.sin(theta) / math.sin(theta_0) out = np.r_[s1, 0, 0, 0] + (q * s2) - qi.append(out) + qi.append(UnitQuaternion(out)) - return UnitQuaternion(qi) + return qi def increment(self, w: ArrayLike3, normalize: Optional[bool] = False) -> None: """ @@ -2016,7 +1903,7 @@ def increment(self, w: ArrayLike3, normalize: Optional[bool] = False) -> None: updated = smb.qqmul(self.A, np.r_[ds, dv]) if normalize: updated = smb.qunit(updated) - self.data = [updated] + self.data = updated def plot(self, *args: List, **kwargs): """ @@ -2059,10 +1946,7 @@ def animate(self, *args: List, **kwargs): :see :func:`~spatialmath.base.transforms3d.tranimate` :func:`~spatialmath.base.transforms3d.trplot` """ - if len(self) > 1: - return smb.tranimate([smb.q2r(q) for q in self.data], *args, **kwargs) - else: - return smb.tranimate(smb.q2r(self._A), *args, **kwargs) + return smb.tranimate(smb.q2r(self._A), *args, **kwargs) def rpy( self, unit: Optional[str] = "rad", order: Optional[str] = "zyx" @@ -2106,10 +1990,7 @@ def rpy( :seealso: :meth:`SE3.RPY` :func:`~spatialmath.base.transforms3d.tr2rpy` """ - if len(self) == 1: - return smb.tr2rpy(self.R, unit=unit, order=order) - else: - return np.array([smb.tr2rpy(q.R, unit=unit, order=order) for q in self]) + return smb.tr2rpy(self.R, unit=unit, order=order) def eul(self, unit: Optional[str] = "rad") -> Union[R3, RNx3]: r""" @@ -2142,10 +2023,7 @@ def eul(self, unit: Optional[str] = "rad") -> Union[R3, RNx3]: :seealso: :meth:`SE3.Eul` :func:`~spatialmath.base.transforms3d.tr2eul` """ - if len(self) == 1: - return smb.tr2eul(self.R, unit=unit) - else: - return np.array([smb.tr2eul(q.R, unit=unit) for q in self]) + return smb.tr2eul(self.R, unit=unit) def angvec(self, unit: Optional[str] = "rad") -> Tuple[float, R3]: r""" @@ -2316,17 +2194,3 @@ def SE3(self) -> SE3: """ return SE3(smb.r2t(self.R), check=False) - - -if __name__ == "__main__": # pragma: no cover - import pathlib - - a = UnitQuaternion([0, 1, 0, 0]) - - exec( - open( - pathlib.Path(__file__).parent.parent.absolute() - / "tests" - / "test_quaternion.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/spatialvector.py b/spatialmath/spatialvector.py index f839e359..05173a09 100644 --- a/spatialmath/spatialvector.py +++ b/spatialmath/spatialvector.py @@ -18,6 +18,7 @@ import numpy as np from spatialmath.baseposelist import BasePoseList from spatialmath import base +from spatialmath.base.types import Self from spatialmath.pose3d import SE3 from spatialmath.twist import Twist3 @@ -93,16 +94,13 @@ def __init__(self, value): """ # print('spatialVec6 init') - super().__init__() if base.isvector(value, 6): - self.data = [np.array(value)] + self.data = np.array(value) elif base.isvector(value, 3): - self.data = [np.r_[value, 0, 0, 0]] + self.data = np.r_[value, 0, 0, 0] elif isinstance(value, SpatialVector): - self.data = [value.A] - elif base.ismatrix(value, (6, None)): - self.data = [x for x in value.T] + self.data = value.A elif not super().arghandler(value): raise ValueError("bad argument to constructor") @@ -144,9 +142,6 @@ def shape(self): """ return (6,) - def __getitem__(self, i): - return self.__class__(self.data[i]) - # ------------------------------------------------------------------------ # def __repr__(self): @@ -180,12 +175,7 @@ def __str__(self): line per element. """ typ = type(self).__name__ - return "\n".join( - [ - "{:s}[{:.5g} {:.5g} {:.5g}; {:.5g} {:.5g} {:.5g}]".format(typ, *list(x)) - for x in self.data - ] - ) + return "{:s}[{:.5g} {:.5g} {:.5g}; {:.5g} {:.5g} {:.5g}]".format(typ, *self.data) def __neg__(self): """ @@ -203,7 +193,7 @@ def __neg__(self): # for i=1:numel(obj) # y(i) = obj.new(-obj(i).vw); - return self.__class__([-x for x in self.data]) + return self.__class__(-self.data) def __add__( left, right @@ -229,7 +219,7 @@ def __add__( if len(left) != len(right): raise ValueError("can only add equal length arrays of spatial vectors") - return left.__class__([x + y for x, y in zip(left.data, right.data)]) + return left.__class__(left.data + right.data) def __sub__( left, right @@ -254,7 +244,7 @@ def __sub__( if len(left) != len(right): raise ValueError("can only add equal length arrays of spatial vectors") - return left.__class__([x - y for x, y in zip(left.data, right.data)]) + return left.__class__(left.data - right.data) def __rmul__( right, left @@ -304,10 +294,6 @@ class SpatialM6(SpatialVector): :seealso: :func:`~spatialmath.spatialvector.SpatialVelocity`, :func:`~spatialmath.spatialvector.SpatialAcceleration` """ - @abstractmethod - def __init__(self, value): - super().__init__(value) - def cross(self, other): r""" Spatial vector cross product @@ -364,10 +350,6 @@ class SpatialF6(SpatialVector): :seealso: :func:`~spatialmath.spatialvector.SpatialForce`, :func:`~spatialmath.spatialvector.SpatialMomentum`. """ - @abstractmethod - def __init__(self, value): - super().__init__(value) - def dot(self, value): return np.dot(self.A, base.getvector(value, 6)) @@ -390,9 +372,6 @@ class SpatialVelocity(SpatialM6): """ - def __init__(self, value=None): - super().__init__(value) - # def cross(self, other): # r""" # Spatial vector cross product @@ -457,9 +436,6 @@ class SpatialAcceleration(SpatialM6): """ - def __init__(self, value=None): - super().__init__(value) - # ------------------------------------------------------------------------- # @@ -478,9 +454,6 @@ class SpatialForce(SpatialF6): :seealso: :func:`~spatialmath.spatialvector.SpatialF6`, :func:`~spatialmath.spatialvector.SpatialMomentum` """ - def __init__(self, value=None): - super().__init__(value) - # n = SpatialForce(val); def __rmul__( @@ -508,9 +481,6 @@ class SpatialMomentum(SpatialF6): :seealso: :func:`~spatialmath.spatialvector.SpatialF6`, :func:`~spatialmath.spatialvector.SpatialForce` """ - def __init__(self, value=None): - super().__init__(value) - # ------------------------------------------------------------------------- # @@ -552,7 +522,6 @@ def __init__(self, m=None, r=None, I=None): :SymPy: supported """ - super().__init__() if m is None and r is None and I is None: # no arguments @@ -571,7 +540,7 @@ def __init__(self, m=None, r=None, I=None): else: raise ValueError("bad values") - self.data = [I] + self.data = I @staticmethod def _identity(): @@ -600,9 +569,6 @@ def shape(self): """ return (6, 6) - def __getitem__(self, i): - return SpatialInertia(self.data[i]) - def __repr__(self): """ Convert to string @@ -679,55 +645,3 @@ def __rmul__( - ``v * I`` is the SpatialMomemtum of a body with SpatialInertia ``I`` and SpatialVelocity ``v``. """ return right.__mul__(left) - - -if __name__ == "__main__": - import numpy.testing as nt - import pathlib - - v = SpatialVelocity() - print(v) - print(len(v)) - v.append(v) - print(v) - print(len(v)) - - v = SpatialVelocity(np.r_[1, 2, 3, 4, 5, 6]) - print(v) - v = SpatialVelocity(np.r_[1, 2, 3]) - print(v) - - a = v + v - print(a) - - vj = SpatialVelocity() - - x = vj @ vj - print(x) - - # I = SpatialInertia() - # print(I) - # print(len(I)) - # I.append(I) - # print(I) - # print(len(I)) - - # z = SpatialForce([1,2,3,4,5,6]) - # print(z) - # z = SpatialMomentum([1,2,3,4,5,6]) - # print(z) - - v = SpatialVelocity() - a = SpatialAcceleration() - I = SpatialInertia() - x = I * v - print(I * v) - print(I * a) - - exec( - open( - pathlib.Path(__file__).parent.parent.absolute() - / "tests" - / "test_spatialvector.py" - ).read() - ) # pylint: disable=exec-used diff --git a/spatialmath/spline.py b/spatialmath/spline.py index 0a472ecc..176d991b 100644 --- a/spatialmath/spline.py +++ b/spatialmath/spline.py @@ -34,6 +34,7 @@ def visualize( animate: bool = False, repeat: bool = True, ax: Optional[plt.Axes] = None, + **kwargs, # pass through to tranimate ) -> None: """Displays an animation of the trajectory with the control poses against an optional input trajectory. @@ -62,7 +63,7 @@ def visualize( if animate: tranimate( - samples, length=pose_marker_length, wait=True, repeat=repeat + samples, length=pose_marker_length, wait=True, repeat=repeat, **kwargs, ) # animate pose along trajectory else: plt.show() @@ -98,7 +99,6 @@ def __init__( string options: ["not-a-knot" (default), "clamped", "natural", "periodic"]. For tuple options and details see the scipy docs link above. """ - super().__init__() self.control_poses = control_poses self.timepoints = np.array(timepoints) diff --git a/spatialmath/twist.py b/spatialmath/twist.py index f84a0f1b..b2bcb662 100644 --- a/spatialmath/twist.py +++ b/spatialmath/twist.py @@ -2,8 +2,10 @@ # Copyright (c) 2000 Peter Corke # MIT Licence, see details in top-level file: LICENCE +from functools import reduce import numpy as np import spatialmath.base as smb +from spatialmath.base.types import Self from spatialmath.baseposelist import BasePoseList from spatialmath.geom3d import Line3 @@ -55,9 +57,6 @@ class BaseTwist(BasePoseList): """ - def __init__(self): - super().__init__() # enable UserList superpowers - @property def S(self): """ @@ -76,10 +75,7 @@ def S(self): - if ``len(X)`` > 1 then return a list of vectors. """ # get the underlying numpy array - if len(self.data) == 1: - return self.data[0] - else: - return self.data + return self.data @property def isprismatic(self): @@ -102,10 +98,7 @@ def isprismatic(self): >>> x.isprismatic """ - if len(self) == 1: - return smb.iszerovec(self.w) - else: - return [smb.iszerovec(x.w) for x in self.data] + return smb.iszerovec(self.w) @property def isrevolute(self): @@ -128,10 +121,7 @@ def isrevolute(self): >>> x.isrevolute """ - if len(self) == 1: - return smb.iszerovec(self.v) - else: - return [smb.iszerovec(x.v) for x in self.data] + return smb.iszerovec(self.v) @property def isunit(self): @@ -154,10 +144,7 @@ def isunit(self): >>> S.isunit() """ - if len(self) == 1: - return smb.isunitvec(self.S) - else: - return [smb.isunitvec(x) for x in self.data] + return smb.isunitvec(self.S) @property def theta(self): @@ -192,9 +179,10 @@ def inv(self): >>> S.inv() >>> S * S.inv() """ - return self.__class__([-t for t in self.data]) + return self.__class__(-self.data) - def prod(self): + @classmethod + def prod(cls, twists) : r""" Product of twists (superclass method) @@ -214,17 +202,18 @@ def prod(self): >>> S.prod() >>> Twist3.Rx(0.9) """ - if self.N == 2: + if cls.N == 2: log = smb.trlog2 exp = smb.trexp2 else: log = smb.trlog exp = smb.trexp - twprod = exp(self.data[0]) - for tw in self.data[1:]: - twprod = twprod @ exp(tw) - return self.__class__(log(twprod)) + exp_twprod = reduce( + lambda exp_twprod, exp_tw: exp_twprod @ exp_tw, + [exp(tw.A) for tw in twists], + ) + return cls(log(exp_twprod)) def __eq__(left, right): # lgtm[py/not-named-self] pylint: disable=no-self-argument """ @@ -325,25 +314,22 @@ def __init__(self, arg=None, w=None, check=True): """ from spatialmath.pose3d import SE3 - super().__init__() - if w is None: # zero or one arguments passed if super().arghandler(arg, check=check): return elif isinstance(arg, SE3): - self.data = [arg.twist().A] + self.data = arg.twist().A elif w is not None and smb.isvector(w, 3) and smb.isvector(arg, 3): # Twist(v, w) - self.data = [np.r_[arg, w]] + self.data = np.r_[arg, w] return else: raise ValueError("bad value to Twist constructor") # ------------------------ SMUserList required ---------------------------# - @staticmethod def _identity(): return np.zeros((6,)) @@ -410,27 +396,18 @@ def shape(self): """ return (6,) - @property - def N(self): - """ - Dimension of the object's group + # Dimension of the object's group - :return: dimension - :rtype: int + # Dimension of the group is 3 for ``Twist3`` and corresponds to the + # dimension of the space (3D in this case) to which these + # rigid-body motions apply. - Dimension of the group is 3 for ``Twist3`` and corresponds to the - dimension of the space (3D in this case) to which these - rigid-body motions apply. + # Example: - Example: - - .. runblock:: pycon - - >>> from spatialmath import Twist3 - >>> x = Twist3() - >>> x.N - """ - return 3 + # >>> from spatialmath import Twist3 + # >>> x = Twist3() + # >>> x.N + N = 3 @property def v(self): @@ -450,7 +427,7 @@ def v(self): >>> t = Twist3([1, 2, 3, 4, 5, 6]) >>> t.v """ - return self.data[0][:3] + return self.data[:3] @property def w(self): @@ -471,7 +448,7 @@ def w(self): >>> t.w """ - return self.data[0][3:6] + return self.data[3:6] # -------------------- variant constructors ----------------------------# @@ -915,10 +892,7 @@ def skewa(self): >>> se >>> smb.trexp(se) """ - if len(self) == 1: - return smb.skewa(self.S) - else: - return [smb.skewa(x.S) for x in self] + return smb.skewa(self.S) @property def pitch(self): @@ -965,7 +939,7 @@ def line(self): >>> S = Twist3(T) >>> S.line() """ - return Line3([Line3(-tw.v + tw.pitch * tw.w, tw.w) for tw in self]) + return Line3(-self.v + self.pitch * self.w, self.w) @property def pole(self): @@ -1013,17 +987,8 @@ def SE3(self, theta=1, unit="rad"): theta = smb.getunit(theta, unit) - if len(theta) == 1: - # theta is a scalar - return SE3(smb.trexp(self.S * theta)) - else: - # theta is a vector - if len(self) == 1: - return SE3([smb.trexp(self.S * t) for t in theta]) - elif len(self) == len(theta): - return SE3([smb.trexp(S * t) for S, t in zip(self.data, theta)]) - else: - raise ValueError("length of twist and theta not consistent") + # theta is a scalar + return SE3(smb.trexp(self.S * theta)) def exp(self, theta=1, unit="rad"): """ @@ -1071,12 +1036,7 @@ def exp(self, theta=1, unit="rad"): theta = smb.getunit(theta, unit) - if len(self) == 1: - return SE3([smb.trexp(self.S * t) for t in theta], check=False) - elif len(self) == len(theta): - return SE3([smb.trexp(s * t) for s, t in zip(self.S, theta)], check=False) - else: - raise ValueError("length mismatch") + return SE3(smb.trexp(self.S * theta), check=False) # ------------------------- arithmetic -------------------------------# @@ -1185,13 +1145,8 @@ def __str__(self): >>> x = Twist3.R([1,2,3], [4,5,6]) >>> print(x) """ - return "\n".join( - [ - "({:.5g} {:.5g} {:.5g}; {:.5g} {:.5g} {:.5g})".format( - *list(smb.removesmall(tw.S)) - ) - for tw in self - ] + return "({:.5g} {:.5g} {:.5g}; {:.5g} {:.5g} {:.5g})".format( + *list(smb.removesmall(self.S)) ) def __repr__(self): @@ -1212,25 +1167,9 @@ def __repr__(self): >>> a """ - if len(self) == 0: - return "Twist([])" - elif len(self) == 1: - return "Twist3([{:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}])".format( - *list(self.S) - ) - else: - return ( - "Twist3([\n" - + ",\n".join( - [ - " [{:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}]".format( - *list(tw) - ) - for tw in self.data - ] - ) - + "\n])" - ) + return "Twist3([{:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}, {:.5g}])".format( + *list(self.S) + ) def _repr_pretty_(self, p, cycle): """ @@ -1243,13 +1182,7 @@ def _repr_pretty_(self, p, cycle): itself. """ - if len(self) == 1: - p.text(str(self)) - else: - for i, x in enumerate(self): - if i > 0: - p.break_() - p.text(f"{i:3d}: {str(x)}") + p.text(str(self)) # ======================================================================== # @@ -1283,8 +1216,6 @@ def __init__(self, arg=None, w=None, check=True): """ from spatialmath.pose2d import SE2 - super().__init__() - if w is None: # zero or one arguments passed if super().arghandler(arg, convertfrom=(SE2,), check=check): @@ -1292,7 +1223,7 @@ def __init__(self, arg=None, w=None, check=True): elif w is not None and smb.isscalar(w) and smb.isvector(arg, 2): # Twist(v, w) - self.data = [np.r_[arg, w]] + self.data = np.r_[arg, w] return raise ValueError("bad twist value") @@ -1412,27 +1343,18 @@ def UnitPrismatic(cls, a): # ------------------------ properties ---------------------------# - @property - def N(self): - """ - Dimension of the object's group - - :return: dimension - :rtype: int - - Dimension of the group is 2 for ``Twist2`` and corresponds to the - dimension of the space (2D in this case) to which these - rigid-body motions apply. + # Dimension of the object's group - Example: + # Dimension of the group is 2 for ``Twist2`` and corresponds to the + # dimension of the space (2D in this case) to which these + # rigid-body motions apply. - .. runblock:: pycon + # Example: - >>> from spatialmath import Twist2 - >>> x = Twist2() - >>> x.N - """ - return 2 + # >>> from spatialmath import Twist2 + # >>> x = Twist2() + # >>> x.N + N = 2 @property def v(self): @@ -1453,7 +1375,7 @@ def v(self): >>> t.v """ - return self.data[0][:2] + return self.data[:2] @property def w(self): @@ -1474,7 +1396,7 @@ def w(self): >>> t.w """ - return self.data[0][2] + return self.data[2] @property def pole(self): @@ -1557,10 +1479,7 @@ def skewa(self): >>> se >>> smb.trexp2(se) """ - if len(self) == 1: - return smb.skewa(self.S) - else: - return [smb.skewa(x.S) for x in self] + return smb.skewa(self.S) def exp(self, theta=1, unit="rad"): r""" @@ -1599,12 +1518,7 @@ def exp(self, theta=1, unit="rad"): theta = smb.getunit(theta, unit) - if len(self) == 1: - return SE2([smb.trexp2(self.S * t) for t in theta], check=False) - elif len(self) == len(theta): - return SE2([smb.trexp2(s * t) for s, t in zip(self.S, theta)], check=False) - else: - raise ValueError("length mismatch") + return SE2(smb.trexp2(self.S * theta), check=False) def unit(self): """ @@ -1792,7 +1706,7 @@ def __str__(self): >>> x = Twist2([1,2,3]) >>> print(x) """ - return "\n".join(["({:.5g} {:.5g}; {:.5g})".format(*list(tw.S)) for tw in self]) + return "({:.5g} {:.5g}; {:.5g})".format(*list(self.S)) def __repr__(self): """ @@ -1813,16 +1727,7 @@ def __repr__(self): """ - if len(self) == 1: - return "Twist2([{:.5g}, {:.5g}, {:.5g}])".format(*list(self.S)) - else: - return ( - "Twist2([\n" - + ",\n".join( - [" [{:.5g}, {:.5g}, {:.5g}}]".format(*list(tw.S)) for tw in self] - ) - + "\n])" - ) + return "Twist2([{:.5g}, {:.5g}, {:.5g}])".format(*list(self.S)) def _repr_pretty_(self, p, cycle): """ @@ -1835,20 +1740,4 @@ def _repr_pretty_(self, p, cycle): itself. """ - if len(self) == 1: - p.text(str(self)) - else: - for i, x in enumerate(self): - if i > 0: - p.break_() - p.text(f"{i:3d}: {str(x)}") - - -if __name__ == "__main__": # pragma: no cover - import pathlib - - exec( - open( - pathlib.Path(__file__).parent.parent.absolute() / "tests" / "test_twist.py" - ).read() - ) # pylint: disable=exec-used + p.text(str(self)) diff --git a/tests/base/test_argcheck.py b/tests/base/test_argcheck.py index 685393b5..4dec0889 100755 --- a/tests/base/test_argcheck.py +++ b/tests/base/test_argcheck.py @@ -6,34 +6,33 @@ @author: corkep """ -import unittest import numpy as np import numpy.testing as nt - +import pytest from spatialmath.base.argcheck import * -class Test_check(unittest.TestCase): +class Test_check: def test_ismatrix(self): a = np.eye(3, 3) - self.assertTrue(ismatrix(a, (3, 3))) - self.assertFalse(ismatrix(a, (4, 3))) - self.assertFalse(ismatrix(a, (3, 4))) - self.assertFalse(ismatrix(a, (4, 4))) + assert ismatrix(a, (3, 3)) + assert not ismatrix(a, (4, 3)) + assert not ismatrix(a, (3, 4)) + assert not ismatrix(a, (4, 4)) - self.assertTrue(ismatrix(a, (-1, 3))) - self.assertTrue(ismatrix(a, (3, -1))) - self.assertTrue(ismatrix(a, (-1, -1))) + assert ismatrix(a, (-1, 3)) + assert ismatrix(a, (3, -1)) + assert ismatrix(a, (-1, -1)) - self.assertFalse(ismatrix(1, (-1, -1))) + assert not ismatrix(1, (-1, -1)) def test_assertmatrix(self): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): assertmatrix(3) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): assertmatrix("not a matrix") - with self.assertRaises(TypeError): + with pytest.raises(TypeError): a = np.eye(3, 3, dtype=complex) assertmatrix(a) @@ -44,89 +43,89 @@ def test_assertmatrix(self): assertmatrix(a, (None, 3)) assertmatrix(a, (3, None)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): assertmatrix(a, (4, 3)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): assertmatrix(a, (4, None)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): assertmatrix(a, (None, 4)) def test_getmatrix(self): a = np.random.rand(4, 3) - self.assertEqual(getmatrix(a, (4, 3)).shape, (4, 3)) - self.assertEqual(getmatrix(a, (None, 3)).shape, (4, 3)) - self.assertEqual(getmatrix(a, (4, None)).shape, (4, 3)) - self.assertEqual(getmatrix(a, (None, None)).shape, (4, 3)) - with self.assertRaises(ValueError): + assert getmatrix(a, (4, 3)).shape == (4, 3) + assert getmatrix(a, (None, 3)).shape == (4, 3) + assert getmatrix(a, (4, None)).shape == (4, 3) + assert getmatrix(a, (None, None)).shape == (4, 3) + with pytest.raises(ValueError): m = getmatrix(a, (5, 3)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (5, None)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (None, 4)) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): m = getmatrix({}, (4, 3)) a = np.r_[1, 2, 3, 4] - self.assertEqual(getmatrix(a, (1, 4)).shape, (1, 4)) - self.assertEqual(getmatrix(a, (4, 1)).shape, (4, 1)) - self.assertEqual(getmatrix(a, (2, 2)).shape, (2, 2)) - with self.assertRaises(ValueError): + assert getmatrix(a, (1, 4)).shape == (1, 4) + assert getmatrix(a, (4, 1)).shape == (4, 1) + assert getmatrix(a, (2, 2)).shape == (2, 2) + with pytest.raises(ValueError): m = getmatrix(a, (5, None)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (None, 5)) a = [1, 2, 3, 4] - self.assertEqual(getmatrix(a, (1, 4)).shape, (1, 4)) - self.assertEqual(getmatrix(a, (4, 1)).shape, (4, 1)) - self.assertEqual(getmatrix(a, (2, 2)).shape, (2, 2)) - with self.assertRaises(ValueError): + assert getmatrix(a, (1, 4)).shape == (1, 4) + assert getmatrix(a, (4, 1)).shape == (4, 1) + assert getmatrix(a, (2, 2)).shape == (2, 2) + with pytest.raises(ValueError): m = getmatrix(a, (5, None)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (None, 5)) a = 7 - self.assertEqual(getmatrix(a, (1, 1)).shape, (1, 1)) - self.assertEqual(getmatrix(a, (None, None)).shape, (1, 1)) - with self.assertRaises(ValueError): + assert getmatrix(a, (1, 1)).shape == (1, 1) + assert getmatrix(a, (None, None)).shape == (1, 1) + with pytest.raises(ValueError): m = getmatrix(a, (2, 1)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (1, 2)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (None, 2)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (2, None)) a = 7.0 - self.assertEqual(getmatrix(a, (1, 1)).shape, (1, 1)) - self.assertEqual(getmatrix(a, (None, None)).shape, (1, 1)) - with self.assertRaises(ValueError): + assert getmatrix(a, (1, 1)).shape == (1, 1) + assert getmatrix(a, (None, None)).shape == (1, 1) + with pytest.raises(ValueError): m = getmatrix(a, (2, 1)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (1, 2)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (None, 2)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): m = getmatrix(a, (2, None)) def test_verifymatrix(self): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): assertmatrix(3) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): verifymatrix([3, 4]) a = np.eye(3, 3) verifymatrix(a, (3, 3)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): verifymatrix(a, (3, 4)) def test_unit(self): - self.assertIsInstance(getunit(1), np.ndarray) - self.assertIsInstance(getunit([1, 2]), np.ndarray) - self.assertIsInstance(getunit((1, 2)), np.ndarray) - self.assertIsInstance(getunit(np.r_[1, 2]), np.ndarray) - self.assertIsInstance(getunit(1.0, dim=0), float) + assert isinstance(getunit(1), np.ndarray) + assert isinstance(getunit([1, 2]), np.ndarray) + assert isinstance(getunit((1, 2)), np.ndarray) + assert isinstance(getunit(np.r_[1, 2]), np.ndarray) + assert isinstance(getunit(1.0, dim=0), float) nt.assert_equal(getunit(5, "rad"), 5) nt.assert_equal(getunit(5, "deg"), 5 * math.pi / 180.0) @@ -148,31 +147,31 @@ def test_unit(self): def test_isvector(self): # no length specified - self.assertTrue(isvector(2)) - self.assertTrue(isvector(2.0)) - self.assertTrue(isvector([1, 2, 3])) - self.assertTrue(isvector((1, 2, 3))) - self.assertTrue(isvector(np.array([1, 2, 3]))) - self.assertTrue(isvector(np.array([[1, 2, 3]]))) - self.assertTrue(isvector(np.array([[1], [2], [3]]))) + assert isvector(2) + assert isvector(2.0) + assert isvector([1, 2, 3]) + assert isvector((1, 2, 3)) + assert isvector(np.array([1, 2, 3])) + assert isvector(np.array([[1, 2, 3]])) + assert isvector(np.array([[1], [2], [3]])) # length specified - self.assertTrue(isvector(2, 1)) - self.assertTrue(isvector(2.0, 1)) - self.assertTrue(isvector([1, 2, 3], 3)) - self.assertTrue(isvector((1, 2, 3), 3)) - self.assertTrue(isvector(np.array([1, 2, 3]), 3)) - self.assertTrue(isvector(np.array([[1, 2, 3]]), 3)) - self.assertTrue(isvector(np.array([[1], [2], [3]]), 3)) + assert isvector(2, 1) + assert isvector(2.0, 1) + assert isvector([1, 2, 3], 3) + assert isvector((1, 2, 3), 3) + assert isvector(np.array([1, 2, 3]), 3) + assert isvector(np.array([[1, 2, 3]]), 3) + assert isvector(np.array([[1], [2], [3]]), 3) # wrong length specified - self.assertFalse(isvector(2, 4)) - self.assertFalse(isvector(2.0, 4)) - self.assertFalse(isvector([1, 2, 3], 4)) - self.assertFalse(isvector((1, 2, 3), 4)) - self.assertFalse(isvector(np.array([1, 2, 3]), 4)) - self.assertFalse(isvector(np.array([[1, 2, 3]]), 4)) - self.assertFalse(isvector(np.array([[1], [2], [3]]), 4)) + assert not isvector(2, 4) + assert not isvector(2.0, 4) + assert not isvector([1, 2, 3], 4) + assert not isvector((1, 2, 3), 4) + assert not isvector(np.array([1, 2, 3]), 4) + assert not isvector(np.array([[1, 2, 3]]), 4) + assert not isvector(np.array([[1], [2], [3]]), 4) def test_isvector(self): l = [1, 2, 3] @@ -187,43 +186,43 @@ def test_getvector(self): # input is list v = getvector(l) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(l, 3) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(l, out="sequence") - self.assertIsInstance(v, list) + assert isinstance(v, list) nt.assert_equal(len(v), 3) v = getvector(l, 3, out="sequence") - self.assertIsInstance(v, list) + assert isinstance(v, list) nt.assert_equal(len(v), 3) v = getvector(l, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(l, 3, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(l, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(l, 3, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(l, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) v = getvector(l, 3, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) nt.assert_raises(ValueError, getvector, l, 4) @@ -235,43 +234,43 @@ def test_getvector(self): # input is tuple v = getvector(t) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(t, 3) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(t, out="sequence") - self.assertIsInstance(v, tuple) + assert isinstance(v, tuple) nt.assert_equal(len(v), 3) v = getvector(t, 3, out="sequence") - self.assertIsInstance(v, tuple) + assert isinstance(v, tuple) nt.assert_equal(len(v), 3) v = getvector(t, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(t, 3, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(t, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(t, 3, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(t, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) v = getvector(t, 3, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) nt.assert_raises(ValueError, getvector, t, 4) @@ -283,43 +282,43 @@ def test_getvector(self): # input is array v = getvector(a) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(a, 3) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(a, out="sequence") - self.assertIsInstance(v, list) + assert isinstance(v, list) nt.assert_equal(len(v), 3) v = getvector(a, 3, out="sequence") - self.assertIsInstance(v, list) + assert isinstance(v, list) nt.assert_equal(len(v), 3) v = getvector(a, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(a, 3, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(a, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(a, 3, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(a, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) v = getvector(a, 3, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) nt.assert_raises(ValueError, getvector, a, 4) @@ -331,43 +330,43 @@ def test_getvector(self): # input is row v = getvector(r) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(r, 3) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(r, out="sequence") - self.assertIsInstance(v, list) + assert isinstance(v, list) nt.assert_equal(len(v), 3) v = getvector(r, 3, out="sequence") - self.assertIsInstance(v, list) + assert isinstance(v, list) nt.assert_equal(len(v), 3) v = getvector(r, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(r, 3, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(r, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(r, 3, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(r, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) v = getvector(r, 3, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) nt.assert_raises(ValueError, getvector, r, 4) @@ -379,43 +378,43 @@ def test_getvector(self): # input is col v = getvector(c) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(c, 3) - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(len(v), 3) v = getvector(c, out="sequence") - self.assertIsInstance(v, list) + assert isinstance(v, list) nt.assert_equal(len(v), 3) v = getvector(c, 3, out="sequence") - self.assertIsInstance(v, list) + assert isinstance(v, list) nt.assert_equal(len(v), 3) v = getvector(c, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(c, 3, out="array") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3,)) v = getvector(c, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(c, 3, out="row") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (1, 3)) v = getvector(c, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) v = getvector(c, 3, out="col") - self.assertIsInstance(v, np.ndarray) + assert isinstance(v, np.ndarray) nt.assert_equal(v.shape, (3, 1)) nt.assert_raises(ValueError, getvector, c, 4) @@ -435,30 +434,25 @@ def test_isnumberlist(self): def test_isvectorlist(self): a = [np.r_[1, 2], np.r_[3, 4], np.r_[5, 6]] - self.assertTrue(isvectorlist(a, 2)) + assert isvectorlist(a, 2) a = [(1, 2), (3, 4), (5, 6)] - self.assertFalse(isvectorlist(a, 2)) + assert not isvectorlist(a, 2) a = [np.r_[1, 2], np.r_[3, 4], np.r_[5, 6, 7]] - self.assertFalse(isvectorlist(a, 2)) + assert not isvectorlist(a, 2) def test_islistof(self): a = [3, 4, 5] - self.assertTrue(islistof(a, int)) - self.assertFalse(islistof(a, float)) - self.assertTrue(islistof(a, lambda x: isinstance(x, int))) + assert islistof(a, int) + assert not islistof(a, float) + assert islistof(a, lambda x: isinstance(x, int)) - self.assertTrue(islistof(a, int, 3)) - self.assertFalse(islistof(a, int, 2)) + assert islistof(a, int, 3) + assert not islistof(a, int, 2) a = [3, 4.5, 5.6] - self.assertFalse(islistof(a, int)) - self.assertTrue(islistof(a, (int, float))) + assert not islistof(a, int) + assert islistof(a, (int, float)) a = [[1, 2], [3, 4], [5, 6]] - self.assertTrue(islistof(a, lambda x: islistof(x, int, 2))) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": # pragma: no cover - unittest.main() + assert islistof(a, lambda x: islistof(x, int, 2)) diff --git a/tests/base/test_graphics.py b/tests/base/test_graphics.py index 552ebdb0..125ec79c 100644 --- a/tests/base/test_graphics.py +++ b/tests/base/test_graphics.py @@ -1,5 +1,6 @@ -import unittest import numpy as np +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt import pytest import sys @@ -8,8 +9,7 @@ # test graphics primitives # TODO check they actually create artists - -class TestGraphics(unittest.TestCase): +class TestGraphics: def teardown_method(self, method): plt.close("all") @@ -143,8 +143,3 @@ def test_cone(self): resolution=5, color="red", ) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main(buffer=True) diff --git a/tests/base/test_numeric.py b/tests/base/test_numeric.py index e6b9de50..1ed794fa 100755 --- a/tests/base/test_numeric.py +++ b/tests/base/test_numeric.py @@ -9,14 +9,14 @@ import numpy as np import numpy.testing as nt -import unittest +import pytest import math from spatialmath.base.numeric import * -class TestNumeric(unittest.TestCase): +class TestNumeric: def test_numjac(self): pass @@ -26,45 +26,45 @@ def test_array2str(self): x = [1.2345678] s = array2str(x) - self.assertIsInstance(s, str) - self.assertEqual(s, "[ 1.23 ]") + assert isinstance(s, str) + assert s == "[ 1.23 ]" s = array2str(x, fmt="{:.5f}") - self.assertEqual(s, "[ 1.23457 ]") + assert s == "[ 1.23457 ]" s = array2str([1, 2, 3]) - self.assertEqual(s, "[ 1, 2, 3 ]") + assert s == "[ 1, 2, 3 ]" s = array2str([1, 2, 3], valuesep=":") - self.assertEqual(s, "[ 1:2:3 ]") + assert s == "[ 1:2:3 ]" s = array2str([1, 2, 3], brackets=("<< ", " >>")) - self.assertEqual(s, "<< 1, 2, 3 >>") + assert s == "<< 1, 2, 3 >>" s = array2str([1, 2e-8, 3]) - self.assertEqual(s, "[ 1, 2e-08, 3 ]") + assert s == "[ 1, 2e-08, 3 ]" s = array2str([1, -2e-14, 3]) - self.assertEqual(s, "[ 1, 0, 3 ]") + assert s == "[ 1, 0, 3 ]" x = np.array([[1, 2, 3], [4, 5, 6]]) s = array2str(x) - self.assertEqual(s, "[ 1, 2, 3 | 4, 5, 6 ]") + assert s == "[ 1, 2, 3 | 4, 5, 6 ]" def test_bresenham(self): x, y = bresenham((-10, -10), (20, 10)) - self.assertIsInstance(x, np.ndarray) - self.assertEqual(x.ndim, 1) - self.assertIsInstance(y, np.ndarray) - self.assertEqual(y.ndim, 1) - self.assertEqual(len(x), len(y)) + assert isinstance(x, np.ndarray) + assert x.ndim == 1 + assert isinstance(y, np.ndarray) + assert y.ndim == 1 + assert len(x) == len(y) # test points are no more than sqrt(2) apart z = np.array([x, y]) d = np.diff(z, axis=1) d = np.linalg.norm(d, axis=0) - self.assertTrue(all(d <= np.sqrt(2))) + assert all(d <= np.sqrt(2)) x, y = bresenham((20, 10), (-10, -10)) @@ -72,7 +72,7 @@ def test_bresenham(self): z = np.array([x, y]) d = np.diff(z, axis=1) d = np.linalg.norm(d, axis=0) - self.assertTrue(all(d <= np.sqrt(2))) + assert all(d <= np.sqrt(2)) x, y = bresenham((-10, -10), (10, 20)) @@ -80,7 +80,7 @@ def test_bresenham(self): z = np.array([x, y]) d = np.diff(z, axis=1) d = np.linalg.norm(d, axis=0) - self.assertTrue(all(d <= np.sqrt(2))) + assert all(d <= np.sqrt(2)) x, y = bresenham((10, 20), (-10, -10)) @@ -88,25 +88,25 @@ def test_bresenham(self): z = np.array([x, y]) d = np.diff(z, axis=1) d = np.linalg.norm(d, axis=0) - self.assertTrue(all(d <= np.sqrt(2))) + assert all(d <= np.sqrt(2)) def test_mpq(self): data = np.array([[-1, 1, 1, -1], [-1, -1, 1, 1]]) - self.assertEqual(mpq_point(data, 0, 0), 4) - self.assertEqual(mpq_point(data, 1, 0), 0) - self.assertEqual(mpq_point(data, 0, 1), 0) + assert mpq_point(data, 0, 0) == 4 + assert mpq_point(data, 1, 0) == 0 + assert mpq_point(data, 0, 1) == 0 def test_gauss1d(self): x = np.arange(-10, 10, 0.02) y = gauss1d(2, 1, x) - self.assertEqual(len(x), len(y)) + assert len(x) == len(y) m = np.argmax(y) - self.assertAlmostEqual(x[m], 2) + assert x[m] == pytest.approx(2) def test_gauss2d(self): @@ -115,11 +115,5 @@ def test_gauss2d(self): Z = gauss2d([2, 3], np.eye(2), X, Y) m = np.unravel_index(np.argmax(Z, axis=None), Z.shape) - self.assertAlmostEqual(r[m[0]], 3) - self.assertAlmostEqual(r[m[1]], 2) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - - unittest.main() + assert r[m[0]] == pytest.approx(3) + assert r[m[1]] == pytest.approx(2) diff --git a/tests/base/test_quaternions.py b/tests/base/test_quaternions.py index c512c6d2..2ae73ccd 100644 --- a/tests/base/test_quaternions.py +++ b/tests/base/test_quaternions.py @@ -30,7 +30,7 @@ # 3. Peter Corke, 2020 import numpy.testing as nt -import unittest +import pytest from spatialmath.base.vectors import * import spatialmath.base as tr @@ -38,7 +38,7 @@ import spatialmath as sm -class TestQuaternion(unittest.TestCase): +class TestQuaternion: def test_ops(self): nt.assert_array_almost_equal(qeye(), np.r_[1, 0, 0, 0]) @@ -134,9 +134,9 @@ def test_rotation(self): large_rotation = math.pi + 0.01 q1 = r2q(tr.rotx(large_rotation), shortest=False) q2 = r2q(tr.rotx(large_rotation), shortest=True) - self.assertLess(q1[0], 0) - self.assertGreater(q2[0], 0) - self.assertTrue(qisequal(q1=q1, q2=q2, unitq=True)) + assert q1[0] < 0 + assert q2[0] > 0 + assert qisequal(q1=q1, q2=q2, unitq=True) def test_slerp(self): q1 = np.r_[0, 1, 0, 0] @@ -222,7 +222,7 @@ def test_r2q(self): nt.assert_array_almost_equal(q1a, r2q(r1.R, order="sxyz")) nt.assert_array_almost_equal(q1b, r2q(r1.R, order="xyzs")) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): nt.assert_array_almost_equal(q1a, r2q(r1.R, order="aaa")) def test_qangle(self): @@ -234,7 +234,3 @@ def test_qangle(self): q1 = [1., 0, 0, 0] q2 = [1 / np.sqrt(2), 1 / np.sqrt(2), 0, 0] # 90deg rotation about x-axis nt.assert_almost_equal(qangle(q1, q2), np.pi / 2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/base/test_symbolic.py b/tests/base/test_symbolic.py index 7a503b5b..4106381f 100644 --- a/tests/base/test_symbolic.py +++ b/tests/base/test_symbolic.py @@ -1,4 +1,4 @@ -import unittest +import pytest import math try: @@ -11,76 +11,70 @@ from spatialmath.base.symbolic import * -class Test_symbolic(unittest.TestCase): - @unittest.skipUnless(_symbolics, "sympy required") +class Test_symbolic: + @pytest.mark.skipif(not _symbolics, reason="sympy required") def test_symbol(self): theta = symbol("theta") - self.assertTrue(isinstance(theta, sp.Expr)) - self.assertTrue(theta.is_real) + assert isinstance(theta, sp.Expr) + assert theta.is_real theta = symbol("theta", real=False) - self.assertTrue(isinstance(theta, sp.Expr)) - self.assertFalse(theta.is_real) + assert isinstance(theta, sp.Expr) + assert not theta.is_real theta, psi = symbol("theta, psi") - self.assertTrue(isinstance(theta, sp.Expr)) - self.assertTrue(isinstance(psi, sp.Expr)) + assert isinstance(theta, sp.Expr) + assert isinstance(psi, sp.Expr) theta, psi = symbol("theta psi") - self.assertTrue(isinstance(theta, sp.Expr)) - self.assertTrue(isinstance(psi, sp.Expr)) + assert isinstance(theta, sp.Expr) + assert isinstance(psi, sp.Expr) q = symbol("q:6") - self.assertEqual(len(q), 6) + assert len(q) == 6 for _ in q: - self.assertTrue(isinstance(_, sp.Expr)) - self.assertTrue(_.is_real) + assert isinstance(_, sp.Expr) + assert _.is_real - @unittest.skipUnless(_symbolics, "sympy required") + @pytest.mark.skipif(not _symbolics, reason="sympy required") def test_issymbol(self): theta = symbol("theta") - self.assertFalse(issymbol(3)) - self.assertFalse(issymbol("not a symbol")) - self.assertFalse(issymbol([1, 2])) - self.assertTrue(issymbol(theta)) + assert not issymbol(3) + assert not issymbol("not a symbol") + assert not issymbol([1, 2]) + assert issymbol(theta) - @unittest.skipUnless(_symbolics, "sympy required") + @pytest.mark.skipif(not _symbolics, reason="sympy required") def test_functions(self): theta = symbol("theta") - self.assertTrue(isinstance(sin(theta), sp.Expr)) - self.assertTrue(isinstance(sin(1.0), float)) + assert isinstance(sin(theta), sp.Expr) + assert isinstance(sin(1.0), float) - self.assertTrue(isinstance(cos(theta), sp.Expr)) - self.assertTrue(isinstance(cos(1.0), float)) + assert isinstance(cos(theta), sp.Expr) + assert isinstance(cos(1.0), float) - self.assertTrue(isinstance(sqrt(theta), sp.Expr)) - self.assertTrue(isinstance(sqrt(1.0), float)) + assert isinstance(sqrt(theta), sp.Expr) + assert isinstance(sqrt(1.0), float) x = (theta - 1) * (theta + 1) - theta ** 2 - self.assertTrue(math.isclose(simplify(x).evalf(), -1)) + assert math.isclose(simplify(x).evalf(), -1) - @unittest.skipUnless(_symbolics, "sympy required") + @pytest.mark.skipif(not _symbolics, reason="sympy required") def test_constants(self): x = zero() - self.assertTrue(isinstance(x, sp.Expr)) - self.assertTrue(math.isclose(x.evalf(), 0)) + assert isinstance(x, sp.Expr) + assert math.isclose(x.evalf(), 0) x = one() - self.assertTrue(isinstance(x, sp.Expr)) - self.assertTrue(math.isclose(x.evalf(), 1)) + assert isinstance(x, sp.Expr) + assert math.isclose(x.evalf(), 1) x = negative_one() - self.assertTrue(isinstance(x, sp.Expr)) - self.assertTrue(math.isclose(x.evalf(), -1)) + assert isinstance(x, sp.Expr) + assert math.isclose(x.evalf(), -1) x = pi() - self.assertTrue(isinstance(x, sp.Expr)) - self.assertTrue(math.isclose(x.evalf(), math.pi)) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": # pragma: no cover - - unittest.main() + assert isinstance(x, sp.Expr) + assert math.isclose(x.evalf(), math.pi) diff --git a/tests/base/test_transforms.py b/tests/base/test_transforms.py index 71b01bb3..62fed783 100755 --- a/tests/base/test_transforms.py +++ b/tests/base/test_transforms.py @@ -10,7 +10,6 @@ import numpy as np import numpy.testing as nt -import unittest from math import pi import math from scipy.linalg import logm, expm @@ -18,10 +17,12 @@ from spatialmath.base import * from spatialmath.base import sym +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -class TestLie(unittest.TestCase): +class TestLie: def test_vex(self): S = np.array([[0, -3], [3, 0]]) @@ -319,9 +320,3 @@ def test_trexp2(self): def test_trnorm(self): T0 = transl(-1, -2, -3) @ trotx(-0.3) nt.assert_array_almost_equal(trnorm(T0), T0) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - - unittest.main() diff --git a/tests/base/test_transforms2d.py b/tests/base/test_transforms2d.py index ff099930..24c284a2 100755 --- a/tests/base/test_transforms2d.py +++ b/tests/base/test_transforms2d.py @@ -9,7 +9,6 @@ import numpy as np import numpy.testing as nt -import unittest from math import pi import math from scipy.linalg import logm, expm @@ -29,10 +28,12 @@ ) from spatialmath.base.numeric import numjac +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -class Test2D(unittest.TestCase): +class Test2D: def test_rot2(self): R = np.array([[1, 0], [0, 1]]) nt.assert_array_almost_equal(rot2(0), R) @@ -100,14 +101,14 @@ def test_trnorm2(self): R = rot2(0.4) R = np.round(R, 3) # approx SO(2) R = trnorm2(R) - self.assertTrue(isrot2(R, check=True)) + assert isrot2(R, check=True) R = rot2(0.4) R = np.round(R, 3) # approx SO(2) T = rt2tr(R, [3, 4]) T = trnorm2(T) - self.assertTrue(ishom2(T, check=True)) + assert ishom2(T, check=True) nt.assert_almost_equal(T[:2, 2], [3, 4]) def test_transl2(self): @@ -180,8 +181,8 @@ def test_print2(self): T = transl2(1, 2) @ trot2(0.3) s = trprint2(T, file=None) - self.assertIsInstance(s, str) - self.assertEqual(len(s), 15) + assert isinstance(s, str) + assert len(s) == 15 def test_checks(self): # 2D case, with rotation matrix @@ -297,8 +298,3 @@ def test_plot(self): transl2(4, 3) @ trot2(math.pi / 3), block=False, color="green", frame="c" ) plt.close("all") - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() diff --git a/tests/base/test_transforms3d.py b/tests/base/test_transforms3d.py index 2f1e6049..213502f7 100755 --- a/tests/base/test_transforms3d.py +++ b/tests/base/test_transforms3d.py @@ -10,7 +10,7 @@ import numpy as np import numpy.testing as nt -import unittest +import pytest from math import pi import math from scipy.linalg import logm, expm @@ -19,7 +19,7 @@ from spatialmath.base.transformsNd import isR, t2r, r2t, rt2tr, skew -class Test3D(unittest.TestCase): +class Test3D: def test_checks(self): # 2D case, with rotation matrix R = np.eye(2) @@ -442,14 +442,14 @@ def test_trnorm(self): R = rpy2r(0.2, 0.3, 0.4) R = np.round(R, 3) # approx SO(3) R = trnorm(R) - self.assertTrue(isrot(R, check=True)) + assert isrot(R, check=True) R = rpy2r(0.2, 0.3, 0.4) R = np.round(R, 3) # approx SO(3) T = rt2tr(R, [3, 4, 5]) T = trnorm(T) - self.assertTrue(ishom(T, check=True)) + assert ishom(T, check=True) nt.assert_almost_equal(T[:3, 3], [3, 4, 5]) def test_tr2eul(self): @@ -528,7 +528,7 @@ def test_tr2angvec(self): # check a rotation matrix that should fail badR = roty(true_ang) + eps - with self.assertRaises(ValueError): + with pytest.raises(ValueError): tr2angvec(badR, check=True) # run without check @@ -539,27 +539,27 @@ def test_tr2angvec(self): def test_print(self): R = rotx(0.3) @ roty(0.4) s = trprint(R, file=None) - self.assertIsInstance(s, str) - self.assertEqual(len(s), 30) + assert isinstance(s, str) + assert len(s) == 30 T = transl(1, 2, 3) @ trotx(0.3) @ troty(0.4) s = trprint(T, file=None) - self.assertIsInstance(s, str) - self.assertEqual(len(s), 42) - self.assertTrue("rpy" in s) - self.assertTrue("zyx" in s) + assert isinstance(s, str) + assert len(s) == 42 + assert "rpy" in s + assert "zyx" in s s = trprint(T, file=None, orient="rpy/xyz") - self.assertIsInstance(s, str) - self.assertEqual(len(s), 39) - self.assertTrue("rpy" in s) - self.assertTrue("xyz" in s) + assert isinstance(s, str) + assert len(s) == 39 + assert "rpy" in s + assert "xyz" in s s = trprint(T, file=None, orient="eul") - self.assertIsInstance(s, str) - self.assertEqual(len(s), 37) - self.assertTrue("eul" in s) - self.assertFalse("zyx" in s) + assert isinstance(s, str) + assert len(s) == 37 + assert "eul" in s + assert not "zyx" in s def test_trinterp(self): R0 = rotx(-0.3) @@ -805,7 +805,3 @@ def test_x2tr(self): nt.assert_array_almost_equal( x2tr(x, representation="exp"), transl(t) @ r2t(trexp(gamma)) ) - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() diff --git a/tests/base/test_transforms3d_plot.py b/tests/base/test_transforms3d_plot.py index f250df4a..517b616d 100755 --- a/tests/base/test_transforms3d_plot.py +++ b/tests/base/test_transforms3d_plot.py @@ -10,7 +10,6 @@ import numpy as np import numpy.testing as nt -import unittest from math import pi import math from scipy.linalg import logm, expm @@ -20,10 +19,12 @@ from spatialmath.base.transforms3d import * from spatialmath.base.transformsNd import isR, t2r, r2t, rt2tr +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -class Test3D(unittest.TestCase): +class Test3D: @pytest.mark.skipif( sys.platform.startswith("darwin") and sys.version_info < (3, 11), reason="tkinter bug with mac", @@ -76,16 +77,11 @@ def test_plot(self): reason="tkinter bug with mac", ) def test_animate(self): - tranimate(transl(1, 2, 3), repeat=False, wait=True) + tranimate(transl(1, 2, 3), repeat=False, wait=True, movie=True) - tranimate(transl(1, 2, 3), repeat=False, wait=True) + tranimate(transl(1, 2, 3), repeat=False, wait=True, movie=True) # run again, with axes already created - tranimate(transl(1, 2, 3), repeat=False, wait=True, dims=[0, 10, 0, 10, 0, 10]) + tranimate(transl(1, 2, 3), repeat=False, wait=True, dims=[0, 10, 0, 10, 0, 10], movie=True) plt.close("all") # test animate with line not arrow, text, test with SO(3) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() diff --git a/tests/base/test_transformsNd.py b/tests/base/test_transformsNd.py index 92d9e2a3..829a827f 100755 --- a/tests/base/test_transformsNd.py +++ b/tests/base/test_transformsNd.py @@ -9,7 +9,7 @@ import numpy as np import numpy.testing as nt -import unittest +import pytest from math import pi import math from scipy.linalg import logm, expm @@ -25,20 +25,22 @@ from spatialmath.base.symbolic import symbol except ImportError: _symbolics = False +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -class TestND(unittest.TestCase): +class TestND: def test_iseye(self): - self.assertTrue(iseye(np.eye(1))) - self.assertTrue(iseye(np.eye(2))) - self.assertTrue(iseye(np.eye(3))) - self.assertTrue(iseye(np.eye(5))) + assert iseye(np.eye(1)) + assert iseye(np.eye(2)) + assert iseye(np.eye(3)) + assert iseye(np.eye(5)) - self.assertFalse(iseye(2 * np.eye(3))) - self.assertFalse(iseye(-np.eye(3))) - self.assertFalse(iseye(np.array([[1, 0, 0], [0, 1, 0]]))) - self.assertFalse(iseye(np.array([1, 0, 0]))) + assert not iseye(2 * np.eye(3)) + assert not iseye(-np.eye(3)) + assert not iseye(np.array([[1, 0, 0], [0, 1, 0]])) + assert not iseye(np.array([1, 0, 0])) def test_r2t(self): # 3D @@ -53,32 +55,32 @@ def test_r2t(self): nt.assert_array_almost_equal(T[0:2, 2], np.r_[0, 0]) nt.assert_array_almost_equal(T[:2, :2], R) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): r2t(3) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): r2t(np.eye(3, 4)) _ = r2t(np.ones((3, 3)), check=False) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): r2t(np.ones((3, 3)), check=True) - @unittest.skipUnless(_symbolics, "sympy required") + @pytest.mark.skipif(not _symbolics, reason="sympy required") def test_r2t_sym(self): theta = symbol("theta") R = rot2(theta) T = r2t(R) - self.assertEqual(r2t(R).dtype, "O") + assert r2t(R).dtype == "O" nt.assert_array_almost_equal(T[0:2, 2], np.r_[0, 0]) nt.assert_array_almost_equal(T[:2, :2], R) theta = symbol("theta") R = rotx(theta) T = r2t(R) - self.assertEqual(r2t(R).dtype, "O") + assert r2t(R).dtype == "O" nt.assert_array_almost_equal(T[0:3, 3], np.r_[0, 0, 0]) # nt.assert_array_almost_equal(T[:3,:3], R) - self.assertTrue((T[:3, :3] == R).all()) + assert (T[:3, :3] == R).all() def test_t2r(self): # 3D @@ -95,10 +97,10 @@ def test_t2r(self): nt.assert_array_almost_equal(T[:2, :2], R) nt.assert_array_almost_equal(transl2(T), np.array(t)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): t2r(3) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): r2t(np.eye(3, 4)) def test_rt2tr(self): @@ -116,28 +118,28 @@ def test_rt2tr(self): nt.assert_array_almost_equal(t2r(T), R) nt.assert_array_almost_equal(transl2(T), np.array(t)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): rt2tr(3, 4) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): rt2tr(np.eye(3, 4), [1, 2, 3, 4]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): rt2tr(np.eye(4, 4), [1, 2, 3, 4]) _ = rt2tr(np.ones((3, 3)), [1, 2, 3], check=False) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): rt2tr(np.ones((3, 3)), [1, 2, 3], check=True) - @unittest.skipUnless(_symbolics, "sympy required") + @pytest.mark.skipif(not _symbolics, reason="sympy required") def test_rt2tr_sym(self): theta = symbol("theta") R = rotx(theta) - self.assertEqual(r2t(R).dtype, "O") + assert r2t(R).dtype == "O" theta = symbol("theta") R = rot2(theta) - self.assertEqual(r2t(R).dtype, "O") + assert r2t(R).dtype == "O" def test_tr2rt(self): # 3D @@ -152,10 +154,10 @@ def test_tr2rt(self): nt.assert_array_almost_equal(T[:2, :2], R) nt.assert_array_almost_equal(T[:2, 2], t) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): R, t = tr2rt(3) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): R, t = tr2rt(np.eye(3, 4)) def test_Ab2M(self): @@ -175,78 +177,76 @@ def test_Ab2M(self): nt.assert_array_almost_equal(T[:2, 2], np.array(t)) nt.assert_array_almost_equal(T[2, :], np.array([0, 0, 0])) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Ab2M(3, 4) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Ab2M(np.eye(3, 4), [1, 2, 3, 4]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Ab2M(np.eye(4, 4), [1, 2, 3, 4]) def test_checks(self): # 3D case, with rotation matrix R = np.eye(3) - self.assertTrue(isR(R)) - self.assertFalse(isrot2(R)) - self.assertTrue(isrot(R)) - self.assertFalse(ishom(R)) - self.assertTrue(ishom2(R)) - self.assertFalse(isrot2(R, True)) - self.assertTrue(isrot(R, True)) - self.assertFalse(ishom(R, True)) - self.assertTrue(ishom2(R, True)) + assert isR(R) + assert not isrot2(R) + assert isrot(R) + assert not ishom(R) + assert ishom2(R) + assert not isrot2(R, True) + assert isrot(R, True) + assert not ishom(R, True) + assert ishom2(R, True) # 3D case, invalid rotation matrix R = np.eye(3) R[0, 1] = 2 - self.assertFalse(isR(R)) - self.assertFalse(isrot2(R)) - self.assertTrue(isrot(R)) - self.assertFalse(ishom(R)) - self.assertTrue(ishom2(R)) - self.assertFalse(isrot2(R, True)) - self.assertFalse(isrot(R, True)) - self.assertFalse(ishom(R, True)) - self.assertFalse(ishom2(R, True)) + assert not isR(R) + assert not isrot2(R) + assert isrot(R) + assert not ishom(R) + assert ishom2(R) + assert not isrot2(R, True) + assert not isrot(R, True) + assert not ishom(R, True) + assert not ishom2(R, True) # 3D case, with rotation matrix T = np.array([[1, 0, 0, 3], [0, 1, 0, 4], [0, 0, 1, 5], [0, 0, 0, 1]]) - self.assertFalse(isR(T)) - self.assertFalse(isrot2(T)) - self.assertFalse(isrot(T)) - self.assertTrue(ishom(T)) - self.assertFalse(ishom2(T)) - self.assertFalse(isrot2(T, True)) - self.assertFalse(isrot(T, True)) - self.assertTrue(ishom(T, True)) - self.assertFalse(ishom2(T, True)) + assert not isR(T) + assert not isrot2(T) + assert not isrot(T) + assert ishom(T) + assert not ishom2(T) + assert not isrot2(T, True) + assert not isrot(T, True) + assert ishom(T, True) + assert not ishom2(T, True) # 3D case, invalid rotation matrix T = np.array([[1, 0, 0, 3], [0, 1, 1, 4], [0, 0, 1, 5], [0, 0, 0, 1]]) - self.assertFalse(isR(T)) - self.assertFalse(isrot2(T)) - self.assertFalse(isrot(T)) - self.assertTrue( - ishom(T), - ) - self.assertFalse(ishom2(T)) - self.assertFalse(isrot2(T, True)) - self.assertFalse(isrot(T, True)) - self.assertFalse(ishom(T, True)) - self.assertFalse(ishom2(T, True)) + assert not isR(T) + assert not isrot2(T) + assert not isrot(T) + assert ishom(T) + assert not ishom2(T) + assert not isrot2(T, True) + assert not isrot(T, True) + assert not ishom(T, True) + assert not ishom2(T, True) # 3D case, invalid bottom row T = np.array([[1, 0, 0, 3], [0, 1, 1, 4], [0, 0, 1, 5], [9, 0, 0, 1]]) - self.assertFalse(isR(T)) - self.assertFalse(isrot2(T)) - self.assertFalse(isrot(T)) - self.assertTrue(ishom(T)) - self.assertFalse(ishom2(T)) - self.assertFalse(isrot2(T, True)) - self.assertFalse(isrot(T, True)) - self.assertFalse(ishom(T, True)) - self.assertFalse(ishom2(T, True)) + assert not isR(T) + assert not isrot2(T) + assert not isrot(T) + assert ishom(T) + assert not ishom2(T) + assert not isrot2(T, True) + assert not isrot(T, True) + assert not ishom(T, True) + assert not ishom2(T, True) # skew matrices S = np.array([[0, 2], [-2, 0]]) @@ -283,7 +283,7 @@ def test_homtrans(self): v2 = homtrans(T, v) nt.assert_almost_equal(v2, np.c_[[-11, 12], [5, -1]]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): T = trotx(pi / 2, t=[1, 2, 3]) v = [10, 12] v2 = homtrans(T, v) @@ -291,21 +291,21 @@ def test_homtrans(self): def test_skew(self): # 3D sk = skew([1, 2, 3]) - self.assertEqual(sk.shape, (3, 3)) + assert sk.shape == (3, 3) nt.assert_almost_equal(sk + sk.T, np.zeros((3, 3))) - self.assertEqual(sk[2, 1], 1) - self.assertEqual(sk[0, 2], 2) - self.assertEqual(sk[1, 0], 3) + assert sk[2, 1] == 1 + assert sk[0, 2] == 2 + assert sk[1, 0] == 3 nt.assert_almost_equal(sk.diagonal(), np.r_[0, 0, 0]) # 2D sk = skew([1]) - self.assertEqual(sk.shape, (2, 2)) + assert sk.shape == (2, 2) nt.assert_almost_equal(sk + sk.T, np.zeros((2, 2))) - self.assertEqual(sk[1, 0], 1) + assert sk[1, 0] == 1 nt.assert_almost_equal(sk.diagonal(), np.r_[0, 0]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): sk = skew([1, 2]) def test_vex(self): @@ -320,51 +320,51 @@ def test_vex(self): nt.assert_almost_equal(vex(sk), t) _ = vex(np.ones((3, 3)), check=False) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = vex(np.ones((3, 3)), check=True) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = vex(np.eye(4, 4)) def test_isskew(self): t = [3, 4, 5] sk = skew(t) - self.assertTrue(isskew(sk)) + assert isskew(sk) sk[0, 0] = 3 - self.assertFalse(isskew(sk)) + assert not isskew(sk) # 2D t = [3] sk = skew(t) - self.assertTrue(isskew(sk)) + assert isskew(sk) sk[0, 0] = 3 - self.assertFalse(isskew(sk)) + assert not isskew(sk) def test_isskewa(self): # 3D t = [3, 4, 5, 6, 7, 8] sk = skewa(t) - self.assertTrue(isskewa(sk)) + assert isskewa(sk) sk[0, 0] = 3 - self.assertFalse(isskew(sk)) + assert not isskew(sk) sk = skewa(t) sk[3, 3] = 3 - self.assertFalse(isskew(sk)) + assert not isskew(sk) # 2D t = [3, 4, 5] sk = skew(t) - self.assertTrue(isskew(sk)) + assert isskew(sk) sk[0, 0] = 3 - self.assertFalse(isskew(sk)) + assert not isskew(sk) sk = skewa(t) sk[2, 2] = 3 - self.assertFalse(isskew(sk)) + assert not isskew(sk) def test_skewa(self): # 3D sk = skewa([1, 2, 3, 4, 5, 6]) - self.assertEqual(sk.shape, (4, 4)) + assert sk.shape == (4, 4) nt.assert_almost_equal(sk.diagonal(), np.r_[0, 0, 0, 0]) nt.assert_almost_equal(sk[-1, :], np.r_[0, 0, 0, 0]) nt.assert_almost_equal(sk[:3, 3], [1, 2, 3]) @@ -372,13 +372,13 @@ def test_skewa(self): # 2D sk = skewa([1, 2, 3]) - self.assertEqual(sk.shape, (3, 3)) + assert sk.shape == (3, 3) nt.assert_almost_equal(sk.diagonal(), np.r_[0, 0, 0]) nt.assert_almost_equal(sk[-1, :], np.r_[0, 0, 0]) nt.assert_almost_equal(sk[:2, 2], [1, 2]) nt.assert_almost_equal(vex(sk[:2, :2]), [3]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): sk = skew([1, 2]) def test_vexa(self): @@ -394,15 +394,10 @@ def test_vexa(self): def test_det(self): a = np.array([[1, 2], [3, 4]]) - self.assertAlmostEqual(np.linalg.det(a), det(a)) + assert np.linalg.det(a) == pytest.approx(det(a)) - @unittest.skipUnless(_symbolics, "sympy required") + @pytest.mark.skipif(not _symbolics, reason="sympy required") def test_det_sym(self): x, y = symbol("x y") a = np.array([[x, y], [y, x]]) - self.assertEqual(det(a), x**2 - y**2) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() + assert det(a) == x**2 - y**2 diff --git a/tests/base/test_vectors.py b/tests/base/test_vectors.py index 592c2d16..b18b5a2a 100755 --- a/tests/base/test_vectors.py +++ b/tests/base/test_vectors.py @@ -9,7 +9,7 @@ import numpy as np import numpy.testing as nt -import unittest +import pytest from math import pi import math from scipy.linalg import logm, expm @@ -24,12 +24,14 @@ _symbolics = True except ImportError: _symbolics = False +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt from math import pi -class TestVector(unittest.TestCase): +class TestVector: @classmethod def tearDownClass(cls): plt.close("all") @@ -51,54 +53,54 @@ def test_unit(self): nt.assert_array_almost_equal(unitvec([0, 9, 0]), np.r_[0, 1, 0]) nt.assert_array_almost_equal(unitvec([0, 0, 9]), np.r_[0, 0, 1]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): unitvec([0, 0, 0]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): unitvec([0]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): unitvec(0) def test_colvec(self): t = np.r_[1, 2, 3] cv = colvec(t) - self.assertEqual(cv.shape, (3, 1)) + assert cv.shape == (3, 1) nt.assert_array_almost_equal(cv.flatten(), t) def test_isunitvec(self): - self.assertTrue(isunitvec([1, 0, 0])) - self.assertTrue(isunitvec((1, 0, 0))) - self.assertTrue(isunitvec(np.r_[1, 0, 0])) + assert isunitvec([1, 0, 0]) + assert isunitvec((1, 0, 0)) + assert isunitvec(np.r_[1, 0, 0]) - self.assertFalse(isunitvec([9, 0, 0])) - self.assertFalse(isunitvec((9, 0, 0))) - self.assertFalse(isunitvec(np.r_[9, 0, 0])) + assert not isunitvec([9, 0, 0]) + assert not isunitvec((9, 0, 0)) + assert not isunitvec(np.r_[9, 0, 0]) - self.assertTrue(isunitvec(1)) - self.assertTrue(isunitvec([1])) - self.assertTrue(isunitvec(-1)) - self.assertTrue(isunitvec([-1])) + assert isunitvec(1) + assert isunitvec([1]) + assert isunitvec(-1) + assert isunitvec([-1]) - self.assertFalse(isunitvec(2)) - self.assertFalse(isunitvec([2])) - self.assertFalse(isunitvec(-2)) - self.assertFalse(isunitvec([-2])) + assert not isunitvec(2) + assert not isunitvec([2]) + assert not isunitvec(-2) + assert not isunitvec([-2]) def test_norm(self): - self.assertAlmostEqual(norm([0, 0, 0]), 0) - self.assertAlmostEqual(norm([1, 2, 3]), math.sqrt(14)) - self.assertAlmostEqual(norm(np.r_[1, 2, 3]), math.sqrt(14)) + assert norm([0, 0, 0]) == pytest.approx(0) + assert norm([1, 2, 3]) == pytest.approx(math.sqrt(14)) + assert norm(np.r_[1, 2, 3]) == pytest.approx(math.sqrt(14)) def test_normsq(self): - self.assertAlmostEqual(normsq([0, 0, 0]), 0) - self.assertAlmostEqual(normsq([1, 2, 3]), 14) - self.assertAlmostEqual(normsq(np.r_[1, 2, 3]), 14) + assert normsq([0, 0, 0]) == pytest.approx(0) + assert normsq([1, 2, 3]) == pytest.approx(14) + assert normsq(np.r_[1, 2, 3]) == pytest.approx(14) - @unittest.skipUnless(_symbolics, "sympy required") + @pytest.mark.skipif(not _symbolics, reason="sympy required") def test_norm_sym(self): x, y = symbol("x y") v = [x, y] - self.assertEqual(norm(v), sqrt(x**2 + y**2)) - self.assertEqual(norm(np.r_[v]), sqrt(x**2 + y**2)) + assert norm(v) == sqrt(x**2 + y**2) + assert norm(np.r_[v]) == sqrt(x**2 + y**2) def test_cross(self): A = np.eye(3) @@ -106,41 +108,41 @@ def test_cross(self): for i in range(0, 3): j = (i + 1) % 3 k = (i + 2) % 3 - self.assertTrue(all(cross(A[:, i], A[:, j]) == A[:, k])) + assert all(cross(A[:, i], A[:, j]) == A[:, k]) def test_isunittwist(self): # 3D # unit rotational twist - self.assertTrue(isunittwist([1, 2, 3, 1, 0, 0])) - self.assertTrue(isunittwist((1, 2, 3, 1, 0, 0))) - self.assertTrue(isunittwist(np.r_[1, 2, 3, 1, 0, 0])) + assert isunittwist([1, 2, 3, 1, 0, 0]) + assert isunittwist((1, 2, 3, 1, 0, 0)) + assert isunittwist(np.r_[1, 2, 3, 1, 0, 0]) # not a unit rotational twist - self.assertFalse(isunittwist([1, 2, 3, 1, 0, 1])) + assert not isunittwist([1, 2, 3, 1, 0, 1]) # unit translation twist - self.assertTrue(isunittwist([1, 0, 0, 0, 0, 0])) + assert isunittwist([1, 0, 0, 0, 0, 0]) # not a unit translation twist - self.assertFalse(isunittwist([2, 0, 0, 0, 0, 0])) + assert not isunittwist([2, 0, 0, 0, 0, 0]) # 2D # unit rotational twist - self.assertTrue(isunittwist2([1, 2, 1])) + assert isunittwist2([1, 2, 1]) # not a unit rotational twist - self.assertFalse(isunittwist2([1, 2, 3])) + assert not isunittwist2([1, 2, 3]) # unit translation twist - self.assertTrue(isunittwist2([1, 0, 0])) + assert isunittwist2([1, 0, 0]) # not a unit translation twist - self.assertFalse(isunittwist2([2, 0, 0])) + assert not isunittwist2([2, 0, 0]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): isunittwist([3, 4]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): isunittwist2([3, 4]) def test_unittwist(self): @@ -174,7 +176,7 @@ def test_unittwist(self): unittwist([0, 0, -2, 0, 0, 0]), np.r_[0, 0, -1, 0, 0, 0] ) - self.assertIsNone(unittwist([0, 0, 0, 0, 0, 0])) + assert unittwist([0, 0, 0, 0, 0, 0]) is None def test_unittwist_norm(self): a = unittwist_norm([0, 0, 0, 1, 0, 0]) @@ -214,8 +216,8 @@ def test_unittwist_norm(self): nt.assert_array_almost_equal(a[1], 2) a = unittwist_norm([0, 0, 0, 0, 0, 0]) - self.assertIsNone(a[0]) - self.assertIsNone(a[1]) + assert a[0] is None + assert a[1] is None def test_unittwist2(self): nt.assert_array_almost_equal( @@ -231,7 +233,7 @@ def test_unittwist2(self): unittwist2([2, 0, -2]), np.r_[1, 0, -1] ) - self.assertIsNone(unittwist2([0, 0, 0])) + assert unittwist2([0, 0, 0]) is None def test_unittwist2_norm(self): a = unittwist2_norm([1, 0, 0]) @@ -251,75 +253,75 @@ def test_unittwist2_norm(self): nt.assert_array_almost_equal(a[1], 2) a = unittwist2_norm([0, 0, 0]) - self.assertIsNone(a[0]) - self.assertIsNone(a[1]) + assert a[0] is None + assert a[1] is None def test_iszerovec(self): - self.assertTrue(iszerovec([0])) - self.assertTrue(iszerovec([0, 0])) - self.assertTrue(iszerovec([0, 0, 0])) + assert iszerovec([0]) + assert iszerovec([0, 0]) + assert iszerovec([0, 0, 0]) - self.assertFalse(iszerovec([1]), False) - self.assertFalse(iszerovec([0, 1]), False) - self.assertFalse(iszerovec([0, 1, 0]), False) + assert not iszerovec([1]), False + assert not iszerovec([0, 1]), False + assert not iszerovec([0, 1, 0]), False def test_iszero(self): - self.assertTrue(iszero(0)) - self.assertFalse(iszero(1)) + assert iszero(0) + assert not iszero(1) def test_angdiff(self): - self.assertEqual(angdiff(0, 0), 0) - self.assertIsInstance(angdiff(0, 0), float) - self.assertEqual(angdiff(pi, 0), -pi) - self.assertEqual(angdiff(-pi, pi), 0) + assert angdiff(0, 0) == 0 + assert isinstance(angdiff(0, 0), float) + assert angdiff(pi, 0) == -pi + assert angdiff(-pi, pi) == 0 x = angdiff([0, -pi, pi], 0) nt.assert_array_almost_equal(x, [0, -pi, -pi]) - self.assertIsInstance(x, np.ndarray) + assert isinstance(x, np.ndarray) nt.assert_array_almost_equal(angdiff([0, -pi, pi], pi), [-pi, 0, 0]) x = angdiff(0, [0, -pi, pi]) nt.assert_array_almost_equal(x, [0, -pi, -pi]) - self.assertIsInstance(x, np.ndarray) + assert isinstance(x, np.ndarray) nt.assert_array_almost_equal(angdiff(pi, [0, -pi, pi]), [-pi, 0, 0]) x = angdiff([1, 2, 3], [1, 2, 3]) nt.assert_array_almost_equal(x, [0, 0, 0]) - self.assertIsInstance(x, np.ndarray) + assert isinstance(x, np.ndarray) def test_wrap(self): - self.assertAlmostEqual(wrap_0_2pi(0), 0) - self.assertAlmostEqual(wrap_0_2pi(2 * pi), 0) - self.assertAlmostEqual(wrap_0_2pi(3 * pi), pi) - self.assertAlmostEqual(wrap_0_2pi(-pi), pi) + assert wrap_0_2pi(0) == pytest.approx(0) + assert wrap_0_2pi(2 * pi) == pytest.approx(0) + assert wrap_0_2pi(3 * pi) == pytest.approx(pi) + assert wrap_0_2pi(-pi) == pytest.approx(pi) nt.assert_array_almost_equal( wrap_0_2pi([0, 2 * pi, 3 * pi, -pi]), [0, 0, pi, pi] ) - self.assertAlmostEqual(wrap_mpi_pi(0), 0) - self.assertAlmostEqual(wrap_mpi_pi(-pi), -pi) - self.assertAlmostEqual(wrap_mpi_pi(pi), -pi) - self.assertAlmostEqual(wrap_mpi_pi(2 * pi), 0) - self.assertAlmostEqual(wrap_mpi_pi(1.5 * pi), -0.5 * pi) - self.assertAlmostEqual(wrap_mpi_pi(-1.5 * pi), 0.5 * pi) + assert wrap_mpi_pi(0) == pytest.approx(0) + assert wrap_mpi_pi(-pi) == pytest.approx(-pi) + assert wrap_mpi_pi(pi) == pytest.approx(-pi) + assert wrap_mpi_pi(2 * pi) == pytest.approx(0) + assert wrap_mpi_pi(1.5 * pi) == pytest.approx(-0.5 * pi) + assert wrap_mpi_pi(-1.5 * pi) == pytest.approx(0.5 * pi) nt.assert_array_almost_equal( wrap_mpi_pi([0, -pi, pi, 2 * pi, 1.5 * pi, -1.5 * pi]), [0, -pi, -pi, 0, -0.5 * pi, 0.5 * pi], ) - self.assertAlmostEqual(wrap_0_pi(0), 0) - self.assertAlmostEqual(wrap_0_pi(pi), pi) - self.assertAlmostEqual(wrap_0_pi(1.2 * pi), 0.8 * pi) - self.assertAlmostEqual(wrap_0_pi(-0.2 * pi), 0.2 * pi) + assert wrap_0_pi(0) == pytest.approx(0) + assert wrap_0_pi(pi) == pytest.approx(pi) + assert wrap_0_pi(1.2 * pi) == pytest.approx(0.8 * pi) + assert wrap_0_pi(-0.2 * pi) == pytest.approx(0.2 * pi) nt.assert_array_almost_equal( wrap_0_pi([0, pi, 1.2 * pi, -0.2 * pi]), [0, pi, 0.8 * pi, 0.2 * pi] ) - self.assertAlmostEqual(wrap_mpi2_pi2(0), 0) - self.assertAlmostEqual(wrap_mpi2_pi2(-0.5 * pi), -0.5 * pi) - self.assertAlmostEqual(wrap_mpi2_pi2(0.5 * pi), 0.5 * pi) - self.assertAlmostEqual(wrap_mpi2_pi2(0.6 * pi), 0.4 * pi) - self.assertAlmostEqual(wrap_mpi2_pi2(-0.6 * pi), -0.4 * pi) + assert wrap_mpi2_pi2(0) == pytest.approx(0) + assert wrap_mpi2_pi2(-0.5 * pi) == pytest.approx(-0.5 * pi) + assert wrap_mpi2_pi2(0.5 * pi) == pytest.approx(0.5 * pi) + assert wrap_mpi2_pi2(0.6 * pi) == pytest.approx(0.4 * pi) + assert wrap_mpi2_pi2(-0.6 * pi) == pytest.approx(-0.4 * pi) nt.assert_array_almost_equal( wrap_mpi2_pi2([0, -0.5 * pi, 0.5 * pi, 0.6 * pi, -0.6 * pi]), [0, -0.5 * pi, 0.5 * pi, 0.4 * pi, -0.4 * pi], @@ -327,27 +329,27 @@ def test_wrap(self): for angle_factor in (0, 0.3, 0.5, 0.8, 1.0, 1.3, 1.5, 1.7, 2): theta = angle_factor * pi - self.assertAlmostEqual(angle_wrap(theta), wrap_mpi_pi(theta)) - self.assertAlmostEqual(angle_wrap(-theta), wrap_mpi_pi(-theta)) - self.assertAlmostEqual(angle_wrap(theta=theta, mode="-pi:pi"), wrap_mpi_pi(theta)) - self.assertAlmostEqual(angle_wrap(theta=-theta, mode="-pi:pi"), wrap_mpi_pi(-theta)) - self.assertAlmostEqual(angle_wrap(theta=theta, mode="0:2pi"), wrap_0_2pi(theta)) - self.assertAlmostEqual(angle_wrap(theta=-theta, mode="0:2pi"), wrap_0_2pi(-theta)) - self.assertAlmostEqual(angle_wrap(theta=theta, mode="0:pi"), wrap_0_pi(theta)) - self.assertAlmostEqual(angle_wrap(theta=-theta, mode="0:pi"), wrap_0_pi(-theta)) - self.assertAlmostEqual(angle_wrap(theta=theta, mode="-pi/2:pi/2"), wrap_mpi2_pi2(theta)) - self.assertAlmostEqual(angle_wrap(theta=-theta, mode="-pi/2:pi/2"), wrap_mpi2_pi2(-theta)) - with self.assertRaises(ValueError): + assert angle_wrap(theta) == pytest.approx(wrap_mpi_pi(theta)) + assert angle_wrap(-theta) == pytest.approx(wrap_mpi_pi(-theta)) + assert angle_wrap(theta=theta, mode="-pi:pi") == pytest.approx(wrap_mpi_pi(theta)) + assert angle_wrap(theta=-theta, mode="-pi:pi") == pytest.approx(wrap_mpi_pi(-theta)) + assert angle_wrap(theta=theta, mode="0:2pi") == pytest.approx(wrap_0_2pi(theta)) + assert angle_wrap(theta=-theta, mode="0:2pi") == pytest.approx(wrap_0_2pi(-theta)) + assert angle_wrap(theta=theta, mode="0:pi") == pytest.approx(wrap_0_pi(theta)) + assert angle_wrap(theta=-theta, mode="0:pi") == pytest.approx(wrap_0_pi(-theta)) + assert angle_wrap(theta=theta, mode="-pi/2:pi/2") == pytest.approx(wrap_mpi2_pi2(theta)) + assert angle_wrap(theta=-theta, mode="-pi/2:pi/2") == pytest.approx(wrap_mpi2_pi2(-theta)) + with pytest.raises(ValueError): angle_wrap(theta=theta, mode="foo") def test_angle_stats(self): theta = np.linspace(3 * pi / 2, 5 * pi / 2, 50) - self.assertAlmostEqual(angle_mean(theta), 0) - self.assertAlmostEqual(angle_std(theta), 0.9717284050981313) + assert angle_mean(theta) == pytest.approx(0) + assert angle_std(theta) == pytest.approx(0.9717284050981313) theta = np.linspace(pi / 2, 3 * pi / 2, 50) - self.assertAlmostEqual(angle_mean(theta), pi) - self.assertAlmostEqual(angle_std(theta), 0.9717284050981313) + assert angle_mean(theta) == pytest.approx(pi) + assert angle_std(theta) == pytest.approx(0.9717284050981313) def test_removesmall(self): v = np.r_[1, 2, 3] @@ -364,8 +366,3 @@ def test_removesmall(self): v = np.r_[1, 2, 3, 1e-10, -1e-10] nt.assert_array_almost_equal(removesmall(v, tol=1e8), [1, 2, 3, 0, 0]) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() diff --git a/tests/base/test_velocity.py b/tests/base/test_velocity.py index 13ee35e8..65519a16 100644 --- a/tests/base/test_velocity.py +++ b/tests/base/test_velocity.py @@ -10,7 +10,6 @@ import numpy as np import numpy.testing as nt -import unittest from math import pi import math from scipy.linalg import logm, expm @@ -19,10 +18,12 @@ from spatialmath.base.numeric import * from spatialmath.base.transformsNd import isR, t2r, r2t, rt2tr +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -class TestVelocity(unittest.TestCase): +class TestVelocity: def test_numjac(self): # test on algebraic example def f(X): @@ -99,28 +100,28 @@ def test_exp2jac(self): # gamma = [0.1, 0.2, 0.3] # R = rpy2r(gamma, order="zyx") # A = rotvelxform(R, representation="rpy/zyx") - # self.assertEqual(A.shape, (6, 6)) + # assert A.shape == (6, 6) # A3 = np.linalg.inv(A[3:6, 3:6]) # nt.assert_array_almost_equal(A3, rpy2jac(gamma, order="zyx")) # gamma = [0.1, 0.2, 0.3] # R = rpy2r(gamma, order="xyz") # A = rot2jac(R, representation="rpy/xyz") - # self.assertEqual(A.shape, (6, 6)) + # assert A.shape == (6, 6) # A3 = np.linalg.inv(A[3:6, 3:6]) # nt.assert_array_almost_equal(A3, rpy2jac(gamma, order="xyz")) # gamma = [0.1, 0.2, 0.3] # R = eul2r(gamma) # A = rot2jac(R, representation="eul") - # self.assertEqual(A.shape, (6, 6)) + # assert A.shape == (6, 6) # A3 = np.linalg.inv(A[3:6, 3:6]) # nt.assert_array_almost_equal(A3, eul2jac(gamma)) # gamma = [0.1, 0.2, 0.3] # R = trexp(gamma) # A = rot2jac(R, representation="exp") - # self.assertEqual(A.shape, (6, 6)) + # assert A.shape == (6, 6) # A3 = np.linalg.inv(A[3:6, 3:6]) # nt.assert_array_almost_equal(A3, exp2jac(gamma)) @@ -216,7 +217,7 @@ def test_angvelxform_dot_rpy_zyx(self): res = rotvelxform_inv_dot(gamma, gamma_d, representation=rep, full=False) nt.assert_array_almost_equal(Adot, res, decimal=4) - # @unittest.skip("bug in angvelxform_dot for exponential coordinates") + # @pytest.mark.skip("bug in angvelxform_dot for exponential coordinates") def test_angvelxform_dot_exp(self): rep = "exp" gamma = [0.1, 0.2, 0.3] @@ -270,8 +271,3 @@ def test_x_tr(self): # f = lambda gamma: angvelxform(gamma, options) # nt.assert_array_almost_equal(angvelxform_dot(gamma, options), numjac(f)) - - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_baseposelist.py b/tests/test_baseposelist.py index 30da5681..4e7193d6 100644 --- a/tests/test_baseposelist.py +++ b/tests/test_baseposelist.py @@ -1,12 +1,16 @@ -import unittest import numpy as np +import pytest from spatialmath.baseposelist import BasePoseList -# create a subclass to test with, its value is a scalar +# create a subclass to test with, its value is a list class X(BasePoseList): - def __init__(self, value=0, check=False): + def __init__(self, value=None, check=False): + if value is None: + value = 0 + elif not isinstance(value, list): + value = value super().__init__() - self.data = [value] + self.data = value @staticmethod def _identity(): @@ -20,127 +24,77 @@ def shape(self): def isvalid(x): return True -class TestBasePoseList(unittest.TestCase): +class TestBasePoseList: def test_constructor(self): x = X() - self.assertIsInstance(x, X) - self.assertEqual(len(x), 1) - - x = X.Empty() - self.assertIsInstance(x, X) - self.assertEqual(len(x), 0) - - x = X.Alloc(10) - self.assertIsInstance(x, X) - self.assertEqual(len(x), 10) - for xx in x: - self.assertEqual(xx.A, 0) - - def test_setget(self): - x = X.Alloc(10) - for i in range(0, 10): - x[i] = X(2 * i) - - for i,v in enumerate(x): - self.assertEqual(v.A, 2 * i) - - def test_append(self): - x = X.Empty() - for i in range(0, 10): - x.append(X(i+1)) - self.assertEqual(len(x), 10) - self.assertEqual([xx.A for xx in x], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - - def test_extend(self): - x = X.Alloc(5) - for i in range(0, 5): - x[i] = X(i + 1) - y = X.Alloc(5) - for i in range(0, 5): - y[i] = X(i + 10) - x.extend(y) - self.assertEqual(len(x), 10) - self.assertEqual([xx.A for xx in x], [1, 2, 3, 4, 5, 10, 11, 12, 13, 14]) - - def test_insert(self): - x = X.Alloc(10) - for i in range(0, 10): - x[i] = X(i + 1) - x.insert(5, X(100)) - self.assertEqual(len(x), 11) - self.assertEqual([xx.A for xx in x], [1, 2, 3, 4, 5, 100, 6, 7, 8, 9, 10]) - - def test_pop(self): - x = X.Alloc(10) - for i in range(0, 10): - x[i] = X(i + 1) - - y = x.pop() - self.assertEqual(len(y), 1) - self.assertEqual(y.A, 10) - self.assertEqual(len(x), 9) - self.assertEqual([xx.A for xx in x], [1, 2, 3, 4, 5, 6, 7, 8, 9]) - - def test_clear(self): - x = X.Alloc(10) - x.clear() - self.assertEqual(len(x), 0) - - def test_reverse(self): - x = X.Alloc(5) - for i in range(0, 5): - x[i] = X(i + 1) - x.reverse() - self.assertEqual(len(x), 5) - self.assertEqual([xx.A for xx in x], [5, 4, 3, 2, 1]) - - def test_binop(self): - x = X(2) - y = X(3) - - # singelton x singleton - self.assertEqual(x.binop(y, lambda x, y: x * y), [6]) - self.assertEqual(x.binop(y, lambda x, y: x * y, list1=False), 6) - - y = X.Alloc(5) - for i in range(0, 5): - y[i] = X(i + 1) - - # singelton x non-singleton - self.assertEqual(x.binop(y, lambda x, y: x * y), [2, 4, 6, 8, 10]) - self.assertEqual(x.binop(y, lambda x, y: x * y, list1=False), [2, 4, 6, 8, 10]) - - # non-singelton x singleton - self.assertEqual(y.binop(x, lambda x, y: x * y), [2, 4, 6, 8, 10]) - self.assertEqual(y.binop(x, lambda x, y: x * y, list1=False), [2, 4, 6, 8, 10]) - - # non-singelton x non-singleton - self.assertEqual(y.binop(y, lambda x, y: x * y), [1, 4, 9, 16, 25]) - self.assertEqual(y.binop(y, lambda x, y: x * y, list1=False), [1, 4, 9, 16, 25]) - - def test_unop(self): - x = X(2) - - f = lambda x: 2 * x - - self.assertEqual(x.unop(f), [4]) - self.assertEqual(x.unop(f, matrix=True), np.r_[4]) - - x = X.Alloc(5) - for i in range(0, 5): - x[i] = X(i + 1) - - self.assertEqual(x.unop(f), [2, 4, 6, 8, 10]) - y = x.unop(f, matrix=True) - self.assertEqual(y.shape, (5,1)) - self.assertTrue(np.all(y - np.c_[2, 4, 6, 8, 10].T == 0)) - - def test_arghandler(self): - pass - -# ---------------------------------------------------------------------------------------# -if __name__ == '__main__': - - unittest.main() \ No newline at end of file + assert isinstance(x, X) + assert len(x) == 1 + + @pytest.mark.parametrize( + 'x,y,list1,expected', + [ + (X(2), X(3), True, [6]), + (X(2), X(3), False, 6), + ], + ) + def test_binop(self, x, y, list1, expected): + assert x.binop(y, lambda x, y: x * y, list1=list1) == expected + + @pytest.mark.parametrize( + 'x,matrix,expected', + [ + (X(2), False, [4]), + (X(2), True, np.array(4)), + ], + ) + def test_unop(self, x, matrix, expected): + result = x.unop(lambda x: 2*x, matrix=matrix) + if isinstance(result, np.ndarray): + assert (result == expected).all() + else: + assert result == expected + +class TestConcreteSubclasses: + """ + Check consistency of methods in concrete subclasses + """ + from spatialmath import ( + SO2, + SE2, + SO3, + SE3, + Quaternion, + UnitQuaternion, + Twist2, + Twist3, + SpatialVelocity, + SpatialAcceleration, + SpatialForce, + SpatialMomentum, + Line3, + ) + concrete_subclasses = [ + SO2, + SE2, + SO3, + SE3, + Quaternion, + UnitQuaternion, + Twist2, + Twist3, + SpatialVelocity, + SpatialAcceleration, + SpatialForce, + SpatialMomentum, + Line3, + ] + + @pytest.mark.parametrize( + 'cls', + concrete_subclasses, + ) + def test_bare_init(self, cls): + with pytest.raises(TypeError): + cls() diff --git a/tests/test_dualquaternion.py b/tests/test_dualquaternion.py index 39c5fc03..2bea5ec6 100644 --- a/tests/test_dualquaternion.py +++ b/tests/test_dualquaternion.py @@ -3,7 +3,6 @@ import numpy as np import numpy.testing as nt -import unittest from spatialmath import DualQuaternion, UnitDualQuaternion, Quaternion, SE3 from spatialmath import base @@ -20,7 +19,7 @@ def qcompare(x, y): y = y.A nt.assert_array_almost_equal(x, y) -class TestDualQuaternion(unittest.TestCase): +class TestDualQuaternion: def test_init(self): @@ -40,8 +39,8 @@ def test_pure(self): def test_strings(self): dq = DualQuaternion(Quaternion([1.,2,3,4]), Quaternion([5.,6,7,8])) - self.assertIsInstance(str(dq), str) - self.assertIsInstance(repr(dq), str) + assert isinstance(str(dq), str) + assert isinstance(repr(dq), str) def test_conj(self): dq = DualQuaternion(Quaternion([1.,2,3,4]), Quaternion([5.,6,7,8])) @@ -69,8 +68,8 @@ def test_matrix(self): dq1 = DualQuaternion(Quaternion([1.,2,3,4]), Quaternion([5.,6,7,8])) M = dq1.matrix() - self.assertIsInstance(M, np.ndarray) - self.assertEqual(M.shape, (8,8)) + assert isinstance(M, np.ndarray) + assert M.shape == (8,8) def test_multiply(self): dq1 = DualQuaternion(Quaternion([1.,2,3,4]), Quaternion([5.,6,7,8])) @@ -84,7 +83,7 @@ def test_unit(self): pass -class TestUnitDualQuaternion(unittest.TestCase): +class TestUnitDualQuaternion: def test_init(self): @@ -108,9 +107,3 @@ def test_multiply(self): d = d1 * d2 nt.assert_array_almost_equal(d.SE3().A, T.A) - - -# ---------------------------------------------------------------------------------------# -if __name__ == '__main__': # pragma: no cover - - unittest.main() diff --git a/tests/test_geom2d.py b/tests/test_geom2d.py index 49aa1d8b..638a98bf 100755 --- a/tests/test_geom2d.py +++ b/tests/test_geom2d.py @@ -9,20 +9,19 @@ from spatialmath.geom2d import * from spatialmath.pose2d import SE2 -import unittest import pytest import sys import numpy.testing as nt import spatialmath.base as smb -class Polygon2Test(unittest.TestCase): +class TestPolygon2: # Primitives def test_constructor1(self): p = Polygon2([(1, 2), (3, 2), (2, 4)]) - self.assertIsInstance(p, Polygon2) - self.assertEqual(len(p), 3) - self.assertEqual(str(p), "Polygon2 with 4 vertices") + assert isinstance(p, Polygon2) + assert len(p) == 3 + assert str(p) == "Polygon2 with 4 vertices" nt.assert_array_equal(p.vertices(), np.array([[1, 3, 2], [2, 2, 4]])) nt.assert_array_equal( p.vertices(unique=False), np.array([[1, 3, 2, 1], [2, 2, 4, 2]]) @@ -31,61 +30,61 @@ def test_constructor1(self): def test_methods(self): p = Polygon2(np.array([[-1, 1, 1, -1], [-1, -1, 1, 1]])) - self.assertEqual(p.area(), 4) - self.assertEqual(p.moment(0, 0), 4) - self.assertEqual(p.moment(1, 0), 0) - self.assertEqual(p.moment(0, 1), 0) + assert p.area() == 4 + assert p.moment(0, 0) == 4 + assert p.moment(1, 0) == 0 + assert p.moment(0, 1) == 0 nt.assert_array_equal(p.centroid(), np.r_[0, 0]) - self.assertEqual(p.radius(), np.sqrt(2)) + assert p.radius() == np.sqrt(2) nt.assert_array_equal(p.bbox(), np.r_[-1, -1, 1, 1]) def test_contains(self): p = Polygon2(np.array([[-1, 1, 1, -1], [-1, -1, 1, 1]])) - self.assertTrue(p.contains([0, 0], radius=1e-6)) - self.assertTrue(p.contains([1, 0], radius=1e-6)) - self.assertTrue(p.contains([-1, 0], radius=1e-6)) - self.assertTrue(p.contains([0, 1], radius=1e-6)) - self.assertTrue(p.contains([0, -1], radius=1e-6)) + assert p.contains([0, 0], radius=1e-6) + assert p.contains([1, 0], radius=1e-6) + assert p.contains([-1, 0], radius=1e-6) + assert p.contains([0, 1], radius=1e-6) + assert p.contains([0, -1], radius=1e-6) - self.assertFalse(p.contains([0, 1.1], radius=1e-6)) - self.assertFalse(p.contains([0, -1.1], radius=1e-6)) - self.assertFalse(p.contains([1.1, 0], radius=1e-6)) - self.assertFalse(p.contains([-1.1, 0], radius=1e-6)) + assert not p.contains([0, 1.1], radius=1e-6) + assert not p.contains([0, -1.1], radius=1e-6) + assert not p.contains([1.1, 0], radius=1e-6) + assert not p.contains([-1.1, 0], radius=1e-6) - self.assertTrue(p.contains(np.r_[0, -1], radius=1e-6)) - self.assertFalse(p.contains(np.r_[0, 1.1], radius=1e-6)) + assert p.contains(np.r_[0, -1], radius=1e-6) + assert not p.contains(np.r_[0, 1.1], radius=1e-6) def test_transform(self): p = Polygon2(np.array([[-1, 1, 1, -1], [-1, -1, 1, 1]])) p = p.transformed(SE2(2, 3)) - self.assertEqual(p.area(), 4) - self.assertEqual(p.moment(0, 0), 4) - self.assertEqual(p.moment(1, 0), 8) - self.assertEqual(p.moment(0, 1), 12) + assert p.area() == 4 + assert p.moment(0 == 0, 4) + assert p.moment(1 == 0, 8) + assert p.moment(0 == 1, 12) nt.assert_array_equal(p.centroid(), np.r_[2, 3]) def test_intersect(self): p1 = Polygon2(np.array([[-1, 1, 1, -1], [-1, -1, 1, 1]])) p2 = p1.transformed(SE2(2, 3)) - self.assertFalse(p1.intersects(p2)) + assert not p1.intersects(p2) p2 = p1.transformed(SE2(1, 1)) - self.assertTrue(p1.intersects(p2)) + assert p1.intersects(p2) - self.assertTrue(p1.intersects(p1)) + assert p1.intersects(p1) def test_intersect_line(self): p = Polygon2(np.array([[-1, 1, 1, -1], [-1, -1, 1, 1]])) l = Line2.Join((-10, 0), (10, 0)) - self.assertTrue(p.intersects(l)) + assert p.intersects(l) l = Line2.Join((-10, 1.1), (10, 1.1)) - self.assertFalse(p.intersects(l)) + assert not p.intersects(l) @pytest.mark.skipif( sys.platform.startswith("darwin") and sys.version_info < (3, 11), @@ -109,10 +108,10 @@ def test_edges(self): # p.move(SE2(0, 0, 0.7)) -class Line2Test(unittest.TestCase): +class TestLine2: def test_constructor(self): l = Line2([1, 2, 3]) - self.assertEqual(str(l), "Line2: [1. 2. 3.]") + assert str(l) == "Line2: [1. 2. 3.]" l = Line2.Join((0, 0), (1, 2)) nt.assert_equal(l.line, [-2, 1, 0]) @@ -123,55 +122,55 @@ def test_constructor(self): def test_contains(self): l = Line2.Join((0, 0), (1, 2)) - self.assertTrue(l.contains((0, 0))) - self.assertTrue(l.contains((1, 2))) - self.assertTrue(l.contains((2, 4))) + assert l.contains((0, 0)) + assert l.contains((1, 2)) + assert l.contains((2, 4)) def test_intersect(self): l1 = Line2.Join((0, 0), (2, 0)) # y = 0 l2 = Line2.Join((0, 1), (2, 1)) # y = 1 - self.assertFalse(l1.intersect(l2)) + assert not l1.intersect(l2) l2 = Line2.Join((2, 1), (2, -1)) # x = 2 - self.assertTrue(l1.intersect(l2)) + assert l1.intersect(l2) def test_intersect_segment(self): l1 = Line2.Join((0, 0), (2, 0)) # y = 0 - self.assertFalse(l1.intersect_segment((2, 1), (2, 3))) - self.assertTrue(l1.intersect_segment((2, 1), (2, -1))) + assert not l1.intersect_segment((2, 1), (2, 3)) + assert l1.intersect_segment((2, 1), (2, -1)) -class EllipseTest(unittest.TestCase): +class TestEllipse: def test_constructor(self): E = np.array([[1, 1], [1, 3]]) e = Ellipse(E=E) nt.assert_almost_equal(e.E, E) nt.assert_almost_equal(e.centre, [0, 0]) - self.assertAlmostEqual(e.theta, 1.1780972450961724) + assert e.theta == pytest.approx(1.1780972450961724) e = Ellipse(radii=(1, 2), theta=0) nt.assert_almost_equal(e.E, np.diag([1, 0.25])) nt.assert_almost_equal(e.centre, [0, 0]) nt.assert_almost_equal(e.radii, [1, 2]) - self.assertAlmostEqual(e.theta, 0) + assert e.theta == pytest.approx(0) e = Ellipse(radii=(1, 2), theta=np.pi / 2) nt.assert_almost_equal(e.E, np.diag([0.25, 1])) nt.assert_almost_equal(e.centre, [0, 0]) nt.assert_almost_equal(e.radii, [2, 1]) - self.assertAlmostEqual(e.theta, np.pi / 2) + assert e.theta == pytest.approx(np.pi / 2) E = np.array([[1, 1], [1, 3]]) e = Ellipse(E=E, centre=[3, 4]) nt.assert_almost_equal(e.E, E) nt.assert_almost_equal(e.centre, [3, 4]) - self.assertAlmostEqual(e.theta, 1.1780972450961724) + assert e.theta == pytest.approx(1.1780972450961724) e = Ellipse(radii=(1, 2), theta=0, centre=[3, 4]) nt.assert_almost_equal(e.E, np.diag([1, 0.25])) nt.assert_almost_equal(e.centre, [3, 4]) nt.assert_almost_equal(e.radii, [1, 2]) - self.assertAlmostEqual(e.theta, 0) + assert e.theta == pytest.approx(0) def test_Polynomial(self): e = Ellipse.Polynomial([2, 3, 1, 0, 0, -1]) @@ -224,24 +223,20 @@ def test_FromPoints(self): def test_misc(self): e = Ellipse(radii=(1, 2), theta=np.pi / 2) - self.assertIsInstance(str(e), str) + assert isinstance(str(e), str) - self.assertAlmostEqual(e.area, np.pi * 2) + assert e.area == pytest.approx(np.pi * 2) e = Ellipse(radii=(1, 2), theta=0) - self.assertTrue(e.contains((0, 0))) - self.assertTrue(e.contains((1, 0))) - self.assertTrue(e.contains((-1, 0))) - self.assertTrue(e.contains((0, 2))) - self.assertTrue(e.contains((0, -2))) - - self.assertFalse(e.contains((1.1, 0))) - self.assertFalse(e.contains((-1.1, 0))) - self.assertFalse(e.contains((0, 2.1))) - self.assertFalse(e.contains((0, -2.1))) - - self.assertEqual(e.contains(np.array([[0, 0], [3, 3]]).T), [True, False]) - - -if __name__ == "__main__": - unittest.main() + assert e.contains((0, 0)) + assert e.contains((1, 0)) + assert e.contains((-1, 0)) + assert e.contains((0, 2)) + assert e.contains((0, -2)) + + assert not e.contains((1.1, 0)) + assert not e.contains((-1.1, 0)) + assert not e.contains((0, 2.1)) + assert not e.contains((0, -2.1)) + + assert e.contains(np.array([[0, 0], [3, 3]]).T) == [True, False] diff --git a/tests/test_geom3d.py b/tests/test_geom3d.py index 7a743dd5..d2af40ae 100755 --- a/tests/test_geom3d.py +++ b/tests/test_geom3d.py @@ -9,35 +9,34 @@ from spatialmath.geom3d import * from spatialmath.pose3d import SE3 -import unittest import numpy.testing as nt import spatialmath.base as base import pytest import sys -class Line3Test(unittest.TestCase): +class TestLine3: # Primitives def test_constructor1(self): # construct from 6-vector - with self.assertRaises(ValueError): + with pytest.raises(ValueError): L = Line3([1, 2, 3, 4, 5, 6], check=True) L = Line3([1, 2, 3, 4, 5, 6], check=False) - self.assertIsInstance(L, Line3) + assert isinstance(L, Line3) nt.assert_array_almost_equal(L.v, np.r_[1, 2, 3]) nt.assert_array_almost_equal(L.w, np.r_[4, 5, 6]) # construct from object L2 = Line3(L, check=False) - self.assertIsInstance(L, Line3) + assert isinstance(L, Line3) nt.assert_array_almost_equal(L2.v, np.r_[1, 2, 3]) nt.assert_array_almost_equal(L2.w, np.r_[4, 5, 6]) # construct from point and direction L = Line3.PointDir([1, 2, 3], [4, 5, 6]) - self.assertTrue(L.contains([1, 2, 3])) + assert L.contains([1, 2, 3]) nt.assert_array_almost_equal(L.uw, base.unitvec([4, 5, 6])) def test_vec(self): @@ -56,33 +55,33 @@ def test_constructor2(self): # TODO, all combos of list and ndarray # test all possible input shapes # L2, = Line3(P, Q) - # self.assertEqual(double(L2), double(L)) + # assert double(L2) == double(L) # L2, = Line3(P, Q') - # self.assertEqual(double(L2), double(L)) + # assert double(L2) == double(L) # L2, = Line3(P', Q') - # self.assertEqual(double(L2), double(L)) + # assert double(L2) == double(L) # L2, = Line3(P, Q) - # self.assertEqual(double(L2), double(L)) + # assert double(L2) == double(L) # # planes constructor # P = [10, 11, 12]'; w = [1, 2, 3] # L = Line3.PointDir(P, w) - # self.assertEqual(double(L), [cross(w,P) w]'); %FAIL + # assertEqual(double(L), [cross(w,P) w]'); %FAIL # L2, = Line3.PointDir(P', w) - # self.assertEqual(double(L2), double(L)) + # assert double(L2) == double(L) # L2, = Line3.PointDir(P, w') - # self.assertEqual(double(L2), double(L)) + # assert double(L2) == double(L) # L2, = Line3.PointDir(P', w') - # self.assertEqual(double(L2), double(L)) + # assert double(L2) == double(L) def test_pp(self): # validate pp and ppd L = Line3.Join([-1, 1, 2], [1, 1, 2]) nt.assert_array_almost_equal(L.pp, np.r_[0, 1, 2]) - self.assertEqual(L.ppd, math.sqrt(5)) + assert L.ppd == math.sqrt(5) # validate pp - self.assertTrue(L.contains(L.pp)) + assert L.contains(L.pp) def test_contains(self): P = [2, 3, 7] @@ -90,9 +89,9 @@ def test_contains(self): L = Line3.Join(P, Q) # validate contains - self.assertTrue(L.contains([2, 3, 7])) - self.assertTrue(L.contains([2, 1, 0])) - self.assertFalse(L.contains([2, 1, 4])) + assert L.contains([2, 3, 7]) + assert L.contains([2, 1, 0]) + assert not L.contains([2, 1, 4]) def test_closest(self): P = [2, 3, 7] @@ -101,29 +100,29 @@ def test_closest(self): p, d = L.closest_to_point(P) nt.assert_array_almost_equal(p, P) - self.assertAlmostEqual(d, 0) + assert d == pytest.approx(0) # validate closest with given points and origin p, d = L.closest_to_point(Q) nt.assert_array_almost_equal(p, Q) - self.assertAlmostEqual(d, 0) + assert d == pytest.approx(0) L = Line3.Join([-1, 1, 2], [1, 1, 2]) p, d = L.closest_to_point([0, 1, 2]) nt.assert_array_almost_equal(p, np.r_[0, 1, 2]) - self.assertAlmostEqual(d, 0) + assert d == pytest.approx(0) p, d = L.closest_to_point([5, 1, 2]) nt.assert_array_almost_equal(p, np.r_[5, 1, 2]) - self.assertAlmostEqual(d, 0) + assert d == pytest.approx(0) p, d = L.closest_to_point([0, 0, 0]) nt.assert_array_almost_equal(p, L.pp) - self.assertEqual(d, L.ppd) + assert d == L.ppd p, d = L.closest_to_point([5, 1, 0]) nt.assert_array_almost_equal(p, [5, 1, 2]) - self.assertAlmostEqual(d, 2) + assert d == pytest.approx(2) @pytest.mark.skipif( sys.platform.startswith("darwin") and sys.version_info < (3, 11), @@ -140,7 +139,7 @@ def test_plot(self): ax.set_ylim3d(-10, 10) ax.set_zlim3d(-10, 10) - L.plot(color="red", linewidth=2) + Line3.plot([L], color="red", linewidth=2) def test_eq(self): w = np.r_[1, 2, 3] @@ -150,11 +149,11 @@ def test_eq(self): L2 = Line3.Join(P + 2 * w, P + 5 * w) L3 = Line3.Join(P + np.r_[1, 0, 0], P + w) - self.assertTrue(L1 == L2) - self.assertFalse(L1 == L3) + assert L1 == L2 + assert not L1 == L3 - self.assertFalse(L1 != L2) - self.assertTrue(L1 != L3) + assert not L1 != L2 + assert L1 != L3 def test_skew(self): P = [2, 3, 7] @@ -163,7 +162,7 @@ def test_skew(self): m = L.skew() - self.assertEqual(m.shape, (4, 4)) + assert m.shape == (4, 4) nt.assert_array_almost_equal(m + m.T, np.zeros((4, 4))) def test_rmul(self): @@ -173,7 +172,7 @@ def test_rmul(self): # check transformation by SE3 - L2 = SE3() * L + L2 = SE3.identity() * L p = L2.intersect_plane([0, 0, 1, 0])[0] # intersects z=0 nt.assert_array_almost_equal(p, [1, 2, 0]) @@ -196,15 +195,15 @@ def test_parallel(self): # L1, || L2, but doesnt intersect # L1, intersects L3 - self.assertTrue(L1.isparallel(L1)) - self.assertTrue(L1 | L1) + assert L1.isparallel(L1) + assert L1 | L1 - self.assertTrue(L1.isparallel(L2)) - self.assertTrue(L1 | L2) - self.assertTrue(L2.isparallel(L1)) - self.assertTrue(L2 | L1) - self.assertFalse(L1.isparallel(L3)) - self.assertFalse(L1 | L3) + assert L1.isparallel(L2) + assert L1 | L2 + assert L2.isparallel(L1) + assert L2 | L1 + assert not L1.isparallel(L3) + assert not L1 | L3 def test_intersect(self): L1 = Line3.PointDir([4, 5, 6], [1, 2, 3]) @@ -214,27 +213,23 @@ def test_intersect(self): # L1, || L2, but doesnt intersect # L3, intersects L4 - self.assertFalse( - L1 ^ L2, - ) + assert not L1 ^ L2 - self.assertTrue( - L3 ^ L4, - ) + assert L3 ^ L4 def test_commonperp(self): L1 = Line3.PointDir([4, 5, 6], [0, 0, 1]) L2 = Line3.PointDir([6, 5, 6], [0, 1, 0]) - self.assertFalse(L1 | L2) - self.assertFalse(L1 ^ L2) + assert not L1 | L2 + assert not L1 ^ L2 - self.assertEqual(L1.distance(L2), 2) + assert L1.distance(L2) == 2 L = L1.commonperp(L2) # common perp intersects both lines - self.assertTrue(L ^ L1) - self.assertTrue(L ^ L2) + assert L ^ L1 + assert L ^ L2 def test_line(self): # mindist @@ -252,9 +247,9 @@ def test_contains(self): Q = [2, 1, 0] L = Line3.Join(P, Q) - self.assertTrue(L.contains(L.point(0))) - self.assertTrue(L.contains(L.point(1))) - self.assertTrue(L.contains(L.point(-1))) + assert L.contains(L.point(0)) + assert L.contains(L.point(1)) + assert L.contains(L.point(-1)) def test_point(self): P = [2, 3, 7] @@ -272,7 +267,7 @@ def test_char(self): L = Line3.Join(P, Q) s = str(L) - self.assertIsInstance(s, str) + assert isinstance(s, str) def test_plane(self): xyplane = [0, 0, 1, 0] @@ -304,8 +299,8 @@ def test_methods(self): px1 = Line3.Join([0, 1, 0], [1, 1, 0]) # offset x-axis - self.assertEqual(px.ppd, 0) - self.assertEqual(px1.ppd, 1) + assert px.ppd == 0 + assert px1.ppd == 1 nt.assert_array_almost_equal(px1.pp, [0, 1, 0]) px.intersects(px) @@ -321,6 +316,5 @@ def test_methods(self): # px.intersect_plane(plane) # py.intersect_plane(plane) - -if __name__ == "__main__": - unittest.main() + def test_identity(self): + nt.assert_array_equal(Line3.identity().A, np.zeros(6,)) diff --git a/tests/test_pose2d.py b/tests/test_pose2d.py index d6d96813..19534d9b 100755 --- a/tests/test_pose2d.py +++ b/tests/test_pose2d.py @@ -1,6 +1,7 @@ import numpy.testing as nt +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -import unittest import sys import pytest @@ -30,18 +31,12 @@ def array_compare(x, y): nt.assert_array_almost_equal(x, y) -class TestSO2(unittest.TestCase): +class TestSO2: @classmethod def tearDownClass(cls): plt.close("all") def test_constructor(self): - # null case - x = SO2() - self.assertIsInstance(x, SO2) - self.assertEqual(len(x), 1) - array_compare(x.A, np.eye(2, 2)) - ## from angle array_compare(SO2(0).A, np.eye(2)) @@ -60,15 +55,7 @@ def test_constructor(self): array_compare(SO2(rot2(pi / 2)).R, rot2(pi / 2)) - ## vectorised forms of R - R = SO2.Empty() - for theta in [-pi / 2, 0, pi / 2, pi]: - R.append(SO2(theta)) - self.assertEqual(len(R), 4) - array_compare(R[0], rot2(-pi / 2)) - array_compare(R[3], rot2(pi)) - - # TODO self.assertEqual(SO2(R).R, R) + # TODO assert SO2(R).R == R ## copy constructor r = SO2(0.3) @@ -77,118 +64,60 @@ def test_constructor(self): r = SO2(0.4) array_compare(c, SO2(0.3)) - def test_concat(self): - x = SO2() - xx = SO2([x, x, x, x]) - - self.assertIsInstance(xx, SO2) - self.assertEqual(len(xx), 4) - def test_primitive_convert(self): # char - s = str(SO2()) - self.assertIsInstance(s, str) + s = str(SO2.identity()) + assert isinstance(s, str) def test_shape(self): - a = SO2() - self.assertEqual(a._A.shape, a.shape) + a = SO2.identity() + assert a._A.shape == a.shape def test_constructor_Exp(self): array_compare(SO2.Exp(skew(0.3)).R, rot2(0.3)) array_compare(SO2.Exp(0.3).R, rot2(0.3)) - x = SO2.Exp([0, 0.3, 1]) - self.assertEqual(len(x), 3) - array_compare(x[0], rot2(0)) - array_compare(x[1], rot2(0.3)) - array_compare(x[2], rot2(1)) - - x = SO2.Exp([skew(x) for x in [0, 0.3, 1]]) - self.assertEqual(len(x), 3) - array_compare(x[0], rot2(0)) - array_compare(x[1], rot2(0.3)) - array_compare(x[2], rot2(1)) - def test_isa(self): - self.assertTrue(SO2.isvalid(rot2(0))) + assert SO2.isvalid(rot2(0)) - self.assertFalse(SO2.isvalid(1)) + assert not SO2.isvalid(1) def test_resulttype(self): - r = SO2() - self.assertIsInstance(r, SO2) - - self.assertIsInstance(r * r, SO2) - - self.assertIsInstance(r / r, SO2) - - self.assertIsInstance(r.inv(), SO2) - - def test_multiply(self): - vx = np.r_[1, 0] - vy = np.r_[0, 1] - - r0 = SO2(0) - r1 = SO2(pi / 2) - r2 = SO2(pi) - u = SO2() - - ## SO2-SO2, product - # scalar x scalar - - array_compare(r0 * u, r0) - array_compare(u * r0, r0) - - # vector x vector - array_compare( - SO2([r0, r1, r2]) * SO2([r2, r0, r1]), SO2([r0 * r2, r1 * r0, r2 * r1]) - ) - - # scalar x vector - array_compare(r1 * SO2([r0, r1, r2]), SO2([r1 * r0, r1 * r1, r1 * r2])) - - # vector x scalar - array_compare(SO2([r0, r1, r2]) * r2, SO2([r0 * r2, r1 * r2, r2 * r2])) - - ## SO2-vector product - # scalar x scalar - - array_compare(r1 * vx, np.c_[vy]) - - # vector x vector - # array_compare(SO2([r0, r1, r0]) * np.c_[vy, vx, vx], np.c_[vy, vy, vx]) - - # scalar x vector - array_compare(r1 * np.c_[vx, vy, -vx], np.c_[vy, -vx, -vy]) - - # vector x scalar - array_compare(SO2([r0, r1, r2]) * vy, np.c_[vy, -vx, -vy]) + r = SO2.identity() + assert isinstance(r, SO2) - def test_divide(self): - r0 = SO2(0) - r1 = SO2(pi / 2) - r2 = SO2(pi) - u = SO2() + assert isinstance(r * r, SO2) - # scalar / scalar - # implicity tests inv + assert isinstance(r / r, SO2) - array_compare(r1 / u, r1) - array_compare(r1 / r1, u) + assert isinstance(r.inv(), SO2) - # vector / vector - array_compare( - SO2([r0, r1, r2]) / SO2([r2, r1, r0]), SO2([r0 / r2, r1 / r1, r2 / r0]) - ) - - # vector / scalar - array_compare(SO2([r0, r1, r2]) / r1, SO2([r0 / r1, r1 / r1, r2 / r1])) + @pytest.mark.parametrize( + 'left, right, expected', + [ + (SO2(0), SO2.identity(), SO2(0)), + (SO2.identity(), SO2(0), SO2(0)), + (SO2(pi/2), np.r_[1, 0], np.c_[np.r_[0, 1]]), + ], + ) + def test_multiply(self, left, right, expected): + array_compare(left * right, expected) + + @pytest.mark.parametrize( + 'left, right, expected', + [ + (SO2(pi/2), SO2.identity(), SO2(pi/2)), + (SO2(pi/2), SO2(pi/2), SO2.identity()), + ], + ) + def test_divide(self, left, right, expected): + array_compare(left / right, expected) def test_conversions(self): T = SO2(pi / 2).SE2() - self.assertIsInstance(T, SE2) + assert isinstance(T, SE2) ## Lie stuff th = 0.3 @@ -199,23 +128,18 @@ def test_miscellany(self): r = SO2( 0.3, ) - self.assertAlmostEqual(np.linalg.det(r.A), 1) + assert np.linalg.det(r.A) == pytest.approx(1) - self.assertEqual(r.N, 2) + assert r.N == 2 - self.assertFalse(r.isSE) + assert not r.isSE def test_printline(self): R = SO2(0.3) R.printline() # s = R.printline(file=None) - # self.assertIsInstance(s, str) - - R = SO2([0.3, 0.4, 0.5]) - s = R.printline(file=None) - # self.assertIsInstance(s, str) - # self.assertEqual(s.count('\n'), 2) + # assert isinstance(s, str) @pytest.mark.skipif( sys.platform.startswith("darwin") and sys.version_info < (3, 11), @@ -231,121 +155,88 @@ def test_plot(self): # R.animate() # R.animate(start=R2) + def test_identity(self): + array_compare(SO2.identity().A, np.eye(2, 2)) # ============================== SE2 =====================================# -class TestSE2(unittest.TestCase): +class TestSE2: @classmethod def tearDownClass(cls): plt.close("all") def test_constructor(self): - self.assertIsInstance(SE2(), SE2) - - ## null - array_compare(SE2().A, np.eye(3, 3)) - # from x,y x = SE2(2, 3) - self.assertIsInstance(x, SE2) - self.assertEqual(len(x), 1) + assert isinstance(x, SE2) + assert len(x) == 1 array_compare(x.A, np.array([[1, 0, 2], [0, 1, 3], [0, 0, 1]])) x = SE2([2, 3]) - self.assertIsInstance(x, SE2) - self.assertEqual(len(x), 1) + assert isinstance(x, SE2) + assert len(x) == 1 array_compare(x.A, np.array([[1, 0, 2], [0, 1, 3], [0, 0, 1]])) # from x,y,theta x = SE2(2, 3, pi / 2) - self.assertIsInstance(x, SE2) - self.assertEqual(len(x), 1) + assert isinstance(x, SE2) + assert len(x) == 1 array_compare(x.A, np.array([[0, -1, 2], [1, 0, 3], [0, 0, 1]])) x = SE2([2, 3, pi / 2]) - self.assertIsInstance(x, SE2) - self.assertEqual(len(x), 1) + assert isinstance(x, SE2) + assert len(x) == 1 array_compare(x.A, np.array([[0, -1, 2], [1, 0, 3], [0, 0, 1]])) x = SE2(2, 3, 90, unit="deg") - self.assertIsInstance(x, SE2) - self.assertEqual(len(x), 1) + assert isinstance(x, SE2) + assert len(x) == 1 array_compare(x.A, np.array([[0, -1, 2], [1, 0, 3], [0, 0, 1]])) x = SE2([2, 3, 90], unit="deg") - self.assertIsInstance(x, SE2) - self.assertEqual(len(x), 1) + assert isinstance(x, SE2) + assert len(x) == 1 array_compare(x.A, np.array([[0, -1, 2], [1, 0, 3], [0, 0, 1]])) ## T T = transl2(1, 2) @ trot2(0.3) x = SE2(T) - self.assertIsInstance(x, SE2) - self.assertEqual(len(x), 1) + assert isinstance(x, SE2) + assert len(x) == 1 array_compare(x.A, T) ## copy constructor TT = SE2(x) array_compare(SE2(TT).A, T) - x = SE2() + x = SE2.identity() array_compare(SE2(TT).A, T) - ## vectorised versions - - T1 = transl2(1, 2) @ trot2(0.3) - T2 = transl2(1, -2) @ trot2(-0.4) - - x = SE2([T1, T2, T1, T2]) - self.assertIsInstance(x, SE2) - self.assertEqual(len(x), 4) - array_compare(x[0], T1) - array_compare(x[1], T2) - def test_shape(self): - a = SE2() - self.assertEqual(a._A.shape, a.shape) - - def test_concat(self): - x = SE2() - xx = SE2([x, x, x, x]) - - self.assertIsInstance(xx, SE2) - self.assertEqual(len(xx), 4) + a = SE2.identity() + assert a._A.shape == a.shape def test_constructor_Exp(self): array_compare(SE2.Exp(skewa([1, 2, 0])), transl2(1, 2)) array_compare(SE2.Exp(np.r_[1, 2, 0]), transl2(1, 2)) - x = SE2.Exp([(1, 2, 0), (3, 4, 0), (5, 6, 0)]) - self.assertEqual(len(x), 3) - array_compare(x[0], transl2(1, 2)) - array_compare(x[1], transl2(3, 4)) - array_compare(x[2], transl2(5, 6)) - - x = SE2.Exp([skewa(x) for x in [(1, 2, 0), (3, 4, 0), (5, 6, 0)]]) - self.assertEqual(len(x), 3) - array_compare(x[0], transl2(1, 2)) - array_compare(x[1], transl2(3, 4)) - array_compare(x[2], transl2(5, 6)) - def test_isa(self): - self.assertTrue(SE2.isvalid(trot2(0))) - self.assertFalse(SE2.isvalid(1)) + assert SE2.isvalid(trot2(0)) + assert not SE2.isvalid(1) def test_resulttype(self): - t = SE2() - self.assertIsInstance(t, SE2) - self.assertIsInstance(t * t, SE2) - self.assertIsInstance(t / t, SE2) - self.assertIsInstance(t.inv(), SE2) - self.assertIsInstance(t + t, np.ndarray) - self.assertIsInstance(t + 1, np.ndarray) - self.assertIsInstance(t - 1, np.ndarray) - self.assertIsInstance(1 + t, np.ndarray) - self.assertIsInstance(1 - t, np.ndarray) - self.assertIsInstance(2 * t, np.ndarray) - self.assertIsInstance(t * 2, np.ndarray) + t = SE2.identity() + assert isinstance(t, SE2) + assert isinstance(t * t, SE2) + assert isinstance(t / t, SE2) + assert isinstance(t.inv(), SE2) + assert isinstance(t + t, np.ndarray) + assert isinstance(t + 1, np.ndarray) + assert isinstance(t - 1, np.ndarray) + assert isinstance(1 + t, np.ndarray) + assert isinstance(1 - t, np.ndarray) + assert isinstance(2 * t, np.ndarray) + assert isinstance(t * 2, np.ndarray) def test_inverse(self): T1 = transl2(1, 2) @ trot2(0.3) @@ -357,11 +248,6 @@ def test_inverse(self): array_compare(TT1 * TT1.inv(), np.eye(3)) array_compare(TT1.inv() * TT1, np.eye(3)) - # vector case - TT2 = SE2([TT1, TT1]) - u = [np.eye(3), np.eye(3)] - array_compare(TT2.inv() * TT1, u) - def test_Rt(self): TT1 = SE2.Rand() T1 = TT1.A @@ -371,11 +257,8 @@ def test_Rt(self): array_compare(TT1.A, T1) array_compare(TT1.R, R1) array_compare(TT1.t, t1) - self.assertEqual(TT1.x, t1[0]) - self.assertEqual(TT1.y, t1[1]) - - TT = SE2([TT1, TT1, TT1]) - array_compare(TT.t, [t1, t1, t1]) + assert TT1.x == t1[0] + assert TT1.y == t1[1] def test_arith(self): TT1 = SE2.Rand() @@ -383,7 +266,7 @@ def test_arith(self): TT2 = SE2.Rand() T2 = TT2.A - I = SE2() + I = SE2.identity() ## SE2, * SE2, product # scalar x scalar @@ -393,18 +276,6 @@ def test_arith(self): array_compare(TT1 * I, T1) array_compare(TT2 * I, TT2) - # vector x vector - array_compare( - SE2([TT1, TT1, TT2]) * SE2([TT2, TT1, TT1]), - SE2([TT1 * TT2, TT1 * TT1, TT2 * TT1]), - ) - - # scalar x vector - array_compare(TT1 * SE2([TT2, TT1]), SE2([TT1 * TT2, TT1 * TT1])) - - # vector x scalar - array_compare(SE2([TT1, TT2]) * TT2, SE2([TT1 * TT2, TT2 * TT2])) - ## SE2, * vector product vx = np.r_[1, 0] vy = np.r_[0, 1] @@ -413,18 +284,6 @@ def test_arith(self): array_compare(TT1 * vy, h2e(T1 @ e2h(vy))) - # # vector x vector - # array_compare(SE2([TT1, TT2]) * np.c_[vx, vy], np.c_[h2e(T1 @ e2h(vx)), h2e(T2 @ e2h(vy))]) - - # scalar x vector - array_compare(TT1 * np.c_[vx, vy], h2e(T1 @ e2h(np.c_[vx, vy]))) - - # vector x scalar - array_compare( - SE2([TT1, TT2, TT1]) * vy, - np.c_[h2e(T1 @ e2h(vy)), h2e(T2 @ e2h(vy)), h2e(T1 @ e2h(vy))], - ) - def test_defs(self): # log # x = SE2.Exp([2, 3, 0.5]) @@ -443,37 +302,37 @@ def test_conversions(self): ## Lie stuff x = TT.log() - self.assertTrue(isskewa(x)) + assert isskewa(x) def test_interp(self): TT = SE2(2, -4, 0.6) - I = SE2() + I = SE2.identity() - z = I.interp(TT, s=0) - self.assertIsInstance(z, SE2) + z = I.interp(TT, s=0)[0] + assert isinstance(z, SE2) - array_compare(I.interp(TT, s=0), I) - array_compare(I.interp(TT, s=1), TT) - array_compare(I.interp(TT, s=0.5), SE2(1, -2, 0.3)) + array_compare(I.interp(TT, s=0)[0], I) + array_compare(I.interp(TT, s=1)[0], TT) + array_compare(I.interp(TT, s=0.5)[0], SE2(1, -2, 0.3)) R1 = SO2(math.pi - 0.1) R2 = SO2(-math.pi + 0.2) - array_compare(R1.interp(R2, s=0.5, shortest=False), SO2(0.05)) - array_compare(R1.interp(R2, s=0.5, shortest=True), SO2(-math.pi + 0.05)) + array_compare(R1.interp(R2, s=0.5, shortest=False)[0], SO2(0.05)) + array_compare(R1.interp(R2, s=0.5, shortest=True)[0], SO2(-math.pi + 0.05)) T1 = SE2(0, 0, math.pi - 0.1) T2 = SE2(0, 0, -math.pi + 0.2) - array_compare(T1.interp(T2, s=0.5, shortest=False), SE2(0, 0, 0.05)) - array_compare(T1.interp(T2, s=0.5, shortest=True), SE2(0, 0, -math.pi + 0.05)) + array_compare(T1.interp(T2, s=0.5, shortest=False)[0], SE2(0, 0, 0.05)) + array_compare(T1.interp(T2, s=0.5, shortest=True)[0], SE2(0, 0, -math.pi + 0.05)) def test_miscellany(self): TT = SE2(1, 2, 0.3) - self.assertEqual(TT.A.shape, (3, 3)) + assert TT.A.shape == (3, 3) - self.assertTrue(TT.isSE) + assert TT.isSE - self.assertIsInstance(TT, SE2) + assert isinstance(TT, SE2) def test_display(self): T1 = SE2.Rand() @@ -494,7 +353,5 @@ def test_graphics(self): T1.animate(repeat=False, dims=[-2, 2], nframes=10) T1.animate(T0=T2, repeat=False, dims=[-2, 2], nframes=10) - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main(buffer=True) + def test_identity(self): + array_compare(SE2.identity().A, np.eye(3)) diff --git a/tests/test_pose3d.py b/tests/test_pose3d.py index d6a941c3..e6a10d7a 100755 --- a/tests/test_pose3d.py +++ b/tests/test_pose3d.py @@ -1,6 +1,7 @@ import numpy.testing as nt +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -import unittest import sys import pytest @@ -27,63 +28,52 @@ def array_compare(x, y): nt.assert_array_almost_equal(x, y) -class TestSO3(unittest.TestCase): +class TestSO3: @classmethod def tearDownClass(cls): plt.close("all") def test_constructor(self): - # null constructor - R = SO3() - nt.assert_equal(len(R), 1) - array_compare(R, np.eye(3)) - self.assertIsInstance(R, SO3) - - # empty constructor - R = SO3.Empty() - nt.assert_equal(len(R), 0) - self.assertIsInstance(R, SO3) - # construct from matrix R = SO3(rotx(0.2)) nt.assert_equal(len(R), 1) array_compare(R, rotx(0.2)) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) # construct from canonic rotation R = SO3.Rx(0.2) nt.assert_equal(len(R), 1) array_compare(R, rotx(0.2)) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.Ry(0.2) nt.assert_equal(len(R), 1) array_compare(R, roty(0.2)) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.Rz(0.2) nt.assert_equal(len(R), 1) array_compare(R, rotz(0.2)) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) # OA R = SO3.OA([0, 1, 0], [0, 0, 1]) nt.assert_equal(len(R), 1) array_compare(R, np.eye(3)) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) np.random.seed(32) # random R = SO3.Rand() nt.assert_equal(len(R), 1) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) # random constrained R = SO3.Rand(theta_range=(0.1, 0.7)) - self.assertIsInstance(R, SO3) - self.assertEqual(R.A.shape, (3, 3)) - self.assertLessEqual(R.angvec()[0], 0.7) - self.assertGreaterEqual(R.angvec()[0], 0.1) + assert isinstance(R, SO3) + assert R.A.shape == (3, 3) + assert R.angvec()[0] <= 0.7 + assert R.angvec()[0] >= 0.1 # copy constructor R = SO3.Rx(pi / 2) @@ -95,124 +85,88 @@ def test_constructor_Eul(self): R = SO3.Eul([0.1, 0.2, 0.3]) nt.assert_equal(len(R), 1) array_compare(R, eul2r([0.1, 0.2, 0.3])) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.Eul(0.1, 0.2, 0.3) nt.assert_equal(len(R), 1) array_compare(R, eul2r([0.1, 0.2, 0.3])) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.Eul(np.r_[0.1, 0.2, 0.3]) nt.assert_equal(len(R), 1) array_compare(R, eul2r([0.1, 0.2, 0.3])) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.Eul([10, 20, 30], unit="deg") nt.assert_equal(len(R), 1) array_compare(R, eul2r([10, 20, 30], unit="deg")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.Eul(10, 20, 30, unit="deg") nt.assert_equal(len(R), 1) array_compare(R, eul2r([10, 20, 30], unit="deg")) - self.assertIsInstance(R, SO3) - - # matrix input - - angles = np.array( - [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]] - ) - R = SO3.Eul(angles) - self.assertIsInstance(R, SO3) - nt.assert_equal(len(R), 4) - for i in range(4): - array_compare(R[i], eul2r(angles[i, :])) - - angles *= 10 - R = SO3.Eul(angles, unit="deg") - self.assertIsInstance(R, SO3) - nt.assert_equal(len(R), 4) - for i in range(4): - array_compare(R[i], eul2r(angles[i, :], unit="deg")) + assert isinstance(R, SO3) def test_constructor_RPY(self): R = SO3.RPY(0.1, 0.2, 0.3, order="zyx") nt.assert_equal(len(R), 1) array_compare(R, rpy2r([0.1, 0.2, 0.3], order="zyx")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.RPY(10, 20, 30, unit="deg", order="zyx") nt.assert_equal(len(R), 1) array_compare(R, rpy2r([10, 20, 30], order="zyx", unit="deg")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.RPY([0.1, 0.2, 0.3], order="zyx") nt.assert_equal(len(R), 1) array_compare(R, rpy2r([0.1, 0.2, 0.3], order="zyx")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.RPY(np.r_[0.1, 0.2, 0.3], order="zyx") nt.assert_equal(len(R), 1) array_compare(R, rpy2r([0.1, 0.2, 0.3], order="zyx")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) # check default R = SO3.RPY([0.1, 0.2, 0.3]) nt.assert_equal(len(R), 1) array_compare(R, rpy2r([0.1, 0.2, 0.3], order="zyx")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) # XYZ order R = SO3.RPY(0.1, 0.2, 0.3, order="xyz") nt.assert_equal(len(R), 1) array_compare(R, rpy2r([0.1, 0.2, 0.3], order="xyz")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.RPY(10, 20, 30, unit="deg", order="xyz") nt.assert_equal(len(R), 1) array_compare(R, rpy2r([10, 20, 30], order="xyz", unit="deg")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.RPY([0.1, 0.2, 0.3], order="xyz") nt.assert_equal(len(R), 1) array_compare(R, rpy2r([0.1, 0.2, 0.3], order="xyz")) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.RPY(np.r_[0.1, 0.2, 0.3], order="xyz") nt.assert_equal(len(R), 1) array_compare(R, rpy2r([0.1, 0.2, 0.3], order="xyz")) - self.assertIsInstance(R, SO3) - - # matrix input - - angles = np.array( - [[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]] - ) - R = SO3.RPY(angles, order="zyx") - self.assertIsInstance(R, SO3) - nt.assert_equal(len(R), 4) - for i in range(4): - array_compare(R[i], rpy2r(angles[i, :], order="zyx")) - - angles *= 10 - R = SO3.RPY(angles, unit="deg", order="zyx") - self.assertIsInstance(R, SO3) - nt.assert_equal(len(R), 4) - for i in range(4): - array_compare(R[i], rpy2r(angles[i, :], unit="deg", order="zyx")) + assert isinstance(R, SO3) def test_constructor_AngVec(self): # angvec R = SO3.AngVec(0.2, [1, 0, 0]) nt.assert_equal(len(R), 1) array_compare(R, rotx(0.2)) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) R = SO3.AngVec(0.3, [0, 1, 0]) nt.assert_equal(len(R), 1) array_compare(R, roty(0.3)) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) def test_constructor_TwoVec(self): # Randomly selected vectors @@ -222,21 +176,21 @@ def test_constructor_TwoVec(self): # x and y given R = SO3.TwoVectors(x=v1, y=v2) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) nt.assert_almost_equal(R.det(), 1, 5) # x axis should equal normalized x vector nt.assert_almost_equal(R.R[:, 0], v1 / np.linalg.norm(v1), 5) # y and z given R = SO3.TwoVectors(y=v2, z=v3) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) nt.assert_almost_equal(R.det(), 1, 5) # y axis should equal normalized y vector nt.assert_almost_equal(R.R[:, 1], v2 / np.linalg.norm(v2), 5) # x and z given R = SO3.TwoVectors(x=v3, z=v1) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) nt.assert_almost_equal(R.det(), 1, 5) # x axis should equal normalized x vector nt.assert_almost_equal(R.R[:, 0], v3 / np.linalg.norm(v3), 5) @@ -256,35 +210,35 @@ def test_conversion(self): def test_shape(self): - a = SO3() - self.assertEqual(a._A.shape, a.shape) + a = SO3.identity() + assert a._A.shape == a.shape def test_about(self): - R = SO3() + R = SO3.identity() R.about def test_str(self): - R = SO3() + R = SO3.identity() s = str(R) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 3) + assert isinstance(s, str) + assert s.count("\n") == 3 s = repr(R) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 2) + assert isinstance(s, str) + assert s.count("\n") == 2 def test_printline(self): R = SO3.Rx(0.3) R.printline() # s = R.printline(file=None) - # self.assertIsInstance(s, str) + # assert isinstance(s, str) R = SO3.Rx([0.3, 0.4, 0.5]) s = R.printline(file=None) - # self.assertIsInstance(s, str) - # self.assertEqual(s.count('\n'), 2) + # assert isinstance(s, str) + # assert s.count('\n') == 2 @pytest.mark.skipif( sys.platform.startswith("darwin") and sys.version_info < (3, 11), @@ -300,47 +254,19 @@ def test_plot(self): # R.animate() # R.animate(start=R.inv()) - def test_listpowers(self): - R = SO3() - R1 = SO3.Rx(0.2) - R2 = SO3.Ry(0.3) - - R.append(R1) - R.append(R2) - nt.assert_equal(len(R), 3) - self.assertIsInstance(R, SO3) - - array_compare(R[0], np.eye(3)) - array_compare(R[1], R1) - array_compare(R[2], R2) - - R = SO3([rotx(0.1), rotx(0.2), rotx(0.3)]) - nt.assert_equal(len(R), 3) - self.assertIsInstance(R, SO3) - array_compare(R[0], rotx(0.1)) - array_compare(R[1], rotx(0.2)) - array_compare(R[2], rotx(0.3)) - - R = SO3([SO3.Rx(0.1), SO3.Rx(0.2), SO3.Rx(0.3)]) - nt.assert_equal(len(R), 3) - self.assertIsInstance(R, SO3) - array_compare(R[0], rotx(0.1)) - array_compare(R[1], rotx(0.2)) - array_compare(R[2], rotx(0.3)) - def test_tests(self): - R = SO3() + R = SO3.identity() - self.assertEqual(R.isrot(), True) - self.assertEqual(R.isrot2(), False) - self.assertEqual(R.ishom(), False) - self.assertEqual(R.ishom2(), False) + assert R.isrot() == True + assert R.isrot2() == False + assert R.ishom() == False + assert R.ishom2() == False def test_properties(self): - R = SO3() + R = SO3.identity() - self.assertEqual(R.isSO, True) - self.assertEqual(R.isSE, False) + assert R.isSO == True + assert R.isSE == False array_compare(R.n, np.r_[1, 0, 0]) array_compare(R.n, np.r_[1, 0, 0]) @@ -353,76 +279,76 @@ def test_properties(self): array_compare(R.inv() * R, np.eye(3, 3)) def test_arith(self): - R = SO3() + R = SO3.identity() # sum a = R + R - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, np.array([[2, 0, 0], [0, 2, 0], [0, 0, 2]])) a = R + 1 - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, np.array([[2, 1, 1], [1, 2, 1], [1, 1, 2]])) # a = 1 + R - # self.assertNotIsInstance(a, SO3) + # assert not isinstance(a, SO3) # array_compare(a, np.array([ [2,1,1], [1,2,1], [1,1,2]])) a = R + np.eye(3) - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, np.array([[2, 0, 0], [0, 2, 0], [0, 0, 2]])) # a = np.eye(3) + R - # self.assertNotIsInstance(a, SO3) + # assert not isinstance(a, SO3) # array_compare(a, np.array([ [2,0,0], [0,2,0], [0,0,2]])) # this invokes the __add__ method for numpy # difference - R = SO3() + R = SO3.identity() a = R - R - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, np.zeros((3, 3))) a = R - 1 - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, np.array([[0, -1, -1], [-1, 0, -1], [-1, -1, 0]])) # a = 1 - R - # self.assertNotIsInstance(a, SO3) + # assert not isinstance(a, SO3) # array_compare(a, -np.array([ [0,-1,-1], [-1,0,-1], [-1,-1,0]])) a = R - np.eye(3) - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, np.zeros((3, 3))) # a = np.eye(3) - R - # self.assertNotIsInstance(a, SO3) + # assert not isinstance(a, SO3) # array_compare(a, np.zeros((3,3))) # multiply - R = SO3() + R = SO3.identity() a = R * R - self.assertIsInstance(a, SO3) + assert isinstance(a, SO3) array_compare(a, R) a = R * 2 - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, 2 * np.eye(3)) a = 2 * R - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, 2 * np.eye(3)) - R = SO3() + R = SO3.identity() R *= SO3.Rx(pi / 2) - self.assertIsInstance(R, SO3) + assert isinstance(R, SO3) array_compare(R, rotx(pi / 2)) - R = SO3() + R = SO3.identity() R *= 2 - self.assertNotIsInstance(R, SO3) + assert not isinstance(R, SO3) array_compare(R, 2 * np.eye(3)) array_compare(SO3.Rx(pi / 2) * SO3.Ry(pi / 2) * SO3.Rx(-pi / 2), SO3.Rz(pi / 2)) @@ -456,11 +382,11 @@ def cv(v): # divide R = SO3.Ry(0.3) a = R / R - self.assertIsInstance(a, SO3) + assert isinstance(a, SO3) array_compare(a, np.eye(3)) a = R / 2 - self.assertNotIsInstance(a, SO3) + assert not isinstance(a, SO3) array_compare(a, roty(0.3) / 2) # power @@ -481,211 +407,6 @@ def cv(v): R **= -2 array_compare(R, SO3.Rx(-pi / 2)) - def test_arith_vect(self): - rx = SO3.Rx(pi / 2) - ry = SO3.Ry(pi / 2) - rz = SO3.Rz(pi / 2) - u = SO3() - - # multiply - R = SO3([rx, ry, rz]) - a = R * rx - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], ry * rx) - array_compare(a[2], rz * rx) - - a = rx * R - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], rx * ry) - array_compare(a[2], rx * rz) - - a = R * R - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], ry * ry) - array_compare(a[2], rz * rz) - - a = R * 2 - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * 2) - array_compare(a[1], ry * 2) - array_compare(a[2], rz * 2) - - a = 2 * R - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * 2) - array_compare(a[1], ry * 2) - array_compare(a[2], rz * 2) - - a = R - a *= rx - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], ry * rx) - array_compare(a[2], rz * rx) - - a = rx - a *= R - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], rx * ry) - array_compare(a[2], rx * rz) - - a = R - a *= R - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], ry * ry) - array_compare(a[2], rz * rz) - - a = R - a *= 2 - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * 2) - array_compare(a[1], ry * 2) - array_compare(a[2], rz * 2) - - # SO3 x vector - vx = np.r_[1, 0, 0] - vy = np.r_[0, 1, 0] - vz = np.r_[0, 0, 1] - - a = R * vx - array_compare(a[:, 0], (rx * vx).flatten()) - array_compare(a[:, 1], (ry * vx).flatten()) - array_compare(a[:, 2], (rz * vx).flatten()) - - a = rx * np.vstack((vx, vy, vz)).T - array_compare(a[:, 0], (rx * vx).flatten()) - array_compare(a[:, 1], (rx * vy).flatten()) - array_compare(a[:, 2], (rx * vz).flatten()) - - # divide - R = SO3([rx, ry, rz]) - a = R / rx - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / rx) - array_compare(a[1], ry / rx) - array_compare(a[2], rz / rx) - - a = rx / R - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / rx) - array_compare(a[1], rx / ry) - array_compare(a[2], rx / rz) - - a = R / R - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], np.eye(3)) - array_compare(a[1], np.eye(3)) - array_compare(a[2], np.eye(3)) - - a = R / 2 - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / 2) - array_compare(a[1], ry / 2) - array_compare(a[2], rz / 2) - - a = R - a /= rx - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / rx) - array_compare(a[1], ry / rx) - array_compare(a[2], rz / rx) - - a = rx - a /= R - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / rx) - array_compare(a[1], rx / ry) - array_compare(a[2], rx / rz) - - a = R - a /= R - self.assertIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], np.eye(3)) - array_compare(a[1], np.eye(3)) - array_compare(a[2], np.eye(3)) - - a = R - a /= 2 - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / 2) - array_compare(a[1], ry / 2) - array_compare(a[2], rz / 2) - - # add - R = SO3([rx, ry, rz]) - a = R + rx - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx + rx) - array_compare(a[1], ry + rx) - array_compare(a[2], rz + rx) - - a = rx + R - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx + rx) - array_compare(a[1], rx + ry) - array_compare(a[2], rx + rz) - - a = R + R - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx + rx) - array_compare(a[1], ry + ry) - array_compare(a[2], rz + rz) - - a = R + 1 - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx + 1) - array_compare(a[1], ry + 1) - array_compare(a[2], rz + 1) - - # subtract - R = SO3([rx, ry, rz]) - a = R - rx - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx - rx) - array_compare(a[1], ry - rx) - array_compare(a[2], rz - rx) - - a = rx - R - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx - rx) - array_compare(a[1], rx - ry) - array_compare(a[2], rx - rz) - - a = R - R - self.assertNotIsInstance(a, SO3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx - rx) - array_compare(a[1], ry - ry) - array_compare(a[2], rz - rz) - def test_functions(self): # inv # .T @@ -697,15 +418,6 @@ def test_functions(self): nt.assert_equal(poseSE3.x, poseSE2.x) nt.assert_equal(poseSE3.y, poseSE2.y) - posesSE3 = SE3([poseSE3, poseSE3]) - posesSE2 = posesSE3.yaw_SE2() - nt.assert_equal(len(posesSE2), 2) - - def test_functions_vect(self): - # inv - # .T - pass - def test_functions_lie(self): R = SO3.EulerVec([0.42, 0.73, -1.17]) @@ -717,126 +429,122 @@ def test_functions_lie(self): nt.assert_equal(R, SO3.EulerVec(R.eulervec())) np.testing.assert_equal((R.inv() * R).eulervec(), np.zeros(3)) + def test_identity(self): + nt.assert_equal(SO3.identity().A, np.eye(3)) # ============================== SE3 =====================================# -class TestSE3(unittest.TestCase): +class TestSE3: @classmethod def tearDownClass(cls): plt.close("all") def test_constructor(self): - # null constructor - R = SE3() - nt.assert_equal(len(R), 1) - array_compare(R, np.eye(4)) - self.assertIsInstance(R, SE3) - # construct from matrix R = SE3(trotx(0.2)) nt.assert_equal(len(R), 1) array_compare(R, trotx(0.2)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) # construct from canonic rotation R = SE3.Rx(0.2) nt.assert_equal(len(R), 1) array_compare(R, trotx(0.2)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.Ry(0.2) nt.assert_equal(len(R), 1) array_compare(R, troty(0.2)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.Rz(0.2) nt.assert_equal(len(R), 1) array_compare(R, trotz(0.2)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) # construct from canonic translation R = SE3.Tx(0.2) nt.assert_equal(len(R), 1) array_compare(R, transl(0.2, 0, 0)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.Ty(0.2) nt.assert_equal(len(R), 1) array_compare(R, transl(0, 0.2, 0)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.Tz(0.2) nt.assert_equal(len(R), 1) array_compare(R, transl(0, 0, 0.2)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) # triple angle R = SE3.Eul([0.1, 0.2, 0.3]) nt.assert_equal(len(R), 1) array_compare(R, eul2tr([0.1, 0.2, 0.3])) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.Eul(np.r_[0.1, 0.2, 0.3]) nt.assert_equal(len(R), 1) array_compare(R, eul2tr([0.1, 0.2, 0.3])) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.Eul([10, 20, 30], unit="deg") nt.assert_equal(len(R), 1) array_compare(R, eul2tr([10, 20, 30], unit="deg")) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.RPY([0.1, 0.2, 0.3]) nt.assert_equal(len(R), 1) array_compare(R, rpy2tr([0.1, 0.2, 0.3])) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.RPY(np.r_[0.1, 0.2, 0.3]) nt.assert_equal(len(R), 1) array_compare(R, rpy2tr([0.1, 0.2, 0.3])) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.RPY([10, 20, 30], unit="deg") nt.assert_equal(len(R), 1) array_compare(R, rpy2tr([10, 20, 30], unit="deg")) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.RPY([0.1, 0.2, 0.3], order="xyz") nt.assert_equal(len(R), 1) array_compare(R, rpy2tr([0.1, 0.2, 0.3], order="xyz")) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) # angvec R = SE3.AngVec(0.2, [1, 0, 0]) nt.assert_equal(len(R), 1) array_compare(R, trotx(0.2)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) R = SE3.AngVec(0.3, [0, 1, 0]) nt.assert_equal(len(R), 1) array_compare(R, troty(0.3)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) # OA R = SE3.OA([0, 1, 0], [0, 0, 1]) nt.assert_equal(len(R), 1) array_compare(R, np.eye(4)) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) np.random.seed(65) # random R = SE3.Rand() nt.assert_equal(len(R), 1) - self.assertIsInstance(R, SE3) + assert isinstance(R, SE3) # random T = SE3.Rand() R = T.R t = T.t T = SE3.Rt(R, t) - self.assertIsInstance(T, SE3) - self.assertEqual(T.A.shape, (4, 4)) + assert isinstance(T, SE3) + assert T.A.shape == (4, 4) nt.assert_equal(T.R, R) nt.assert_equal(T.t, t) @@ -845,23 +553,12 @@ def test_constructor(self): nt.assert_equal(T.y, t[1]) nt.assert_equal(T.z, t[2]) - TT = SE3([T, T, T]) - desired_shape = (3,) - nt.assert_equal(TT.x.shape, desired_shape) - nt.assert_equal(TT.y.shape, desired_shape) - nt.assert_equal(TT.z.shape, desired_shape) - - ones = np.ones(desired_shape) - nt.assert_equal(TT.x, ones * t[0]) - nt.assert_equal(TT.y, ones * t[1]) - nt.assert_equal(TT.z, ones * t[2]) - # random constrained T = SE3.Rand(theta_range=(0.1, 0.7)) - self.assertIsInstance(T, SE3) - self.assertEqual(T.A.shape, (4, 4)) - self.assertLessEqual(T.angvec()[0], 0.7) - self.assertGreaterEqual(T.angvec()[0], 0.1) + assert isinstance(T, SE3) + assert T.A.shape == (4, 4) + assert T.angvec()[0] <= 0.7 + assert T.angvec()[0] >= 0.1 # copy constructor R = SE3.Rx(pi / 2) @@ -870,69 +567,41 @@ def test_constructor(self): array_compare(R2, trotx(pi / 2)) # SO3 - T = SE3(SO3()) + T = SE3(SO3.identity()) nt.assert_equal(len(T), 1) - self.assertIsInstance(T, SE3) + assert isinstance(T, SE3) nt.assert_equal(T.A, np.eye(4)) # SE2 T = SE3(SE2(1, 2, 0.4)) nt.assert_equal(len(T), 1) - self.assertIsInstance(T, SE3) - self.assertEqual(T.A.shape, (4, 4)) + assert isinstance(T, SE3) + assert T.A.shape == (4, 4) nt.assert_equal(T.t, [1, 2, 0]) # Bad number of arguments - with self.assertRaises(ValueError): + with pytest.raises(ValueError): T = SE3(1.0, 0.0) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): T = SE3(1.0, 0.0, 0.0, 0.0) def test_shape(self): - a = SE3() - self.assertEqual(a._A.shape, a.shape) - - def test_listpowers(self): - R = SE3() - R1 = SE3.Rx(0.2) - R2 = SE3.Ry(0.3) - - R.append(R1) - R.append(R2) - nt.assert_equal(len(R), 3) - self.assertIsInstance(R, SE3) - - array_compare(R[0], np.eye(4)) - array_compare(R[1], R1) - array_compare(R[2], R2) - - R = SE3([trotx(0.1), trotx(0.2), trotx(0.3)]) - nt.assert_equal(len(R), 3) - self.assertIsInstance(R, SE3) - array_compare(R[0], trotx(0.1)) - array_compare(R[1], trotx(0.2)) - array_compare(R[2], trotx(0.3)) - - R = SE3([SE3.Rx(0.1), SE3.Rx(0.2), SE3.Rx(0.3)]) - nt.assert_equal(len(R), 3) - self.assertIsInstance(R, SE3) - array_compare(R[0], trotx(0.1)) - array_compare(R[1], trotx(0.2)) - array_compare(R[2], trotx(0.3)) + a = SE3.identity() + assert a._A.shape == a.shape def test_tests(self): - R = SE3() + R = SE3.identity() - self.assertEqual(R.isrot(), False) - self.assertEqual(R.isrot2(), False) - self.assertEqual(R.ishom(), True) - self.assertEqual(R.ishom2(), False) + assert R.isrot() == False + assert R.isrot2() == False + assert R.ishom() == True + assert R.ishom2() == False def test_properties(self): - R = SE3() + R = SE3.identity() - self.assertEqual(R.isSO, False) - self.assertEqual(R.isSE, True) + assert R.isSO == False + assert R.isSE == True array_compare(R.n, np.r_[1, 0, 0]) array_compare(R.n, np.r_[1, 0, 0]) @@ -946,10 +615,10 @@ def test_properties(self): pass_by_ref = SE3(mutable_array) pass_by_val = SE3.CopyFrom(mutable_array) mutable_array[0, 3] = 5.0 - nt.assert_allclose(pass_by_val.data[0], np.eye(4)) - nt.assert_allclose(pass_by_ref.data[0], mutable_array) + nt.assert_allclose(pass_by_val.data, np.eye(4)) + nt.assert_allclose(pass_by_ref.data, mutable_array) nt.assert_raises( - AssertionError, nt.assert_allclose, pass_by_val.data[0], pass_by_ref.data[0] + AssertionError, nt.assert_allclose, pass_by_val.data, pass_by_ref.data ) def test_arith(self): @@ -957,29 +626,29 @@ def test_arith(self): # sum a = T + T - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare( a, np.array([[2, 0, 0, 2], [0, 2, 0, 4], [0, 0, 2, 6], [0, 0, 0, 2]]) ) a = T + 1 - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare( a, np.array([[2, 1, 1, 2], [1, 2, 1, 3], [1, 1, 2, 4], [1, 1, 1, 2]]) ) # a = 1 + T - # self.assertNotIsInstance(a, SE3) + # assert not isinstance(a, SE3) # array_compare(a, np.array([ [2,1,1], [1,2,1], [1,1,2]])) a = T + np.eye(4) - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare( a, np.array([[2, 0, 0, 1], [0, 2, 0, 2], [0, 0, 2, 3], [0, 0, 0, 2]]) ) # a = np.eye(3) + T - # self.assertNotIsInstance(a, SE3) + # assert not isinstance(a, SE3) # array_compare(a, np.array([ [2,0,0], [0,2,0], [0,0,2]])) # this invokes the __add__ method for numpy @@ -987,60 +656,60 @@ def test_arith(self): T = SE3(1, 2, 3) a = T - T - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare(a, np.zeros((4, 4))) a = T - 1 - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare( a, np.array([[0, -1, -1, 0], [-1, 0, -1, 1], [-1, -1, 0, 2], [-1, -1, -1, 0]]), ) # a = 1 - T - # self.assertNotIsInstance(a, SE3) + # assert not isinstance(a, SE3) # array_compare(a, -np.array([ [0,-1,-1], [-1,0,-1], [-1,-1,0]])) a = T - np.eye(4) - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare( a, np.array([[0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 3], [0, 0, 0, 0]]) ) # a = np.eye(3) - T - # self.assertNotIsInstance(a, SE3) + # assert not isinstance(a, SE3) # array_compare(a, np.zeros((3,3))) a = T a -= T - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare(a, np.zeros((4, 4))) # multiply T = SE3(1, 2, 3) a = T * T - self.assertIsInstance(a, SE3) + assert isinstance(a, SE3) array_compare(a, transl(2, 4, 6)) a = T * 2 - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare(a, 2 * transl(1, 2, 3)) a = 2 * T - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare(a, 2 * transl(1, 2, 3)) T = SE3(1, 2, 3) T *= SE3.Ry(pi / 2) - self.assertIsInstance(T, SE3) + assert isinstance(T, SE3) array_compare( T, np.array([[0, 0, 1, 1], [0, 1, 0, 2], [-1, 0, 0, 3], [0, 0, 0, 1]]) ) - T = SE3() + T = SE3.identity() T *= 2 - self.assertNotIsInstance(T, SE3) + assert not isinstance(T, SE3) array_compare(T, 2 * np.eye(4)) array_compare(SE3.Rx(pi / 2) * SE3.Ry(pi / 2) * SE3.Rx(-pi / 2), SE3.Rz(pi / 2)) @@ -1071,262 +740,42 @@ def cv(v): # divide T = SE3.Ry(0.3) a = T / T - self.assertIsInstance(a, SE3) + assert isinstance(a, SE3) array_compare(a, np.eye(4)) a = T / 2 - self.assertNotIsInstance(a, SE3) + assert not isinstance(a, SE3) array_compare(a, troty(0.3) / 2) - def test_arith_vect(self): - rx = SE3.Rx(pi / 2) - ry = SE3.Ry(pi / 2) - rz = SE3.Rz(pi / 2) - u = SE3() - - # multiply - T = SE3([rx, ry, rz]) - a = T * rx - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], ry * rx) - array_compare(a[2], rz * rx) - - a = rx * T - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], rx * ry) - array_compare(a[2], rx * rz) - - a = T * T - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], ry * ry) - array_compare(a[2], rz * rz) - - a = T * 2 - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * 2) - array_compare(a[1], ry * 2) - array_compare(a[2], rz * 2) - - a = 2 * T - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * 2) - array_compare(a[1], ry * 2) - array_compare(a[2], rz * 2) - - a = T - a *= rx - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], ry * rx) - array_compare(a[2], rz * rx) - - a = rx - a *= T - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], rx * ry) - array_compare(a[2], rx * rz) - - a = T - a *= T - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * rx) - array_compare(a[1], ry * ry) - array_compare(a[2], rz * rz) - - a = T - a *= 2 - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx * 2) - array_compare(a[1], ry * 2) - array_compare(a[2], rz * 2) - - # SE3 x vector - vx = np.r_[1, 0, 0] - vy = np.r_[0, 1, 0] - vz = np.r_[0, 0, 1] - - a = T * vx - array_compare(a[:, 0], (rx * vx).flatten()) - array_compare(a[:, 1], (ry * vx).flatten()) - array_compare(a[:, 2], (rz * vx).flatten()) - - a = rx * np.vstack((vx, vy, vz)).T - array_compare(a[:, 0], (rx * vx).flatten()) - array_compare(a[:, 1], (rx * vy).flatten()) - array_compare(a[:, 2], (rx * vz).flatten()) - - # divide - T = SE3([rx, ry, rz]) - a = T / rx - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / rx) - array_compare(a[1], ry / rx) - array_compare(a[2], rz / rx) - - a = rx / T - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / rx) - array_compare(a[1], rx / ry) - array_compare(a[2], rx / rz) - - a = T / T - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], np.eye(4)) - array_compare(a[1], np.eye(4)) - array_compare(a[2], np.eye(4)) - - a = T / 2 - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / 2) - array_compare(a[1], ry / 2) - array_compare(a[2], rz / 2) - - a = T - a /= rx - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / rx) - array_compare(a[1], ry / rx) - array_compare(a[2], rz / rx) - - a = rx - a /= T - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / rx) - array_compare(a[1], rx / ry) - array_compare(a[2], rx / rz) - - a = T - a /= T - self.assertIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], np.eye(4)) - array_compare(a[1], np.eye(4)) - array_compare(a[2], np.eye(4)) - - a = T - a /= 2 - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx / 2) - array_compare(a[1], ry / 2) - array_compare(a[2], rz / 2) - - # add - T = SE3([rx, ry, rz]) - a = T + rx - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx + rx) - array_compare(a[1], ry + rx) - array_compare(a[2], rz + rx) - - a = rx + T - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx + rx) - array_compare(a[1], rx + ry) - array_compare(a[2], rx + rz) - - a = T + T - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx + rx) - array_compare(a[1], ry + ry) - array_compare(a[2], rz + rz) - - a = T + 1 - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx + 1) - array_compare(a[1], ry + 1) - array_compare(a[2], rz + 1) - - # subtract - T = SE3([rx, ry, rz]) - a = T - rx - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx - rx) - array_compare(a[1], ry - rx) - array_compare(a[2], rz - rx) - - a = rx - T - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx - rx) - array_compare(a[1], rx - ry) - array_compare(a[2], rx - rz) - - a = T - T - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx - rx) - array_compare(a[1], ry - ry) - array_compare(a[2], rz - rz) - - a = T - 1 - self.assertNotIsInstance(a, SE3) - nt.assert_equal(len(a), 3) - array_compare(a[0], rx - 1) - array_compare(a[1], ry - 1) - array_compare(a[2], rz - 1) - def test_angle(self): # angle between SO3's r1 = SO3.Rx(0.1) r2 = SO3.Rx(0.2) for metric in range(6): - self.assertAlmostEqual(r1.angdist(other=r1, metric=metric), 0.0) - self.assertGreater(r1.angdist(other=r2, metric=metric), 0.0) - self.assertAlmostEqual( - r1.angdist(other=r2, metric=metric), r2.angdist(other=r1, metric=metric) - ) + assert r1.angdist(other=r1, metric=metric) == pytest.approx(0.0) + assert r1.angdist(other=r2, metric=metric) > 0.0 + assert r1.angdist(other=r2, metric=metric) == pytest.approx(r2.angdist(other=r1, metric=metric)) # angle between SE3's p1a, p1b = SE3.Rx(0.1), SE3.Rx(0.1, t=(1, 2, 3)) p2a, p2b = SE3.Rx(0.2), SE3.Rx(0.2, t=(3, 2, 1)) for metric in range(6): - self.assertAlmostEqual(p1a.angdist(other=p1a, metric=metric), 0.0) - self.assertGreater(p1a.angdist(other=p2a, metric=metric), 0.0) - self.assertAlmostEqual(p1a.angdist(other=p1b, metric=metric), 0.0) - self.assertAlmostEqual( - p1a.angdist(other=p2a, metric=metric), - p2a.angdist(other=p1a, metric=metric), - ) - self.assertAlmostEqual( - p1a.angdist(other=p2a, metric=metric), - p1a.angdist(other=p2b, metric=metric), - ) + assert p1a.angdist(other=p1a, metric=metric) == pytest.approx(0.0) + assert p1a.angdist(other=p2a, metric=metric) > 0.0 + assert p1a.angdist(other=p1b, metric=metric) == pytest.approx(0.0) + assert p1a.angdist(other=p2a, metric=metric) == pytest.approx(p2a.angdist(other=p1a, metric=metric)) + assert p1a.angdist(other=p2a, metric=metric) == pytest.approx(p1a.angdist(other=p2b, metric=metric)) # angdist is not implemented for mismatched types - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = r1.angdist(p1a) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = r1._op2(right=p1a, op=r1.angdist) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = p1a._op2(right=r1, op=p1a.angdist) # in general, the _op2 interface enforces an isinstance check. - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = r1._op2(right=(1, 0, 0), op=r1.angdist) def test_functions(self): @@ -1339,7 +788,5 @@ def test_functions_vect(self): # .T pass - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() + def test_identity(self): + nt.assert_equal(SE3.identity().A, np.eye(4)) diff --git a/tests/test_quaternion.py b/tests/test_quaternion.py index 73c1b090..a48e2534 100644 --- a/tests/test_quaternion.py +++ b/tests/test_quaternion.py @@ -1,7 +1,7 @@ import math from math import pi import numpy.testing as nt -import unittest +import pytest from spatialmath import * from spatialmath.base import * @@ -16,20 +16,22 @@ def qcompare(x, y): x = x.vec elif isinstance(x, BasePoseMatrix): x = x.A + elif isinstance(x, list) and isinstance(x[0], Quaternion): + x = [q.A for q in x] if isinstance(y, Quaternion): y = y.vec elif isinstance(y, BasePoseMatrix): y = y.A + elif isinstance(y, list) and isinstance(y[0], Quaternion): + y = [q.A for q in y] nt.assert_array_almost_equal(x, y) # straight port of the MATLAB unit tests -class TestUnitQuaternion(unittest.TestCase): +class TestUnitQuaternion: def test_constructor_variants(self): - nt.assert_array_almost_equal(UnitQuaternion().vec, np.r_[1, 0, 0, 0]) - nt.assert_array_almost_equal( UnitQuaternion.Rx(90, "deg").vec, np.r_[1, 1, 0, 0] / math.sqrt(2) ) @@ -51,20 +53,18 @@ def test_constructor_variants(self): np.random.seed(73) q = UnitQuaternion.Rand(theta_range=(0.1, 0.7)) - self.assertIsInstance(q, UnitQuaternion) - self.assertLessEqual(q.angvec()[0], 0.7) - self.assertGreaterEqual(q.angvec()[0], 0.1) + assert isinstance(q, UnitQuaternion) + assert q.angvec()[0] <= 0.7 + assert q.angvec()[0] >= 0.1 q = UnitQuaternion.Rand(theta_range=(0.1, 0.7)) - self.assertIsInstance(q, UnitQuaternion) - self.assertLessEqual(q.angvec()[0], 0.7) - self.assertGreaterEqual(q.angvec()[0], 0.1) + assert isinstance(q, UnitQuaternion) + assert q.angvec()[0] <= 0.7 + assert q.angvec()[0] >= 0.1 def test_constructor(self): - qcompare(UnitQuaternion(), [1, 0, 0, 0]) - # from S qcompare(UnitQuaternion([1, 0, 0, 0]), np.r_[1, 0, 0, 0]) qcompare(UnitQuaternion([0, 1, 0, 0]), np.r_[0, 1, 0, 0]) @@ -104,7 +104,7 @@ def test_constructor(self): # from SO3 - qcompare(UnitQuaternion(SO3()), np.r_[1, 0, 0, 0]) + qcompare(UnitQuaternion(SO3.identity()), np.r_[1, 0, 0, 0]) qcompare(UnitQuaternion(SO3.Rx(pi / 2)), np.r_[1, 1, 0, 0] / math.sqrt(2)) qcompare(UnitQuaternion(SO3.Ry(pi / 2)), np.r_[1, 0, 1, 0] / math.sqrt(2)) @@ -118,14 +118,9 @@ def test_constructor(self): qcompare(UnitQuaternion(SO3.Ry(pi)), np.r_[0, 0, 1, 0]) qcompare(UnitQuaternion(SO3.Rz(pi)), np.r_[0, 0, 0, 1]) - # vector of SO3 - q = UnitQuaternion([SO3.Rx(pi / 2), SO3.Ry(pi / 2), SO3.Rz(pi / 2)]) - self.assertEqual(len(q), 3) - qcompare(q, np.array([[1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1]]) / math.sqrt(2)) - # from SE3 - qcompare(UnitQuaternion(SE3()), np.r_[1, 0, 0, 0]) + qcompare(UnitQuaternion(SE3.identity()), np.r_[1, 0, 0, 0]) qcompare(UnitQuaternion(SE3.Rx(pi / 2)), np.r_[1, 1, 0, 0] / math.sqrt(2)) qcompare(UnitQuaternion(SE3.Ry(pi / 2)), np.r_[1, 0, 1, 0] / math.sqrt(2)) @@ -139,30 +134,6 @@ def test_constructor(self): qcompare(UnitQuaternion(SE3.Ry(pi)), np.r_[0, 0, 1, 0]) qcompare(UnitQuaternion(SE3.Rz(pi)), np.r_[0, 0, 0, 1]) - # vector of SE3 - q = UnitQuaternion([SE3.Rx(pi / 2), SE3.Ry(pi / 2), SE3.Rz(pi / 2)]) - self.assertEqual(len(q), 3) - qcompare(q, np.array([[1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1]]) / math.sqrt(2)) - - # from S - M = np.identity(4) - q = UnitQuaternion(M) - self.assertEqual(len(q), 4) - - qcompare(q[0], np.r_[1, 0, 0, 0]) - qcompare(q[1], np.r_[0, 1, 0, 0]) - qcompare(q[2], np.r_[0, 0, 1, 0]) - qcompare(q[3], np.r_[0, 0, 0, 1]) - - # # vectorised forms of R, T - # R = []; T = [] - # for theta in [-pi/2, 0, pi/2, pi]: - # R = cat(3, R, rotx(theta), roty(theta), rotz(theta)) - # T = cat(3, T, trotx(theta), troty(theta), trotz(theta)) - - # nt.assert_array_almost_equal(UnitQuaternion(R).R, R) - # nt.assert_array_almost_equal(UnitQuaternion(T).T, T) - # copy constructor q = UnitQuaternion(rotx(0.3)) qcompare(UnitQuaternion(q), q) @@ -170,38 +141,26 @@ def test_constructor(self): # fail when invalid arrays are provided # invalid rotation matrix R = 1.1 * np.eye(3) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): UnitQuaternion(R, check=True) # wrong shape to be anything R = np.zeros((5, 5)) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): UnitQuaternion(R, check=True) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): UnitQuaternion(R, check=False) - def test_concat(self): - u = UnitQuaternion() - uu = UnitQuaternion([u, u, u, u]) - - self.assertIsInstance(uu, UnitQuaternion) - self.assertEqual(len(uu), 4) - def test_string(self): - u = UnitQuaternion() + u = UnitQuaternion.identity() s = str(u) - self.assertIsInstance(s, str) - self.assertTrue(s.endswith(" >>")) - self.assertEqual(s.count("\n"), 0) - - q = UnitQuaternion.Rx([0.3, 0.4, 0.5]) - s = str(q) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 2) + assert isinstance(s, str) + assert s.endswith(" >>") + assert s.count("\n") == 0 def test_properties(self): - u = UnitQuaternion() + u = UnitQuaternion.identity() # s,v nt.assert_array_almost_equal(UnitQuaternion([1, 0, 0, 0]).s, 1) @@ -361,40 +320,40 @@ def test_convert(self): def test_resulttype(self): q = Quaternion([2, 0, 0, 0]) - u = UnitQuaternion() + u = UnitQuaternion.identity() - self.assertIsInstance(q * q, Quaternion) - self.assertIsInstance(q * u, Quaternion) - self.assertIsInstance(u * q, Quaternion) - self.assertIsInstance(u * u, UnitQuaternion) + assert isinstance(q * q, Quaternion) + assert isinstance(q * u, Quaternion) + assert isinstance(u * q, Quaternion) + assert isinstance(u * u, UnitQuaternion) - # self.assertIsInstance(u.*u, UnitQuaternion) + # assert isinstance(u.*u, UnitQuaternion) # other combos all fail, test this? - self.assertIsInstance(u / u, UnitQuaternion) + assert isinstance(u / u, UnitQuaternion) - self.assertIsInstance(u.conj(), UnitQuaternion) - self.assertIsInstance(u.inv(), UnitQuaternion) - self.assertIsInstance(u.unit(), UnitQuaternion) - self.assertIsInstance(q.unit(), UnitQuaternion) + assert isinstance(u.conj(), UnitQuaternion) + assert isinstance(u.inv(), UnitQuaternion) + assert isinstance(u.unit(), UnitQuaternion) + assert isinstance(q.unit(), UnitQuaternion) - self.assertIsInstance(q.conj(), Quaternion) + assert isinstance(q.conj(), Quaternion) - self.assertIsInstance(q + q, Quaternion) - self.assertIsInstance(q - q, Quaternion) + assert isinstance(q + q, Quaternion) + assert isinstance(q - q, Quaternion) - self.assertIsInstance(u + u, Quaternion) - self.assertIsInstance(u - u, Quaternion) + assert isinstance(u + u, Quaternion) + assert isinstance(u - u, Quaternion) - # self.assertIsInstance(q+u, Quaternion) - # self.assertIsInstance(u+q, Quaternion) + # assert isinstance(q+u, Quaternion) + # assert isinstance(u+q, Quaternion) - # self.assertIsInstance(q-u, Quaternion) - # self.assertIsInstance(u-q, Quaternion) + # assert isinstance(q-u, Quaternion) + # assert isinstance(u-q, Quaternion) # TODO test for ValueError in these cases - self.assertIsInstance(u.SO3(), SO3) - self.assertIsInstance(u.SE3(), SE3) + assert isinstance(u.SO3(), SO3) + assert isinstance(u.SE3(), SE3) def test_multiply(self): vx = np.r_[1, 0, 0] @@ -403,7 +362,7 @@ def test_multiply(self): rx = UnitQuaternion.Rx(pi / 2) ry = UnitQuaternion.Ry(pi / 2) rz = UnitQuaternion.Rz(pi / 2) - u = UnitQuaternion() + u = UnitQuaternion.identity() # quat-quat product # scalar x scalar @@ -411,37 +370,11 @@ def test_multiply(self): qcompare(rx * u, rx) qcompare(u * rx, rx) - # vector x vector - qcompare( - UnitQuaternion([ry, rz, rx]) * UnitQuaternion([rx, ry, rz]), - UnitQuaternion([ry * rx, rz * ry, rx * rz]), - ) - - # scalar x vector - qcompare( - ry * UnitQuaternion([rx, ry, rz]), - UnitQuaternion([ry * rx, ry * ry, ry * rz]), - ) - - # vector x scalar - qcompare( - UnitQuaternion([rx, ry, rz]) * ry, - UnitQuaternion([rx * ry, ry * ry, rz * ry]), - ) - # quatvector product # scalar x scalar qcompare(rx * vy, vz) - # scalar x vector - nt.assert_array_almost_equal(ry * np.c_[vx, vy, vz], np.c_[-vz, vy, vx]) - - # vector x scalar - nt.assert_array_almost_equal( - UnitQuaternion([ry, rz, rx]) * vy, np.c_[vy, -vx, vz] - ) - def test_matmul(self): rx = UnitQuaternion.Rx(pi / 2) ry = UnitQuaternion.Ry(pi / 2) @@ -460,7 +393,7 @@ def test_matmul(self): # rx = UnitQuaternion.Rx(pi/2) # ry = UnitQuaternion.Ry(pi/2) # rz = UnitQuaternion.Rz(pi/2) - # u = UnitQuaternion() + # u = UnitQuaternion.identity() # # quat-quat product # # scalar x scalar @@ -485,7 +418,7 @@ def test_divide(self): rx = UnitQuaternion.Rx(pi / 2) ry = UnitQuaternion.Ry(pi / 2) rz = UnitQuaternion.Rz(pi / 2) - u = UnitQuaternion() + u = UnitQuaternion.identity() # scalar / scalar # implicity tests inv @@ -516,13 +449,10 @@ def test_angle(self): uq1 = UnitQuaternion.Rx(0.1) uq2 = UnitQuaternion.Ry(0.1) for metric in range(5): - self.assertEqual(uq1.angdist(other=uq1, metric=metric), 0.0) - self.assertEqual(uq2.angdist(other=uq2, metric=metric), 0.0) - self.assertEqual( - uq1.angdist(other=uq2, metric=metric), - uq2.angdist(other=uq1, metric=metric), - ) - self.assertTrue(uq1.angdist(other=uq2, metric=metric) > 0) + assert uq1.angdist(other=uq1, metric=metric) == 0.0 + assert uq2.angdist(other=uq2, metric=metric) == 0.0 + assert uq1.angdist(other=uq2, metric=metric) == uq2.angdist(other=uq1, metric=metric) + assert uq1.angdist(other=uq2, metric=metric) > 0 def test_conversions(self): # , 3 angle @@ -546,18 +476,18 @@ def test_conversions(self): th = 0.2 v = unitvec([1, 2, 3]) [a, b] = UnitQuaternion.AngVec(th, v).angvec() - self.assertAlmostEqual(a, th) + assert a == pytest.approx(th) nt.assert_array_almost_equal(b, v) [a, b] = UnitQuaternion.AngVec(-th, v).angvec() - self.assertAlmostEqual(a, th) + assert a == pytest.approx(th) nt.assert_array_almost_equal(b, -v) # null rotation case th = 0 v = unitvec([1, 2, 3]) [a, b] = UnitQuaternion.AngVec(th, v).angvec() - self.assertAlmostEqual(a, th) + assert a == pytest.approx(th) # SO3 convert to SO3 class # SE3 convert to SE3 class @@ -568,15 +498,13 @@ def test_miscellany(self): rx = UnitQuaternion.Rx(pi / 2) ry = UnitQuaternion.Ry(pi / 2) rz = UnitQuaternion.Rz(pi / 2) - u = UnitQuaternion() + u = UnitQuaternion.identity() # norm qcompare(rx.norm(), 1) - qcompare(UnitQuaternion([rx, ry, rz]).norm(), [1, 1, 1]) # unit qcompare(rx.unit(), rx) - qcompare(UnitQuaternion([rx, ry, rz]).unit(), UnitQuaternion([rx, ry, rz])) # inner nt.assert_array_almost_equal(u.inner(u), 1) @@ -590,11 +518,11 @@ def test_miscellany(self): qcompare(q**2, q * q) # angle - # self.assertEqual(angle(u, u), 0) - # self.assertEqual(angle(u, rx), pi/4) - # self.assertEqual(angle(u, [rx, u]), pi/4*np.r_[1, 0]) - # self.assertEqual(angle([rx, u], u), pi/4*np.r_[1, 0]) - # self.assertEqual(angle([rx, u], [u, rx]), pi/4*np.r_[1, 1]) + # assert angle(u, u) == 0 + # assert angle(u, rx) == pi/4 + # assert angle(u, [rx, u]) == pi/4*np.r_[1, 0] + # assert angle([rx, u], u) == pi/4*np.r_[1, 0] + # assert angle([rx, u], [u, rx]) == pi/4*np.r_[1, 1] # TODO angle # increment @@ -606,38 +534,38 @@ def test_interp(self): rx = UnitQuaternion.Rx(pi / 2) ry = UnitQuaternion.Ry(pi / 2) rz = UnitQuaternion.Rz(pi / 2) - u = UnitQuaternion() + u = UnitQuaternion.identity() q = UnitQuaternion.RPY([0.2, 0.3, 0.4]) # from null - qcompare(q.interp1(0), u) - qcompare(q.interp1(1), q) + qcompare(q.interp1(0), [u]) + qcompare(q.interp1(1), [q]) - # self.assertEqual(length(q.interp(linspace(0,1, 10))), 10) - # self.assertTrue(all( q.interp([0, 1]) == [u, q])) + # assert length(q.interp(linspace(0,1, 10))) == 10 + # assert all( q.interp([0, 1]) == [u, q]) # TODO vectorizing - q0_5 = q.interp1(0.5) + q0_5 = q.interp1(0.5)[0] qcompare(q0_5 * q0_5, q) qq = rx.interp1(11) - self.assertEqual(len(qq), 11) + assert len(qq) == 11 # between two quaternions - qcompare(q.interp(rx, 0), q) - qcompare(q.interp(rx, 1), rx) + qcompare(q.interp(rx, 0), [q]) + qcompare(q.interp(rx, 1), [rx]) # test vectorised results qq = q.interp(rx, [0, 1]) - self.assertEqual(len(qq), 2) + assert len(qq) == 2 qcompare(qq[0], q) qcompare(qq[1], rx) qq = rx.interp(q, 11) - self.assertEqual(len(qq), 11) + assert len(qq) == 11 - # self.assertTrue(all( q.interp([0, 1], dest=rx, ) == [q, rx])) + # assert all( q.interp([0, 1], dest=rx, ) == [q, rx]) # test shortest option # q1 = UnitQuaternion.Rx(0.9*pi) @@ -649,19 +577,19 @@ def test_interp(self): # TODO interp def test_increment(self): - q = UnitQuaternion() + q = UnitQuaternion.identity() q.increment([0, 0, 0]) - qcompare(q, UnitQuaternion()) + qcompare(q, UnitQuaternion.identity()) q.increment([0, 0, 0], normalize=True) - qcompare(q, UnitQuaternion()) + qcompare(q, UnitQuaternion.identity()) for i in range(10): q.increment([0.1, 0, 0]) qcompare(q, UnitQuaternion.Rx(1)) - q = UnitQuaternion() + q = UnitQuaternion.identity() for i in range(10): q.increment([0.1, 0, 0], normalize=True) qcompare(q, UnitQuaternion.Rx(1)) @@ -671,38 +599,23 @@ def test_eq(self): q2 = UnitQuaternion([0, -1, 0, 0]) q3 = UnitQuaternion.Rz(pi / 2) - self.assertTrue(q1 == q1) - self.assertTrue(q2 == q2) - self.assertTrue(q3 == q3) - self.assertTrue(q1 == q2) # because of double wrapping - self.assertFalse(q1 == q3) - - nt.assert_array_almost_equal( - UnitQuaternion([q1, q1, q1]) == UnitQuaternion([q1, q1, q1]), - [True, True, True], - ) - nt.assert_array_almost_equal( - UnitQuaternion([q1, q2, q3]) == UnitQuaternion([q1, q2, q3]), - [True, True, True], - ) - nt.assert_array_almost_equal( - UnitQuaternion([q1, q1, q3]) == q1, [True, True, False] - ) - nt.assert_array_almost_equal( - q3 == UnitQuaternion([q1, q1, q3]), [False, False, True] - ) + assert q1 == q1 + assert q2 == q2 + assert q3 == q3 + assert q1 == q2 # because of double wrapping + assert not q1 == q3 def test_logical(self): rx = UnitQuaternion.Rx(pi / 2) ry = UnitQuaternion.Ry(pi / 2) # equality tests - self.assertTrue(rx == rx) - self.assertFalse(rx != rx) - self.assertFalse(rx == ry) + assert rx == rx + assert not rx != rx + assert not rx == ry def test_dot(self): - q = UnitQuaternion() + q = UnitQuaternion.identity() omega = np.r_[1, 2, 3] nt.assert_array_almost_equal(q.dot(omega), np.r_[0, omega / 2]) @@ -742,15 +655,12 @@ def test_vec3(self): # ry.animate('rgb') # ry.animate( UnitQuaternion.Rx(pi/2), 'rgb' ) + def test_identity(self): + qcompare(UnitQuaternion.identity(), [1, 0, 0, 0]) -class TestQuaternion(unittest.TestCase): - def test_constructor(self): - q = Quaternion() - self.assertEqual(len(q), 1) - self.assertIsInstance(q, Quaternion) - - nt.assert_array_almost_equal(Quaternion().vec, [0, 0, 0, 0]) +class TestQuaternion: + def test_constructor(self): # from S nt.assert_array_almost_equal(Quaternion([1, 0, 0, 0]).vec, [1, 0, 0, 0]) nt.assert_array_almost_equal(Quaternion([0, 1, 0, 0]).vec, [0, 1, 0, 0]) @@ -794,22 +704,17 @@ def test_constructor(self): # tc.verifyError( @() Quaternion([1, 2, 3]), 'SMTB:Quaternion:badarg') def test_string(self): - u = Quaternion() + u = Quaternion.identity() s = str(u) - self.assertIsInstance(s, str) - self.assertTrue(s.endswith(" >")) - self.assertEqual(s.count("\n"), 0) - self.assertEqual(len(s), 37) - - q = Quaternion([u, u, u]) - s = str(q) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 2) + assert isinstance(s, str) + assert s.endswith(" >") + assert s.count("\n") == 0 + assert len(s) == 37 def test_properties(self): q = Quaternion([1, 2, 3, 4]) - self.assertEqual(q.s, 1) + assert q.s == 1 nt.assert_array_almost_equal(q.v, np.r_[2, 3, 4]) nt.assert_array_almost_equal(q.vec, np.r_[1, 2, 3, 4]) @@ -824,19 +729,12 @@ def test_log(self): q1 = Quaternion([4, 3, 2, 1]) q2 = Quaternion([-1, 2, -3, 4]) - self.assertTrue(isscalar(q1.log().s)) - self.assertTrue(isvector(q1.log().v, 3)) + assert isscalar(q1.log().s) + assert isvector(q1.log().v, 3) nt.assert_array_almost_equal(q1.log().exp(), q1) nt.assert_array_almost_equal(q2.log().exp(), q2) - def test_concat(self): - u = Quaternion() - uu = Quaternion([u, u, u, u]) - - self.assertIsInstance(uu, Quaternion) - self.assertEqual(len(uu), 4) - def primitive_test_convert(self): # s,v nt.assert_array_almost_equal(Quaternion([1, 0, 0, 0]).s, 1) @@ -854,15 +752,15 @@ def primitive_test_convert(self): def test_resulttype(self): q = Quaternion([2, 0, 0, 0]) - self.assertIsInstance(q, Quaternion) + assert isinstance(q, Quaternion) # other combos all fail, test this? - self.assertIsInstance(q.conj(), Quaternion) - self.assertIsInstance(q.unit(), UnitQuaternion) + assert isinstance(q.conj(), Quaternion) + assert isinstance(q.unit(), UnitQuaternion) - self.assertIsInstance(q + q, Quaternion) - self.assertIsInstance(q + q, Quaternion) + assert isinstance(q + q, Quaternion) + assert isinstance(q + q, Quaternion) def test_multiply(self): q1 = Quaternion([1, 2, 3, 4]) @@ -930,30 +828,11 @@ def test_equality(self): q1 = Quaternion([1, 2, 3, 4]) q2 = Quaternion([-2, 1, -4, 3]) - self.assertTrue(q1 == q1) - self.assertFalse(q1 == q2) - - self.assertTrue(q1 != q2) - self.assertFalse(q2 != q2) - - qt1 = Quaternion([q1, q1, q2, q2]) - qt2 = Quaternion([q1, q2, q2, q1]) + assert q1 == q1 + assert not q1 == q2 - self.assertEqual(qt1 == q1, [True, True, False, False]) - self.assertEqual(q1 == qt1, [True, True, False, False]) - self.assertEqual(qt1 == qt1, [True, True, True, True]) - - self.assertEqual(qt2 == q1, [True, False, False, True]) - self.assertEqual(q1 == qt2, [True, False, False, True]) - self.assertEqual(qt1 == qt2, [True, False, True, False]) - - self.assertEqual(qt1 != q1, [False, False, True, True]) - self.assertEqual(q1 != qt1, [False, False, True, True]) - self.assertEqual(qt1 != qt1, [False, False, False, False]) - - self.assertEqual(qt2 != q1, [False, True, True, False]) - self.assertEqual(q1 != qt2, [False, True, True, False]) - self.assertEqual(qt1 != qt2, [False, True, False, True]) + assert q1 != q2 + assert not q2 != q2 # errors @@ -1002,23 +881,17 @@ def test_miscellany(self): # norm nt.assert_array_almost_equal(q.norm(), np.linalg.norm(v)) - nt.assert_array_almost_equal( - Quaternion([q, u, q]).norm(), [np.linalg.norm(v), 1, np.linalg.norm(v)] - ) # unit qu = q.unit() - uu = UnitQuaternion() - self.assertIsInstance(q, Quaternion) + uu = UnitQuaternion.identity() + assert isinstance(q, Quaternion) nt.assert_array_almost_equal(qu.vec, v / np.linalg.norm(v)) - qcompare(Quaternion([q, u, q]).unit(), UnitQuaternion([qu, uu, qu])) # inner nt.assert_equal(u.inner(u), 1) nt.assert_equal(q.inner(q), q.norm() ** 2) nt.assert_equal(q.inner(u), np.dot(q.vec, u.vec)) - -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() + def test_identity(self): + qcompare(Quaternion.identity(), [0, 0, 0, 0]) diff --git a/tests/test_spatialvector.py b/tests/test_spatialvector.py index bca0f4c3..776f095d 100644 --- a/tests/test_spatialvector.py +++ b/tests/test_spatialvector.py @@ -1,190 +1,90 @@ -import unittest import numpy.testing as nt import numpy as np +import pytest from spatialmath.spatialvector import * -class TestSpatialVector(unittest.TestCase): - def test_list_powers(self): - x = SpatialVelocity.Empty() - self.assertEqual(len(x), 0) - x.append(SpatialVelocity([1, 2, 3, 4, 5, 6])) - self.assertEqual(len(x), 1) - - x.append(SpatialVelocity([7, 8, 9, 10, 11, 12])) - self.assertEqual(len(x), 2) - - y = x[0] - self.assertIsInstance(y, SpatialVelocity) - self.assertEqual(len(y), 1) - self.assertTrue(all(y.A == np.r_[1, 2, 3, 4, 5, 6])) - - y = x[1] - self.assertIsInstance(y, SpatialVelocity) - self.assertEqual(len(y), 1) - self.assertTrue(all(y.A == np.r_[7, 8, 9, 10, 11, 12])) - - x.insert(0, SpatialVelocity([20, 21, 22, 23, 24, 25])) - - y = x[0] - self.assertIsInstance(y, SpatialVelocity) - self.assertEqual(len(y), 1) - self.assertTrue(all(y.A == np.r_[20, 21, 22, 23, 24, 25])) - - y = x[1] - self.assertIsInstance(y, SpatialVelocity) - self.assertEqual(len(y), 1) - self.assertTrue(all(y.A == np.r_[1, 2, 3, 4, 5, 6])) - +class TestSpatialVector: def test_velocity(self): a = SpatialVelocity([1, 2, 3, 4, 5, 6]) - self.assertIsInstance(a, SpatialVelocity) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialM6) - self.assertEqual(len(a), 1) - self.assertTrue(all(a.A == np.r_[1, 2, 3, 4, 5, 6])) + assert isinstance(a, SpatialVelocity) + assert isinstance(a, SpatialVector) + assert isinstance(a, SpatialM6) + assert len(a) == 1 + assert all(a.A == np.r_[1, 2, 3, 4, 5, 6]) a = SpatialVelocity(np.r_[1, 2, 3, 4, 5, 6]) - self.assertIsInstance(a, SpatialVelocity) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialM6) - self.assertEqual(len(a), 1) - self.assertTrue(all(a.A == np.r_[1, 2, 3, 4, 5, 6])) - - s = str(a) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 0) - self.assertTrue(s.startswith("SpatialVelocity")) - - r = np.random.rand(6, 10) - a = SpatialVelocity(r) - self.assertIsInstance(a, SpatialVelocity) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialM6) - self.assertEqual(len(a), 10) - - b = a[3] - self.assertIsInstance(b, SpatialVelocity) - self.assertIsInstance(b, SpatialVector) - self.assertIsInstance(b, SpatialM6) - self.assertEqual(len(b), 1) - self.assertTrue(all(b.A == r[:, 3])) + assert isinstance(a, SpatialVelocity) + assert isinstance(a, SpatialVector) + assert isinstance(a, SpatialM6) + assert len(a) == 1 + assert all(a.A == np.r_[1, 2, 3, 4, 5, 6]) s = str(a) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 9) + assert isinstance(s, str) + assert s.count("\n") == 0 + assert s.startswith("SpatialVelocity") def test_acceleration(self): a = SpatialAcceleration([1, 2, 3, 4, 5, 6]) - self.assertIsInstance(a, SpatialAcceleration) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialM6) - self.assertEqual(len(a), 1) - self.assertTrue(all(a.A == np.r_[1, 2, 3, 4, 5, 6])) + assert isinstance(a, SpatialAcceleration) + assert isinstance(a, SpatialVector) + assert isinstance(a, SpatialM6) + assert len(a) == 1 + assert all(a.A == np.r_[1, 2, 3, 4, 5, 6]) a = SpatialAcceleration(np.r_[1, 2, 3, 4, 5, 6]) - self.assertIsInstance(a, SpatialAcceleration) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialM6) - self.assertEqual(len(a), 1) - self.assertTrue(all(a.A == np.r_[1, 2, 3, 4, 5, 6])) - - s = str(a) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 0) - self.assertTrue(s.startswith("SpatialAcceleration")) - - r = np.random.rand(6, 10) - a = SpatialAcceleration(r) - self.assertIsInstance(a, SpatialAcceleration) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialM6) - self.assertEqual(len(a), 10) - - b = a[3] - self.assertIsInstance(b, SpatialAcceleration) - self.assertIsInstance(b, SpatialVector) - self.assertIsInstance(b, SpatialM6) - self.assertEqual(len(b), 1) - self.assertTrue(all(b.A == r[:, 3])) + assert isinstance(a, SpatialAcceleration) + assert isinstance(a, SpatialVector) + assert isinstance(a, SpatialM6) + assert len(a) == 1 + assert all(a.A == np.r_[1, 2, 3, 4, 5, 6]) s = str(a) - self.assertIsInstance(s, str) + assert isinstance(s, str) + assert s.count("\n") == 0 + assert s.startswith("SpatialAcceleration") def test_force(self): a = SpatialForce([1, 2, 3, 4, 5, 6]) - self.assertIsInstance(a, SpatialForce) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialF6) - self.assertEqual(len(a), 1) - self.assertTrue(all(a.A == np.r_[1, 2, 3, 4, 5, 6])) + assert isinstance(a, SpatialForce) + assert isinstance(a, SpatialVector) + assert isinstance(a, SpatialF6) + assert len(a) == 1 + assert all(a.A == np.r_[1, 2, 3, 4, 5, 6]) a = SpatialForce(np.r_[1, 2, 3, 4, 5, 6]) - self.assertIsInstance(a, SpatialForce) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialF6) - self.assertEqual(len(a), 1) - self.assertTrue(all(a.A == np.r_[1, 2, 3, 4, 5, 6])) - - s = str(a) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 0) - self.assertTrue(s.startswith("SpatialForce")) - - r = np.random.rand(6, 10) - a = SpatialForce(r) - self.assertIsInstance(a, SpatialForce) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialF6) - self.assertEqual(len(a), 10) - - b = a[3] - self.assertIsInstance(b, SpatialForce) - self.assertIsInstance(b, SpatialVector) - self.assertIsInstance(b, SpatialF6) - self.assertEqual(len(b), 1) - self.assertTrue(all(b.A == r[:, 3])) + assert isinstance(a, SpatialForce) + assert isinstance(a, SpatialVector) + assert isinstance(a, SpatialF6) + assert len(a) == 1 + assert all(a.A == np.r_[1, 2, 3, 4, 5, 6]) s = str(a) - self.assertIsInstance(s, str) + assert isinstance(s, str) + assert s.count("\n") == 0 + assert s.startswith("SpatialForce") def test_momentum(self): a = SpatialMomentum([1, 2, 3, 4, 5, 6]) - self.assertIsInstance(a, SpatialMomentum) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialF6) - self.assertEqual(len(a), 1) - self.assertTrue(all(a.A == np.r_[1, 2, 3, 4, 5, 6])) + assert isinstance(a, SpatialMomentum) + assert isinstance(a, SpatialVector) + assert isinstance(a, SpatialF6) + assert len(a) == 1 + assert all(a.A == np.r_[1, 2, 3, 4, 5, 6]) a = SpatialMomentum(np.r_[1, 2, 3, 4, 5, 6]) - self.assertIsInstance(a, SpatialMomentum) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialF6) - self.assertEqual(len(a), 1) - self.assertTrue(all(a.A == np.r_[1, 2, 3, 4, 5, 6])) - - s = str(a) - self.assertIsInstance(s, str) - self.assertEqual(s.count("\n"), 0) - self.assertTrue(s.startswith("SpatialMomentum")) - - r = np.random.rand(6, 10) - a = SpatialMomentum(r) - self.assertIsInstance(a, SpatialMomentum) - self.assertIsInstance(a, SpatialVector) - self.assertIsInstance(a, SpatialF6) - self.assertEqual(len(a), 10) - - b = a[3] - self.assertIsInstance(b, SpatialMomentum) - self.assertIsInstance(b, SpatialVector) - self.assertIsInstance(b, SpatialF6) - self.assertEqual(len(b), 1) - self.assertTrue(all(b.A == r[:, 3])) + assert isinstance(a, SpatialMomentum) + assert isinstance(a, SpatialVector) + assert isinstance(a, SpatialF6) + assert len(a) == 1 + assert all(a.A == np.r_[1, 2, 3, 4, 5, 6]) s = str(a) - self.assertIsInstance(s, str) + assert isinstance(s, str) + assert s.count("\n") == 0 + assert s.startswith("SpatialMomentum") def test_arith(self): # just test SpatialVelocity since all types derive from same superclass @@ -194,15 +94,11 @@ def test_arith(self): a1 = SpatialVelocity(r1) a2 = SpatialVelocity(r2) - self.assertTrue(all((a1 + a2).A == r1 + r2)) - self.assertTrue(all((a1 - a2).A == r1 - r2)) - self.assertTrue(all((-a1).A == -r1)) + assert all((a1 + a2).A == r1 + r2) + assert all((a1 - a2).A == r1 - r2) + assert all((-a1).A == -r1) def test_inertia(self): - # constructor - i0 = SpatialInertia() - nt.assert_equal(i0.A, np.zeros((6, 6))) - i1 = SpatialInertia(np.eye(6, 6)) nt.assert_equal(i1.A, np.eye(6, 6)) @@ -219,7 +115,7 @@ def test_inertia(self): nt.assert_almost_equal((i4a + i4b).A, SpatialInertia(m=m_a + m_b, r=r).A) # isvalid - note this method is very barebone, to be improved - self.assertTrue(SpatialInertia().isvalid(np.ones((6, 6)), check=False)) + assert SpatialInertia.identity().isvalid(np.ones((6, 6)), check=False) def test_products(self): # v x v = a *, v x F6 = a @@ -228,7 +124,12 @@ def test_products(self): # twist x v, twist x a, twist x F pass + @pytest.mark.parametrize( + 'cls', + [SpatialVelocity, SpatialAcceleration, SpatialForce, SpatialMomentum], + ) + def test_identity(self, cls): + nt.assert_equal(cls.identity().A, np.zeros((6,))) -# ---------------------------------------------------------------------------------------# -if __name__ == "__main__": - unittest.main() + def test_spatial_inertia_identity(self): + nt.assert_equal(SpatialInertia.identity().A, np.zeros((6,6))) diff --git a/tests/test_spline.py b/tests/test_spline.py index 361bc28f..5d49f087 100644 --- a/tests/test_spline.py +++ b/tests/test_spline.py @@ -1,12 +1,13 @@ import numpy.testing as nt import numpy as np +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -import unittest from spatialmath import BSplineSE3, SE3, InterpSplineSE3, SplineFit, SO3 -class TestBSplineSE3(unittest.TestCase): +class TestBSplineSE3: control_poses = [ SE3.Trans([e, 2 * np.cos(e / 2 * np.pi), 2 * np.sin(e / 2 * np.pi)]) * SE3.Ry(e / 8 * np.pi) @@ -27,7 +28,7 @@ def test_evaluation(self): def test_visualize(self): spline = BSplineSE3(self.control_poses) - spline.visualize(sample_times= np.linspace(0, 1.0, 100), animate=True, repeat=False) + spline.visualize(sample_times= np.linspace(0, 1.0, 100), animate=True, repeat=False, movie=True) class TestInterpSplineSE3: waypoints = [ @@ -62,7 +63,7 @@ def test_small_delta_t(self): def test_visualize(self): spline = InterpSplineSE3(self.times, self.waypoints) - spline.visualize(sample_times= np.linspace(0, self.time_horizon, 100), animate=True, repeat=False) + spline.visualize(sample_times= np.linspace(0, self.time_horizon, 100), animate=True, repeat=False, movie=True) class TestSplineFit: @@ -92,4 +93,4 @@ def test_spline_fit(self): assert( fit.max_angular_error() < np.deg2rad(5.0) ) assert( fit.max_angular_error() < 0.1 ) - spline.visualize(sample_times= np.linspace(0, self.time_horizon, 100), animate=True, repeat=False) \ No newline at end of file + spline.visualize(sample_times= np.linspace(0, self.time_horizon, 100), animate=True, repeat=False, movie=True) diff --git a/tests/test_twist.py b/tests/test_twist.py index 12660c7d..d7386ad5 100755 --- a/tests/test_twist.py +++ b/tests/test_twist.py @@ -1,6 +1,7 @@ import numpy.testing as nt +import matplotlib +matplotlib.use("AGG") import matplotlib.pyplot as plt -import unittest """ we will assume that the primitives rotx,trotx, etc. all work @@ -25,14 +26,14 @@ def array_compare(x, y): nt.assert_array_almost_equal(x, y) -class Twist3dTest(unittest.TestCase): +class TestTwist3d: def test_constructor(self): s = [1, 2, 3, 4, 5, 6] x = Twist3(s) - self.assertIsInstance(x, Twist3) - self.assertEqual(len(x), 1) + assert isinstance(x, Twist3) + assert len(x) == 1 array_compare(x.v, [1, 2, 3]) array_compare(x.w, [4, 5, 6]) array_compare(x.S, s) @@ -45,26 +46,15 @@ def test_constructor(self): y = Twist3(x) array_compare(x, y) - x = Twist3(SE3()) + x = Twist3(SE3.identity()) array_compare(x, [0,0,0,0,0,0]) - - def test_list(self): - x = Twist3([1, 0, 0, 0, 0, 0]) - y = Twist3([1, 0, 0, 0, 0, 0]) - - a = Twist3(x) - a.append(y) - self.assertEqual(len(a), 2) - array_compare(a[0], x) - array_compare(a[1], y) - def test_conversion_SE3(self): T = SE3.Rx(0) tw = Twist3(T) array_compare(tw.SE3(), T) - self.assertIsInstance(tw.SE3(), SE3) - self.assertEqual(len(tw.SE3()), 1) + assert isinstance(tw.SE3(), SE3) + assert len(tw.SE3()) == 1 T = SE3.Rx(0) * SE3(1, 2, 3) array_compare(Twist3(T).SE3(), T) @@ -81,52 +71,26 @@ def test_conversion_se3(self): def test_conversion_Plucker(self): pass - def test_list_constuctor(self): - x = Twist3([1, 0, 0, 0, 0, 0]) - - a = Twist3([x,x,x,x]) - self.assertIsInstance(a, Twist3) - self.assertEqual(len(a), 4) - - a = Twist3([x.skewa(), x.skewa(), x.skewa(), x.skewa()]) - self.assertIsInstance(a, Twist3) - self.assertEqual(len(a), 4) - - a = Twist3([x.S, x.S, x.S, x.S]) - self.assertIsInstance(a, Twist3) - self.assertEqual(len(a), 4) - - s = np.r_[1, 2, 3, 4, 5, 6] - a = Twist3([s, s, s, s]) - self.assertIsInstance(a, Twist3) - self.assertEqual(len(a), 4) - def test_predicate(self): x = Twist3.UnitRevolute([1, 2, 3], [0, 0, 0]) - self.assertFalse(x.isprismatic) + assert not x.isprismatic # check prismatic twist x = Twist3.UnitPrismatic([1, 2, 3]) - self.assertTrue(x.isprismatic) + assert x.isprismatic - self.assertTrue(Twist3.isvalid(x.skewa())) - self.assertTrue(Twist3.isvalid(x.S)) + assert Twist3.isvalid(x.skewa()) + assert Twist3.isvalid(x.S) - self.assertFalse(Twist3.isvalid(2)) - self.assertFalse(Twist3.isvalid(np.eye(4))) + assert not Twist3.isvalid(2) + assert not Twist3.isvalid(np.eye(4)) def test_str(self): x = Twist3([1, 2, 3, 4, 5, 6]) s = str(x) - self.assertIsInstance(s, str) - self.assertEqual(len(s), 14) - self.assertEqual(s.count('\n'), 0) - - x.append(x) - s = str(x) - self.assertIsInstance(s, str) - self.assertEqual(len(s), 29) - self.assertEqual(s.count('\n'), 1) + assert isinstance(s, str) + assert len(s) == 14 + assert s.count('\n') == 0 def test_variant_constructors(self): @@ -188,18 +152,20 @@ def test_prod(self): x1 = Twist3(T1) x2 = Twist3(T2) - x = Twist3([x1, x2]) - array_compare( x.prod().SE3(), T1 * T2) - + array_compare(Twist3.prod([x1,x2]).SE3(), T1 * T2) + + def test_identity(self): + array_compare(Twist3.identity(), np.zeros(6,)) + -class Twist2dTest(unittest.TestCase): +class TestTwist2d: def test_constructor(self): s = [1, 2, 3] x = Twist2(s) - self.assertIsInstance(x, Twist2) - self.assertEqual(len(x), 1) + assert isinstance(x, Twist2) + assert len(x) == 1 array_compare(x.v, [1, 2]) array_compare(x.w, [3]) array_compare(x.S, s) @@ -213,7 +179,7 @@ def test_constructor(self): array_compare(x, y) # construct from SE2 - x = Twist2(SE2()) + x = Twist2(SE2.identity()) array_compare(x, [0,0,0]) x = Twist2( SE2(0, 0, pi / 2)) @@ -225,17 +191,6 @@ def test_constructor(self): x = Twist2( SE2(1, 2, pi / 2)) array_compare(x, np.r_[3 * pi / 4, pi / 4, pi / 2]) - - def test_list(self): - x = Twist2([1, 0, 0]) - y = Twist2([1, 0, 0]) - - a = Twist2(x) - a.append(y) - self.assertEqual(len(a), 2) - array_compare(a[0], x) - array_compare(a[1], y) - def test_variant_constructors(self): # check rotational twist @@ -250,8 +205,8 @@ def test_conversion_SE2(self): T = SE2(1, 2, 0.3) tw = Twist2(T) array_compare(tw.SE2(), T) - self.assertIsInstance(tw.SE2(), SE2) - self.assertEqual(len(tw.SE2()), 1) + assert isinstance(tw.SE2(), SE2) + assert len(tw.SE2()) == 1 def test_conversion_se2(self): s = [1, 2, 3] @@ -261,56 +216,29 @@ def test_conversion_se2(self): [ 3., 0., 2.], [ 0., 0., 0.]])) - def test_list_constuctor(self): - x = Twist2([1, 0, 0]) - - a = Twist2([x,x,x,x]) - self.assertIsInstance(a, Twist2) - self.assertEqual(len(a), 4) - - a = Twist2([x.skewa(), x.skewa(), x.skewa(), x.skewa()]) - self.assertIsInstance(a, Twist2) - self.assertEqual(len(a), 4) - - a = Twist2([x.S, x.S, x.S, x.S]) - self.assertIsInstance(a, Twist2) - self.assertEqual(len(a), 4) - - s = np.r_[1, 2, 3] - a = Twist2([s, s, s, s]) - self.assertIsInstance(a, Twist2) - self.assertEqual(len(a), 4) - def test_predicate(self): x = Twist2.UnitRevolute([1, 2]) - self.assertFalse(x.isprismatic) + assert not x.isprismatic # check prismatic twist x = Twist2.UnitPrismatic([1, 2]) - self.assertTrue(x.isprismatic) + assert x.isprismatic - self.assertTrue(Twist2.isvalid(x.skewa())) - self.assertTrue(Twist2.isvalid(x.S)) + assert Twist2.isvalid(x.skewa()) + assert Twist2.isvalid(x.S) - self.assertFalse(Twist2.isvalid(2)) - self.assertFalse(Twist2.isvalid(np.eye(3))) + assert not Twist2.isvalid(2) + assert not Twist2.isvalid(np.eye(3)) def test_str(self): x = Twist2([1, 2, 3]) s = str(x) - self.assertIsInstance(s, str) - self.assertEqual(len(s), 8) - self.assertEqual(s.count('\n'), 0) - - x.append(x) - s = str(x) - self.assertIsInstance(s, str) - self.assertEqual(len(s), 17) - self.assertEqual(s.count('\n'), 1) - + assert isinstance(s, str) + assert len(s) == 8 + assert s.count('\n') == 0 def test_SE2_twists(self): - tw = Twist2( SE2() ) + tw = Twist2( SE2.identity() ) array_compare(tw, np.r_[0, 0, 0]) tw = Twist2( SE2(0, 0, pi / 2) ) @@ -358,11 +286,7 @@ def test_prod(self): x1 = Twist2(T1) x2 = Twist2(T2) - x = Twist2([x1, x2]) - array_compare( x.prod().SE2(), T1 * T2) - -# ---------------------------------------------------------------------------------------# -if __name__ == '__main__': - + array_compare(Twist2.prod([x1, x2]).SE2(), T1 * T2) - unittest.main() + def test_identity(self): + array_compare(Twist3.identity(), np.zeros(6,)) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy