Skip to content

Commit 0a09bf9

Browse files
committed
fixing matrix derivative bugs
1 parent 5f01517 commit 0a09bf9

File tree

6 files changed

+63
-18
lines changed

6 files changed

+63
-18
lines changed

sympy/core/expr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3420,6 +3420,10 @@ def round(self, p=0):
34203420
allow += 1
34213421
return Float(rv, allow)
34223422

3423+
def _eval_derivative_matrix_lines(self, x):
3424+
from sympy.matrices.expressions.matexpr import _LeftRightArgs
3425+
return [_LeftRightArgs(S.One, S.One, higher=self._eval_derivative(x))]
3426+
34233427

34243428
class AtomicExpr(Atom, Expr):
34253429
"""
@@ -3440,9 +3444,9 @@ def _eval_derivative(self, s):
34403444

34413445
def _eval_derivative_n_times(self, s, n):
34423446
from sympy import Piecewise, Eq
3443-
from sympy import Tuple
3447+
from sympy import Tuple, MatrixExpr
34443448
from sympy.matrices.common import MatrixCommon
3445-
if isinstance(s, (MatrixCommon, Tuple, Iterable)):
3449+
if isinstance(s, (MatrixCommon, Tuple, Iterable, MatrixExpr)):
34463450
return super(AtomicExpr, self)._eval_derivative_n_times(s, n)
34473451
if self == s:
34483452
return Piecewise((self, Eq(n, 0)), (1, Eq(n, 1)), (0, True))

sympy/core/function.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,8 +1184,8 @@ def _diff_wrt(self):
11841184
def __new__(cls, expr, *variables, **kwargs):
11851185

11861186
from sympy.matrices.common import MatrixCommon
1187-
from sympy import Integer
1188-
from sympy.tensor.array import Array, NDimArray
1187+
from sympy import Integer, MatrixExpr
1188+
from sympy.tensor.array import Array, NDimArray, derive_by_array
11891189
from sympy.utilities.misc import filldedent
11901190

11911191
expr = sympify(expr)
@@ -1317,6 +1317,9 @@ def __new__(cls, expr, *variables, **kwargs):
13171317
if not expr.xreplace({v: D}).has(D):
13181318
zero = True
13191319
break
1320+
elif isinstance(v, MatrixExpr):
1321+
zero = False
1322+
break
13201323
elif isinstance(v, Symbol) and v not in free:
13211324
zero = True
13221325
break

sympy/core/mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ def _eval_derivative(self, s):
897897
# Note: reduce is used in step of Mul as Mul is unable to
898898
# handle subtypes and operation priority:
899899
terms.append(reduce(lambda x, y: x*y, (args[:i] + [d] + args[i + 1:]), S.One))
900-
return reduce(lambda x, y: x+y, terms, S.Zero)
900+
return Add.fromiter(terms)
901901

902902
@cacheit
903903
def _eval_derivative_n_times(self, s, n):

sympy/matrices/expressions/matexpr.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,23 @@ def _eval_derivative(self, x):
202202
def _eval_derivative_n_times(self, x, n):
203203
return Basic._eval_derivative_n_times(self, x, n)
204204

205+
def _visit_eval_derivative_scalar(self, x):
206+
# `x` is a scalar:
207+
if x.has(self):
208+
return _matrix_derivative(x, self)
209+
else:
210+
return ZeroMatrix(*self.shape)
211+
212+
def _visit_eval_derivative_array(self, x):
213+
if x.has(self):
214+
return _matrix_derivative(x, self)
215+
else:
216+
from sympy import Derivative
217+
return Derivative(x, self)
218+
219+
def _accept_eval_derivative(self, s):
220+
return s._visit_eval_derivative_array(self)
221+
205222
def _entry(self, i, j, **kwargs):
206223
raise NotImplementedError(
207224
"Indexing not implemented for %s" % self.__class__.__name__)
@@ -599,24 +616,17 @@ def _postprocessor(expr):
599616
"Add": [get_postprocessor(Add)],
600617
}
601618

619+
602620
def _matrix_derivative(expr, x):
603621
from sympy import Derivative
604622
lines = expr._eval_derivative_matrix_lines(x)
605623

606-
first = lines[0].first
607-
second = lines[0].second
608-
higher = lines[0].higher
609-
610624
ranks = [i.rank() for i in lines]
611625
assert len(set(ranks)) == 1
612626
rank = ranks[0]
613627

614628
if rank <= 2:
615-
return reduce(lambda x, y: x+y, [i.matrix_form() for i in lines])
616-
if first != 1:
617-
return reduce(lambda x,y: x+y, [lr.first * lr.second.T for lr in lines])
618-
elif higher != 1:
619-
return reduce(lambda x,y: x+y, [lr.higher for lr in lines])
629+
return Add.fromiter([i.matrix_form() for i in lines])
620630

621631
return Derivative(expr, x)
622632

@@ -748,8 +758,8 @@ def _eval_derivative_matrix_lines(self, x):
748758
transposed=False,
749759
)]
750760
else:
751-
first=Identity(self.shape[0])
752-
second=Identity(self.shape[1])
761+
first = Identity(self.shape[0])
762+
second = Identity(self.shape[1])
753763
return [_LeftRightArgs(
754764
first=first,
755765
second=second,
@@ -848,6 +858,7 @@ def __ne__(self, other):
848858
def __hash__(self):
849859
return super(GenericIdentity, self).__hash__()
850860

861+
851862
class ZeroMatrix(MatrixExpr):
852863
"""The Matrix Zero 0 - additive identity
853864
@@ -902,6 +913,7 @@ def __nonzero__(self):
902913

903914
__bool__ = __nonzero__
904915

916+
905917
class GenericZeroMatrix(ZeroMatrix):
906918
"""
907919
A zero matrix without a specified shape
@@ -937,6 +949,7 @@ def __ne__(self, other):
937949
def __hash__(self):
938950
return super(GenericZeroMatrix, self).__hash__()
939951

952+
940953
def matrix_symbols(expr):
941954
return [sym for sym in expr.free_symbols if sym.is_Matrix]
942955

@@ -976,6 +989,12 @@ def transpose(self):
976989
def matrix_form(self):
977990
if self.first != 1 and self.higher != 1:
978991
raise ValueError("higher dimensional array cannot be represented")
992+
# Remove one-dimensional identity matrices:
993+
# (this is needed by `a.diff(a)` where `a` is a vector)
994+
if self.first == Identity(1):
995+
return self.second.T
996+
if self.second == Identity(1):
997+
return self.first
979998
if self.first != 1:
980999
return self.first*self.second.T
9811000
else:

sympy/matrices/expressions/tests/test_derivatives.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def test_matrix_derivative_non_matrix_result():
3333

3434
def test_matrix_derivative_trivial_cases():
3535
# Cookbook example 33:
36-
assert X.diff(A) == 0
36+
# TODO: find a way to represent a four-dimensional zero-array:
37+
assert X.diff(A) == Derivative(X, A)
3738

3839

3940
def test_matrix_derivative_with_inverse():
@@ -58,6 +59,9 @@ def test_matrix_derivative_with_inverse():
5859

5960
def test_matrix_derivative_vectors_and_scalars():
6061

62+
assert x.diff(x) == Identity(k)
63+
assert x.T.diff(x) == Identity(k)
64+
6165
# Cookbook example 69:
6266
expr = x.T*a
6367
assert expr.diff(x) == a
@@ -246,9 +250,23 @@ def test_derivatives_of_complicated_matrix_expr():
246250

247251
def test_mixed_deriv_mixed_expressions():
248252

253+
expr = 3*Trace(A)
254+
assert expr.diff(A) == 3*Identity(k)
255+
256+
expr = k
257+
deriv = expr.diff(A)
258+
assert isinstance(deriv, ZeroMatrix)
259+
assert deriv == ZeroMatrix(k, k)
260+
261+
expr = Trace(A)**2
262+
assert expr.diff(A) == (2*Trace(A))*Identity(k)
263+
249264
expr = Trace(A)*A
250265
# TODO: this is not yet supported:
251266
assert expr.diff(A) == Derivative(expr, A)
252267

253268
expr = Trace(Trace(A)*A)
254269
assert expr.diff(A) == (2*Trace(A))*Identity(k)
270+
271+
expr = Trace(Trace(Trace(A)*A)*A)
272+
assert expr.diff(A) == (3*Trace(A)**2)*Identity(k)

sympy/matrices/tests/test_matrices.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2135,7 +2135,8 @@ def test_diff_by_matrix():
21352135
assert A.diff(a) == MutableDenseMatrix([[0, 0], [0, 0]])
21362136

21372137
B = ImmutableDenseMatrix([a, b])
2138-
assert A.diff(B) == A.zeros(2)
2138+
assert A.diff(B) == Array.zeros(2, 1, 2, 2)
2139+
assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]])
21392140

21402141
# Test diff with tuples:
21412142

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy