Skip to content

Commit 226b942

Browse files
authored
Merge pull request sympy#19215 from jlherren/improve-2x2-block-inversion
Improve 2x2 block matrix inversion with off-diagonal invertible blocks
2 parents 0b42f52 + 6008c81 commit 226b942

File tree

4 files changed

+234
-39
lines changed

4 files changed

+234
-39
lines changed

sympy/assumptions/handlers/matrices.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,53 @@ def MatrixSlice(expr, assumptions):
155155
else:
156156
return ask(Q.invertible(expr.parent), assumptions)
157157

158+
@staticmethod
159+
def MatrixBase(expr, assumptions):
160+
if not expr.is_square:
161+
return False
162+
return expr.rank() == expr.rows
163+
164+
@staticmethod
165+
def MatrixExpr(expr, assumptions):
166+
if not expr.is_square:
167+
return False
168+
return None
169+
170+
@staticmethod
171+
def BlockMatrix(expr, assumptions):
172+
from sympy.matrices.expressions.blockmatrix import reblock_2x2
173+
if not expr.is_square:
174+
return False
175+
if expr.blockshape == (1, 1):
176+
return ask(Q.invertible(expr.blocks[0, 0]), assumptions)
177+
expr = reblock_2x2(expr)
178+
if expr.blockshape == (2, 2):
179+
[[A, B], [C, D]] = expr.blocks.tolist()
180+
if ask(Q.invertible(A), assumptions) == True:
181+
invertible = ask(Q.invertible(D - C * A.I * B), assumptions)
182+
if invertible is not None:
183+
return invertible
184+
if ask(Q.invertible(B), assumptions) == True:
185+
invertible = ask(Q.invertible(C - D * B.I * A), assumptions)
186+
if invertible is not None:
187+
return invertible
188+
if ask(Q.invertible(C), assumptions) == True:
189+
invertible = ask(Q.invertible(B - A * C.I * D), assumptions)
190+
if invertible is not None:
191+
return invertible
192+
if ask(Q.invertible(D), assumptions) == True:
193+
invertible = ask(Q.invertible(A - B * D.I * C), assumptions)
194+
if invertible is not None:
195+
return invertible
196+
return None
197+
198+
@staticmethod
199+
def BlockDiagMatrix(expr, assumptions):
200+
if expr.rowblocksizes != expr.colblocksizes:
201+
return None
202+
return fuzzy_and([ask(Q.invertible(a), assumptions) for a in expr.diag])
203+
204+
158205
class AskOrthogonalHandler(CommonHandler):
159206
"""
160207
Handler for key 'orthogonal'

sympy/assumptions/tests/test_matrices.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from sympy import Q, ask, Symbol, DiagMatrix, DiagonalMatrix
2+
from sympy.matrices.dense import Matrix
23
from sympy.matrices.expressions import (MatrixSymbol, Identity, ZeroMatrix,
3-
OneMatrix, Trace, MatrixSlice, Determinant)
4+
OneMatrix, Trace, MatrixSlice, Determinant, BlockMatrix, BlockDiagMatrix)
45
from sympy.matrices.expressions.factorizations import LofLU
56
from sympy.testing.pytest import XFAIL
67

@@ -43,6 +44,40 @@ def test_invertible_fullrank():
4344
assert ask(Q.invertible(X), Q.fullrank(X)) is True
4445

4546

47+
def test_invertible_BlockMatrix():
48+
assert ask(Q.invertible(BlockMatrix([Identity(3)]))) == True
49+
assert ask(Q.invertible(BlockMatrix([ZeroMatrix(3, 3)]))) == False
50+
51+
X = Matrix([[1, 2, 3], [3, 5, 4]])
52+
Y = Matrix([[4, 2, 7], [2, 3, 5]])
53+
# non-invertible A block
54+
assert ask(Q.invertible(BlockMatrix([
55+
[Matrix.ones(3, 3), Y.T],
56+
[X, Matrix.eye(2)],
57+
]))) == True
58+
# non-invertible B block
59+
assert ask(Q.invertible(BlockMatrix([
60+
[Y.T, Matrix.ones(3, 3)],
61+
[Matrix.eye(2), X],
62+
]))) == True
63+
# non-invertible C block
64+
assert ask(Q.invertible(BlockMatrix([
65+
[X, Matrix.eye(2)],
66+
[Matrix.ones(3, 3), Y.T],
67+
]))) == True
68+
# non-invertible D block
69+
assert ask(Q.invertible(BlockMatrix([
70+
[Matrix.eye(2), X],
71+
[Y.T, Matrix.ones(3, 3)],
72+
]))) == True
73+
74+
75+
def test_invertible_BlockDiagMatrix():
76+
assert ask(Q.invertible(BlockDiagMatrix(Identity(3), Identity(5)))) == True
77+
assert ask(Q.invertible(BlockDiagMatrix(ZeroMatrix(3, 3), Identity(5)))) == False
78+
assert ask(Q.invertible(BlockDiagMatrix(Identity(3), OneMatrix(5, 5)))) == False
79+
80+
4681
def test_symmetric():
4782
assert ask(Q.symmetric(X), Q.symmetric(X))
4883
assert ask(Q.symmetric(X*Z), Q.symmetric(X)) is None
@@ -236,7 +271,6 @@ def test_matrix_element_sets():
236271

237272

238273
def test_matrix_element_sets_slices_blocks():
239-
from sympy.matrices.expressions import BlockMatrix
240274
X = MatrixSymbol('X', 4, 4)
241275
assert ask(Q.integer_elements(X[:, 3]), Q.integer_elements(X))
242276
assert ask(Q.integer_elements(BlockMatrix([[X], [X]])),

sympy/matrices/expressions/blockmatrix.py

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -556,38 +556,71 @@ def blockinverse_1x1(expr):
556556
return BlockMatrix(mat)
557557
return expr
558558

559+
559560
def blockinverse_2x2(expr):
560-
if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (2, 2) and expr.arg.blocks[0, 0].is_square:
561-
# Cite: The Matrix Cookbook Section 9.1.3
561+
if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (2, 2):
562+
# See: Inverses of 2x2 Block Matrices, Tzon-Tzer Lu and Sheng-Hua Shiou
562563
[[A, B],
563564
[C, D]] = expr.arg.blocks.tolist()
564565

565-
# Use one or the other formula, depending on whether A or D is known to be invertible or at least not known
566-
# to not be invertible. Note that invertAbility of the other expressions M is not checked.
567-
A_invertible = ask(Q.invertible(A))
568-
D_invertible = ask(Q.invertible(D))
569-
if A_invertible == True:
570-
invert_A = True
571-
elif D_invertible == True:
572-
invert_A = False
573-
elif A_invertible != False:
574-
invert_A = True
575-
elif D_invertible != False:
576-
invert_A = False
577-
else:
578-
invert_A = True
566+
formula = _choose_2x2_inversion_formula(A, B, C, D)
579567

580-
if invert_A:
568+
if formula == 'A':
581569
AI = A.I
582570
MI = (D - C * AI * B).I
583571
return BlockMatrix([[AI + AI * B * MI * C * AI, -AI * B * MI], [-MI * C * AI, MI]])
584-
else:
572+
if formula == 'B':
573+
BI = B.I
574+
MI = (C - D * BI * A).I
575+
return BlockMatrix([[-MI * D * BI, MI], [BI + BI * A * MI * D * BI, -BI * A * MI]])
576+
if formula == 'C':
577+
CI = C.I
578+
MI = (B - A * CI * D).I
579+
return BlockMatrix([[-CI * D * MI, CI + CI * D * MI * A * CI], [MI, -MI * A * CI]])
580+
if formula == 'D':
585581
DI = D.I
586582
MI = (A - B * DI * C).I
587583
return BlockMatrix([[MI, -MI * B * DI], [-DI * C * MI, DI + DI * C * MI * B * DI]])
588584

589585
return expr
590586

587+
588+
def _choose_2x2_inversion_formula(A, B, C, D):
589+
"""
590+
Assuming [[A, B], [C, D]] would form a valid square block matrix, find
591+
which of the classical 2x2 block matrix inversion formulas would be
592+
best suited.
593+
594+
Returns 'A', 'B', 'C', 'D' to represent the algorithm involving inversion
595+
of the given argument or None if the matrix cannot be inverted using
596+
any of those formulas.
597+
"""
598+
# Try to find a known invertible matrix. Note that the Schur complement
599+
# is currently not being considered for this
600+
A_inv = ask(Q.invertible(A))
601+
if A_inv == True:
602+
return 'A'
603+
B_inv = ask(Q.invertible(B))
604+
if B_inv == True:
605+
return 'B'
606+
C_inv = ask(Q.invertible(C))
607+
if C_inv == True:
608+
return 'C'
609+
D_inv = ask(Q.invertible(D))
610+
if D_inv == True:
611+
return 'D'
612+
# Otherwise try to find a matrix that isn't known to be non-invertible
613+
if A_inv != False:
614+
return 'A'
615+
if B_inv != False:
616+
return 'B'
617+
if C_inv != False:
618+
return 'C'
619+
if D_inv != False:
620+
return 'D'
621+
return None
622+
623+
591624
def deblock(B):
592625
""" Flatten a BlockMatrix of BlockMatrices """
593626
if not isinstance(B, BlockMatrix) or not B.blocks.has(BlockMatrix):
@@ -609,15 +642,33 @@ def deblock(B):
609642
return B
610643

611644

612-
613-
def reblock_2x2(B):
614-
""" Reblock a BlockMatrix so that it has 2x2 blocks of block matrices """
615-
if not isinstance(B, BlockMatrix) or not all(d > 2 for d in B.blocks.shape):
616-
return B
645+
def reblock_2x2(expr):
646+
"""
647+
Reblock a BlockMatrix so that it has 2x2 blocks of block matrices. If
648+
possible in such a way that the matrix continues to be invertible using the
649+
classical 2x2 block inversion formulas.
650+
"""
651+
if not isinstance(expr, BlockMatrix) or not all(d > 2 for d in expr.blockshape):
652+
return expr
617653

618654
BM = BlockMatrix # for brevity's sake
619-
return BM([[ B.blocks[0, 0], BM(B.blocks[0, 1:])],
620-
[BM(B.blocks[1:, 0]), BM(B.blocks[1:, 1:])]])
655+
rowblocks, colblocks = expr.blockshape
656+
blocks = expr.blocks
657+
for i in range(1, rowblocks):
658+
for j in range(1, colblocks):
659+
# try to split rows at i and cols at j
660+
A = bc_unpack(BM(blocks[:i, :j]))
661+
B = bc_unpack(BM(blocks[:i, j:]))
662+
C = bc_unpack(BM(blocks[i:, :j]))
663+
D = bc_unpack(BM(blocks[i:, j:]))
664+
665+
formula = _choose_2x2_inversion_formula(A, B, C, D)
666+
if formula is not None:
667+
return BlockMatrix([[A, B], [C, D]])
668+
669+
# else: nothing worked, just split upper left corner
670+
return BM([[blocks[0, 0], BM(blocks[0, 1:])],
671+
[BM(blocks[1:, 0]), BM(blocks[1:, 1:])]])
621672

622673

623674
def bounds(sizes):

sympy/matrices/expressions/tests/test_blockmatrix.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from sympy import Trace
2-
from sympy.testing.pytest import raises
2+
from sympy.testing.pytest import raises, slow
33
from sympy.matrices.expressions.blockmatrix import (
44
block_collapse, bc_matmul, bc_block_plus_ident, BlockDiagMatrix,
55
BlockMatrix, bc_dist, bc_matadd, bc_transpose, bc_inverse,
66
blockcut, reblock_2x2, deblock)
77
from sympy.matrices.expressions import (MatrixSymbol, Identity,
8-
Inverse, trace, Transpose, det, ZeroMatrix, OneMatrix)
8+
Inverse, trace, Transpose, det, ZeroMatrix)
99
from sympy.matrices.common import NonInvertibleMatrixError
1010
from sympy.matrices import (
1111
Matrix, ImmutableMatrix, ImmutableSparseMatrix)
@@ -161,35 +161,98 @@ def test_squareBlockMatrix():
161161
Z = BlockMatrix([[Identity(n), B], [C, D]])
162162
assert not Z.is_Identity
163163

164-
def test_BlockMatrix_inverse():
164+
165+
def test_BlockMatrix_2x2_inverse_symbolic():
165166
A = MatrixSymbol('A', n, m)
166-
B = MatrixSymbol('B', n, n)
167-
C = MatrixSymbol('C', m, m)
168-
D = MatrixSymbol('D', m, n)
167+
B = MatrixSymbol('B', n, k - m)
168+
C = MatrixSymbol('C', k - n, m)
169+
D = MatrixSymbol('D', k - n, k - m)
169170
X = BlockMatrix([[A, B], [C, D]])
170-
assert X.is_square
171-
assert isinstance(block_collapse(X.inverse()), Inverse) # Can't inverse when A, D aren't square
171+
assert X.is_square and X.shape == (k, k)
172+
assert isinstance(block_collapse(X.I), Inverse) # Can't invert when none of the blocks is square
172173

173-
# test code path for non-invertible D matrix
174+
# test code path where only A is invertible
174175
A = MatrixSymbol('A', n, n)
175176
B = MatrixSymbol('B', n, m)
176177
C = MatrixSymbol('C', m, n)
177-
D = OneMatrix(m, m)
178+
D = ZeroMatrix(m, m)
178179
X = BlockMatrix([[A, B], [C, D]])
179180
assert block_collapse(X.inverse()) == BlockMatrix([
180181
[A.I + A.I * B * (D - C * A.I * B).I * C * A.I, -A.I * B * (D - C * A.I * B).I],
181182
[-(D - C * A.I * B).I * C * A.I, (D - C * A.I * B).I],
182183
])
183184

184-
# test code path for non-invertible A matrix
185-
A = OneMatrix(n, n)
185+
# test code path where only B is invertible
186+
A = MatrixSymbol('A', n, m)
187+
B = MatrixSymbol('B', n, n)
188+
C = ZeroMatrix(m, m)
189+
D = MatrixSymbol('D', m, n)
190+
X = BlockMatrix([[A, B], [C, D]])
191+
assert block_collapse(X.inverse()) == BlockMatrix(([
192+
[-(C - D * B.I * A).I * D * B.I, (C - D * B.I * A).I],
193+
[B.I + B.I * A * (C - D * B.I * A).I * D * B.I, -B.I * A * (C - D * B.I * A).I],
194+
]))
195+
196+
# test code path where only C is invertible
197+
A = MatrixSymbol('A', n, m)
198+
B = ZeroMatrix(n, n)
199+
C = MatrixSymbol('C', m, m)
200+
D = MatrixSymbol('D', m, n)
201+
X = BlockMatrix([[A, B], [C, D]])
202+
assert block_collapse(X.inverse()) == BlockMatrix([
203+
[-C.I * D * (B - A * C.I * D).I, C.I + C.I * D * (B - A * C.I * D).I * A * C.I],
204+
[(B - A * C.I * D).I, -(B - A * C.I * D).I * A * C.I],
205+
])
206+
207+
# test code path where only D is invertible
208+
A = ZeroMatrix(n, n)
209+
B = MatrixSymbol('B', n, m)
210+
C = MatrixSymbol('C', m, n)
186211
D = MatrixSymbol('D', m, m)
187212
X = BlockMatrix([[A, B], [C, D]])
188213
assert block_collapse(X.inverse()) == BlockMatrix([
189214
[(A - B * D.I * C).I, -(A - B * D.I * C).I * B * D.I],
190215
[-D.I * C * (A - B * D.I * C).I, D.I + D.I * C * (A - B * D.I * C).I * B * D.I],
191216
])
192217

218+
219+
def test_BlockMatrix_2x2_inverse_numeric():
220+
"""Test 2x2 block matrix inversion numerically for all 4 formulas"""
221+
M = Matrix([[1, 2], [3, 4]])
222+
# rank deficient matrices that have full rank when two of them combined
223+
D1 = Matrix([[1, 2], [2, 4]])
224+
D2 = Matrix([[1, 3], [3, 9]])
225+
D3 = Matrix([[1, 4], [4, 16]])
226+
assert D1.rank() == D2.rank() == D3.rank() == 1
227+
assert (D1 + D2).rank() == (D2 + D3).rank() == (D3 + D1).rank() == 2
228+
229+
# Only A is invertible
230+
K = BlockMatrix([[M, D1], [D2, D3]])
231+
assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
232+
# Only B is invertible
233+
K = BlockMatrix([[D1, M], [D2, D3]])
234+
assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
235+
# Only C is invertible
236+
K = BlockMatrix([[D1, D2], [M, D3]])
237+
assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
238+
# Only D is invertible
239+
K = BlockMatrix([[D1, D2], [D3, M]])
240+
assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv()
241+
242+
243+
@slow
244+
def test_BlockMatrix_3x3_symbolic():
245+
# Only test one of these, instead of all permutations, because it's slow
246+
rowblocksizes = (n, m, k)
247+
colblocksizes = (m, k, n)
248+
K = BlockMatrix([
249+
[MatrixSymbol('M%s%s' % (rows, cols), rows, cols) for cols in colblocksizes]
250+
for rows in rowblocksizes
251+
])
252+
collapse = block_collapse(K.I)
253+
assert isinstance(collapse, BlockMatrix)
254+
255+
193256
def test_BlockDiagMatrix():
194257
A = MatrixSymbol('A', n, n)
195258
B = MatrixSymbol('B', m, m)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy