@@ -900,6 +900,23 @@ def irregular(cls, ntop, *matrices, **kwargs):
900
900
rows .append (r )
901
901
return cls ._new (rows )
902
902
903
+ @classmethod
904
+ def _handle_ndarray (cls , arg ):
905
+ # NumPy array or matrix or some other object that implements
906
+ # __array__. So let's first use this method to get a
907
+ # numpy.array() and then make a python list out of it.
908
+ arr = arg .__array__ ()
909
+ if len (arr .shape ) == 2 :
910
+ rows , cols = arr .shape [0 ], arr .shape [1 ]
911
+ flat_list = [cls ._sympify (i ) for i in arr .ravel ()]
912
+ return rows , cols , flat_list
913
+ elif len (arr .shape ) == 1 :
914
+ flat_list = [cls ._sympify (i ) for i in arr ]
915
+ return arr .shape [0 ], 1 , flat_list
916
+ else :
917
+ raise NotImplementedError (
918
+ "SymPy supports just 1D and 2D matrices" )
919
+
903
920
@classmethod
904
921
def _handle_creation_inputs (cls , * args , ** kwargs ):
905
922
"""Return the number of rows, cols and flat matrix elements.
@@ -973,23 +990,7 @@ def _handle_creation_inputs(cls, *args, **kwargs):
973
990
974
991
# Matrix(numpy.ones((2, 2)))
975
992
elif hasattr (args [0 ], "__array__" ):
976
- # NumPy array or matrix or some other object that implements
977
- # __array__. So let's first use this method to get a
978
- # numpy.array() and then make a python list out of it.
979
- arr = args [0 ].__array__ ()
980
- if len (arr .shape ) == 2 :
981
- rows , cols = arr .shape [0 ], arr .shape [1 ]
982
- flat_list = [cls ._sympify (i ) for i in arr .ravel ()]
983
- return rows , cols , flat_list
984
- elif len (arr .shape ) == 1 :
985
- rows , cols = arr .shape [0 ], 1
986
- flat_list = [cls .zero ] * rows
987
- for i in range (len (arr )):
988
- flat_list [i ] = cls ._sympify (arr [i ])
989
- return rows , cols , flat_list
990
- else :
991
- raise NotImplementedError (
992
- "SymPy supports just 1D and 2D matrices" )
993
+ return cls ._handle_ndarray (args [0 ])
993
994
994
995
# Matrix([1, 2, 3]) or Matrix([[1, 2], [3, 4]])
995
996
elif is_sequence (args [0 ]) \
@@ -1064,13 +1065,19 @@ def do(x):
1064
1065
if not is_sequence (row ) and \
1065
1066
not getattr (row , 'is_Matrix' , False ):
1066
1067
raise ValueError ('expecting list of lists' )
1067
- if not row :
1068
+
1069
+ if hasattr (row , '__array__' ):
1070
+ if 0 in row .shape :
1071
+ continue
1072
+ elif not row :
1068
1073
continue
1074
+
1069
1075
if evaluate and all (ismat (i ) for i in row ):
1070
1076
r , c , flatT = cls ._handle_creation_inputs (
1071
1077
[i .T for i in row ])
1072
1078
T = reshape (flatT , [c ])
1073
- flat = [T [i ][j ] for j in range (c ) for i in range (r )]
1079
+ flat = \
1080
+ [T [i ][j ] for j in range (c ) for i in range (r )]
1074
1081
r , c = c , r
1075
1082
else :
1076
1083
r = 1
0 commit comments