Skip to content

Commit 9bc7f8a

Browse files
authored
Merge pull request sympy#15121 from asmeurer/matexpr-cleanup
MatrixSymbol first arg is a Symbol (and other matrix expressions cleanup)
2 parents 5e4a199 + c1b2e1a commit 9bc7f8a

File tree

12 files changed

+53
-36
lines changed

12 files changed

+53
-36
lines changed

sympy/core/function.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2945,10 +2945,6 @@ def count_ops(expr, visual=False):
29452945
while args:
29462946
a = args.pop()
29472947

2948-
# XXX: This is a hack to support non-Basic args
2949-
if isinstance(a, string_types):
2950-
continue
2951-
29522948
if a.is_Rational:
29532949
#-1/3 = NEG + DIV
29542950
if a is not S.One:
@@ -3036,10 +3032,6 @@ def count_ops(expr, visual=False):
30363032
while args:
30373033
a = args.pop()
30383034

3039-
# XXX: This is a hack to support non-Basic args
3040-
if isinstance(a, string_types):
3041-
continue
3042-
30433035
if a.args:
30443036
o = Symbol(a.func.__name__.upper())
30453037
if a.is_Boolean:

sympy/core/tests/test_args.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2679,7 +2679,6 @@ def test_sympy__matrices__expressions__matexpr__MatrixElement():
26792679
from sympy import S
26802680
assert _test_args(MatrixElement(MatrixSymbol('A', 3, 5), S(2), S(3)))
26812681

2682-
@XFAIL
26832682
def test_sympy__matrices__expressions__matexpr__MatrixSymbol():
26842683
from sympy.matrices.expressions.matexpr import MatrixSymbol
26852684
assert _test_args(MatrixSymbol('A', 3, 5))

sympy/matrices/expressions/matadd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
from sympy.matrices.expressions.matexpr import (MatrixExpr, ShapeError,
1313
ZeroMatrix, GenericZeroMatrix)
1414
from sympy.utilities import default_sort_key, sift
15-
from sympy.core.operations import AssocOp
16-
1715

16+
# XXX: MatAdd should perhaps not subclass directly from Add
1817
class MatAdd(MatrixExpr, Add):
1918
"""A Sum of Matrix Expressions
2019

sympy/matrices/expressions/matexpr.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from sympy.core import S, Symbol, Tuple, Integer, Basic, Expr, Eq, Mul, Add
77
from sympy.core.decorators import call_highest_priority
8-
from sympy.core.compatibility import range, SYMPY_INTS, default_sort_key
9-
from sympy.core.sympify import SympifyError, sympify
8+
from sympy.core.compatibility import range, SYMPY_INTS, default_sort_key, string_types
9+
from sympy.core.sympify import SympifyError, _sympify
1010
from sympy.functions import conjugate, adjoint
1111
from sympy.functions.special.tensor_functions import KroneckerDelta
1212
from sympy.matrices import ShapeError
@@ -20,7 +20,7 @@ def deco(func):
2020
@wraps(func)
2121
def __sympifyit_wrapper(a, b):
2222
try:
23-
b = sympify(b, strict=True)
23+
b = _sympify(b)
2424
return func(a, b)
2525
except SympifyError:
2626
return retval
@@ -71,7 +71,7 @@ class MatrixExpr(Expr):
7171
is_symbol = False
7272

7373
def __new__(cls, *args, **kwargs):
74-
args = map(sympify, args)
74+
args = map(_sympify, args)
7575
return Basic.__new__(cls, *args, **kwargs)
7676

7777
# The following is adapted from the core Expr object
@@ -265,7 +265,7 @@ def __getitem__(self, key):
265265
if isinstance(i, slice) or isinstance(j, slice):
266266
from sympy.matrices.expressions.slice import MatrixSlice
267267
return MatrixSlice(self, i, j)
268-
i, j = sympify(i), sympify(j)
268+
i, j = _sympify(i), _sympify(j)
269269
if self.valid_index(i, j) != False:
270270
return self._entry(i, j)
271271
else:
@@ -278,7 +278,7 @@ def __getitem__(self, key):
278278
raise IndexError(filldedent('''
279279
Single indexing is only supported when the number
280280
of columns is known.'''))
281-
key = sympify(key)
281+
key = _sympify(key)
282282
i = key // cols
283283
j = key % cols
284284
if self.valid_index(i, j) != False:
@@ -640,12 +640,14 @@ class MatrixElement(Expr):
640640
is_commutative = True
641641

642642
def __new__(cls, name, n, m):
643-
n, m = map(sympify, (n, m))
643+
n, m = map(_sympify, (n, m))
644644
from sympy import MatrixBase
645645
if isinstance(name, (MatrixBase,)):
646646
if n.is_Integer and m.is_Integer:
647647
return name[n, m]
648-
name = sympify(name)
648+
if isinstance(name, string_types):
649+
name = Symbol(name)
650+
name = _sympify(name)
649651
obj = Expr.__new__(cls, name, n, m)
650652
return obj
651653

@@ -710,7 +712,9 @@ class MatrixSymbol(MatrixExpr):
710712
_diff_wrt = True
711713

712714
def __new__(cls, name, n, m):
713-
n, m = sympify(n), sympify(m)
715+
n, m = _sympify(n), _sympify(m)
716+
if isinstance(name, string_types):
717+
name = Symbol(name)
714718
obj = Basic.__new__(cls, name, n, m)
715719
return obj
716720

@@ -723,7 +727,7 @@ def shape(self):
723727

724728
@property
725729
def name(self):
726-
return self.args[0]
730+
return self.args[0].name
727731

728732
def _eval_subs(self, old, new):
729733
# only do substitutions in shape
@@ -783,7 +787,7 @@ class Identity(MatrixExpr):
783787
is_Identity = True
784788

785789
def __new__(cls, n):
786-
return super(Identity, cls).__new__(cls, sympify(n))
790+
return super(Identity, cls).__new__(cls, _sympify(n))
787791

788792
@property
789793
def rows(self):

sympy/matrices/expressions/matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import print_function, division
22

33
from sympy import Number
4-
from sympy.core import Mul, Basic, sympify, Add
4+
from sympy.core import Mul, Basic, sympify
55
from sympy.core.compatibility import range
66
from sympy.functions import adjoint
77
from sympy.matrices.expressions.transpose import transpose
@@ -12,7 +12,7 @@
1212
from sympy.matrices.expressions.matpow import MatPow
1313
from sympy.matrices.matrices import MatrixBase
1414

15-
15+
# XXX: MatMul should perhaps not subclass directly from Mul
1616
class MatMul(MatrixExpr, Mul):
1717
"""
1818
A product of matrix expressions

sympy/matrices/expressions/tests/test_matexpr.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from sympy import (KroneckerDelta, diff, Piecewise, Sum, Dummy, factor,
22
expand, zeros, gcd_terms, Eq)
33

4-
from sympy.core import S, symbols, Add, Mul
4+
from sympy.core import S, symbols, Add, Mul, SympifyError
55
from sympy.core.compatibility import long
6-
from sympy.functions import transpose, sin, cos, sqrt, cbrt
6+
from sympy.functions import transpose, sin, cos, sqrt, cbrt, exp
77
from sympy.simplify import simplify
88
from sympy.matrices import (Identity, ImmutableMatrix, Inverse, MatAdd, MatMul,
99
MatPow, Matrix, MatrixExpr, MatrixSymbol, ShapeError, ZeroMatrix,
@@ -504,3 +504,16 @@ def test_simplify_matrix_expressions():
504504
a = gcd_terms(2*C*D + 4*D*C)
505505
assert type(a) == MatMul
506506
assert a.args == (2, (C*D + 2*D*C))
507+
508+
def test_exp():
509+
A = MatrixSymbol('A', 2, 2)
510+
B = MatrixSymbol('B', 2, 2)
511+
expr1 = exp(A)*exp(B)
512+
expr2 = exp(B)*exp(A)
513+
assert expr1 != expr2
514+
assert expr1 - expr2 != 0
515+
assert not isinstance(expr1, exp)
516+
assert not isinstance(expr2, exp)
517+
518+
def test_invalid_args():
519+
raises(SympifyError, lambda: MatrixSymbol(1, 2, 'A'))

sympy/matrices/tests/test_matrices.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3003,7 +3003,7 @@ def test_pinv_rank_deficient():
30033003
@XFAIL
30043004
def test_pinv_rank_deficient_when_diagonalization_fails():
30053005
# Test the four properties of the pseudoinverse for matrices when
3006-
# diagonalization of A.H*A fails.'
3006+
# diagonalization of A.H*A fails.
30073007
As = [Matrix([
30083008
[61, 89, 55, 20, 71, 0],
30093009
[62, 96, 85, 85, 16, 0],

sympy/polys/polytools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6629,7 +6629,7 @@ def cancel(f, *gens, **args):
66296629
x.is_commutative is True and not x.has(Piecewise),
66306630
binary=True)
66316631
nc = [cancel(i) for i in nc]
6632-
return f.func(cancel(f.func._from_args(c)), *nc)
6632+
return f.func(cancel(f.func(*c)), *nc)
66336633
else:
66346634
reps = []
66356635
pot = preorder_traversal(f)

sympy/printing/repr.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ def _print_Add(self, expr, order=None):
4848
args = self._as_ordered_terms(expr, order=order)
4949
nargs = len(args)
5050
args = map(self._print, args)
51+
clsname = type(expr).__name__
5152
if nargs > 255: # Issue #10259, Python < 3.7
52-
return "Add(*[%s])" % ", ".join(args)
53-
return "Add(%s)" % ", ".join(args)
53+
return clsname + "(*[%s])" % ", ".join(args)
54+
return clsname + "(%s)" % ", ".join(args)
5455

5556
def _print_Cycle(self, expr):
5657
return expr.__repr__()
@@ -138,9 +139,10 @@ def _print_Mul(self, expr, order=None):
138139

139140
nargs = len(args)
140141
args = map(self._print, args)
142+
clsname = type(expr).__name__
141143
if nargs > 255: # Issue #10259, Python < 3.7
142-
return "Mul(*[%s])" % ", ".join(args)
143-
return "Mul(%s)" % ", ".join(args)
144+
return clsname + "(*[%s])" % ", ".join(args)
145+
return clsname + "(%s)" % ", ".join(args)
144146

145147
def _print_Rational(self, expr):
146148
return 'Rational(%s, %s)' % (self._print(expr.p), self._print(expr.q))

sympy/printing/tests/test_repr.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from sympy.utilities.pytest import raises
22
from sympy import (symbols, Function, Integer, Matrix, Abs,
33
Rational, Float, S, WildFunction, ImmutableDenseMatrix, sin, true, false, ones,
4-
sqrt, root, AlgebraicNumber, Symbol, Dummy, Wild)
4+
sqrt, root, AlgebraicNumber, Symbol, Dummy, Wild, MatrixSymbol)
55
from sympy.core.compatibility import exec_
66
from sympy.geometry import Point, Ellipse
77
from sympy.printing import srepr
@@ -274,3 +274,11 @@ def test_Naturals0():
274274

275275
def test_Reals():
276276
sT(S.Reals, "Reals")
277+
278+
def test_matrix_expressions():
279+
n = symbols('n', integer=True)
280+
A = MatrixSymbol("A", n, n)
281+
B = MatrixSymbol("B", n, n)
282+
sT(A, "MatrixSymbol(Symbol('A'), Symbol('n', integer=True), Symbol('n', integer=True))")
283+
sT(A*B, "MatMul(MatrixSymbol(Symbol('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Symbol('B'), Symbol('n', integer=True), Symbol('n', integer=True)))")
284+
sT(A + B, "MatAdd(MatrixSymbol(Symbol('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Symbol('B'), Symbol('n', integer=True), Symbol('n', integer=True)))")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy