@@ -144,29 +144,32 @@ class Plot(object):
144
144
- surface_color : function which returns a float.
145
145
"""
146
146
147
- def __init__ (self , * args , ** kwargs ):
147
+ def __init__ (self , * args ,
148
+ title = None , xlabel = None , ylabel = None , aspect_ratio = 'auto' ,
149
+ xlim = None , ylim = None , axis_center = 'auto' , axis = True ,
150
+ xscale = 'linear' , yscale = 'linear' , legend = False , autoscale = True ,
151
+ margin = 0 , annotations = None , markers = None , rectangles = None ,
152
+ fill = None , backend = 'default' , ** kwargs ):
148
153
super (Plot , self ).__init__ ()
149
154
150
155
# Options for the graph as a whole.
151
156
# The possible values for each option are described in the docstring of
152
157
# Plot. They are based purely on convention, no checking is done.
153
- self .title = None
154
- self .xlabel = None
155
- self .ylabel = None
156
- self .aspect_ratio = 'auto'
157
- self .xlim = None
158
- self .ylim = None
159
- self .axis_center = 'auto'
160
- self .axis = True
161
- self .xscale = 'linear'
162
- self .yscale = 'linear'
163
- self .legend = False
164
- self .autoscale = True
165
- self .margin = 0
166
- self .annotations = None
167
- self .markers = None
168
- self .rectangles = None
169
- self .fill = None
158
+ self .title = title
159
+ self .xlabel = xlabel
160
+ self .ylabel = ylabel
161
+ self .aspect_ratio = aspect_ratio
162
+ self .axis_center = axis_center
163
+ self .axis = axis
164
+ self .xscale = xscale
165
+ self .yscale = yscale
166
+ self .legend = legend
167
+ self .autoscale = autoscale
168
+ self .margin = margin
169
+ self .annotations = annotations
170
+ self .markers = markers
171
+ self .rectangles = rectangles
172
+ self .fill = fill
170
173
171
174
# Contains the data objects to be plotted. The backend should be smart
172
175
# enough to iterate over this list.
@@ -176,13 +179,32 @@ def __init__(self, *args, **kwargs):
176
179
# The backend type. On every show() a new backend instance is created
177
180
# in self._backend which is tightly coupled to the Plot instance
178
181
# (thanks to the parent attribute of the backend).
179
- self .backend = plot_backends [kwargs . pop ( ' backend' , 'default' ) ]
182
+ self .backend = plot_backends [backend ]
180
183
184
+ is_real = \
185
+ lambda lim : all (getattr (i , 'is_real' , True ) for i in lim )
186
+ is_finite = \
187
+ lambda lim : all (getattr (i , 'is_finite' , True ) for i in lim )
188
+
189
+ self .xlim = None
190
+ self .ylim = None
191
+ if xlim :
192
+ if not is_real (xlim ):
193
+ raise ValueError (
194
+ "All numbers from xlim={} must be real" .format (xlim ))
195
+ if not is_finite (xlim ):
196
+ raise ValueError (
197
+ "All numbers from xlim={} must be finite" .format (xlim ))
198
+ self .xlim = (float (xlim [0 ]), float (xlim [1 ]))
199
+ if ylim :
200
+ if not is_real (ylim ):
201
+ raise ValueError (
202
+ "All numbers from ylim={} must be real" .format (ylim ))
203
+ if not is_finite (ylim ):
204
+ raise ValueError (
205
+ "All numbers from ylim={} must be finite" .format (ylim ))
206
+ self .ylim = (float (ylim [0 ]), float (ylim [1 ]))
181
207
182
- # The keyword arguments should only contain options for the plot.
183
- for key , val in kwargs .items ():
184
- if hasattr (self , key ):
185
- setattr (self , key , val )
186
208
187
209
def show (self ):
188
210
# TODO move this to the backend (also for save)
@@ -847,14 +869,28 @@ def get_parameter_points(self):
847
869
return np .linspace (self .start , self .end , num = self .nb_of_points )
848
870
849
871
def get_points (self ):
872
+ np = import_module ('numpy' )
850
873
param = self .get_parameter_points ()
851
874
fx = vectorized_lambdify ([self .var ], self .expr_x )
852
875
fy = vectorized_lambdify ([self .var ], self .expr_y )
853
876
fz = vectorized_lambdify ([self .var ], self .expr_z )
877
+
854
878
list_x = fx (param )
855
879
list_y = fy (param )
856
880
list_z = fz (param )
857
- return (list_x , list_y , list_z )
881
+
882
+ list_x = np .array (list_x , dtype = np .float64 )
883
+ list_y = np .array (list_y , dtype = np .float64 )
884
+ list_z = np .array (list_z , dtype = np .float64 )
885
+
886
+ list_x = np .ma .masked_invalid (list_x )
887
+ list_y = np .ma .masked_invalid (list_y )
888
+ list_z = np .ma .masked_invalid (list_z )
889
+
890
+ self ._xlim = (np .amin (list_x ), np .amax (list_x ))
891
+ self ._ylim = (np .amin (list_y ), np .amax (list_y ))
892
+ self ._zlim = (np .amin (list_z ), np .amax (list_z ))
893
+ return list_x , list_y , list_z
858
894
859
895
860
896
### Surfaces
@@ -906,6 +942,9 @@ def __init__(self, expr, var_start_end_x, var_start_end_y, **kwargs):
906
942
self .nb_of_points_y = kwargs .get ('nb_of_points_y' , 50 )
907
943
self .surface_color = kwargs .get ('surface_color' , None )
908
944
945
+ self ._xlim = (self .start_x , self .end_x )
946
+ self ._ylim = (self .start_y , self .end_y )
947
+
909
948
def __str__ (self ):
910
949
return ('cartesian surface: %s for'
911
950
' %s over %s and %s over %s' ) % (
@@ -922,7 +961,11 @@ def get_meshes(self):
922
961
np .linspace (self .start_y , self .end_y ,
923
962
num = self .nb_of_points_y ))
924
963
f = vectorized_lambdify ((self .var_x , self .var_y ), self .expr )
925
- return (mesh_x , mesh_y , f (mesh_x , mesh_y ))
964
+ mesh_z = f (mesh_x , mesh_y )
965
+ mesh_z = np .array (mesh_z , dtype = np .float64 )
966
+ mesh_z = np .ma .masked_invalid (mesh_z )
967
+ self ._zlim = (np .amin (mesh_z ), np .amax (mesh_z ))
968
+ return mesh_x , mesh_y , mesh_z
926
969
927
970
928
971
class ParametricSurfaceSeries (SurfaceBaseSeries ):
@@ -967,11 +1010,30 @@ def get_parameter_meshes(self):
967
1010
num = self .nb_of_points_v ))
968
1011
969
1012
def get_meshes (self ):
1013
+ np = import_module ('numpy' )
1014
+
970
1015
mesh_u , mesh_v = self .get_parameter_meshes ()
971
1016
fx = vectorized_lambdify ((self .var_u , self .var_v ), self .expr_x )
972
1017
fy = vectorized_lambdify ((self .var_u , self .var_v ), self .expr_y )
973
1018
fz = vectorized_lambdify ((self .var_u , self .var_v ), self .expr_z )
974
- return (fx (mesh_u , mesh_v ), fy (mesh_u , mesh_v ), fz (mesh_u , mesh_v ))
1019
+
1020
+ mesh_x = fx (mesh_u , mesh_v )
1021
+ mesh_y = fy (mesh_u , mesh_v )
1022
+ mesh_z = fz (mesh_u , mesh_v )
1023
+
1024
+ mesh_x = np .array (mesh_x , dtype = np .float64 )
1025
+ mesh_y = np .array (mesh_y , dtype = np .float64 )
1026
+ mesh_z = np .array (mesh_z , dtype = np .float64 )
1027
+
1028
+ mesh_x = np .ma .masked_invalid (mesh_x )
1029
+ mesh_y = np .ma .masked_invalid (mesh_y )
1030
+ mesh_z = np .ma .masked_invalid (mesh_z )
1031
+
1032
+ self ._xlim = (np .amin (mesh_x ), np .amax (mesh_x ))
1033
+ self ._ylim = (np .amin (mesh_y ), np .amax (mesh_y ))
1034
+ self ._zlim = (np .amin (mesh_z ), np .amax (mesh_z ))
1035
+
1036
+ return mesh_x , mesh_y , mesh_z
975
1037
976
1038
977
1039
### Contours
@@ -996,6 +1058,9 @@ def __init__(self, expr, var_start_end_x, var_start_end_y):
996
1058
997
1059
self .get_points = self .get_meshes
998
1060
1061
+ self ._xlim = (self .start_x , self .end_x )
1062
+ self ._ylim = (self .start_y , self .end_y )
1063
+
999
1064
def __str__ (self ):
1000
1065
return ('contour: %s for '
1001
1066
'%s over %s and %s over %s' ) % (
@@ -1068,12 +1133,18 @@ def __init__(self, parent):
1068
1133
self .ax [i ].spines ['right' ].set_color ('none' )
1069
1134
self .ax [i ].spines ['bottom' ].set_position ('zero' )
1070
1135
self .ax [i ].spines ['top' ].set_color ('none' )
1071
- self .ax [i ].spines ['left' ].set_smart_bounds (True )
1072
- self .ax [i ].spines ['bottom' ].set_smart_bounds (False )
1073
1136
self .ax [i ].xaxis .set_ticks_position ('bottom' )
1074
1137
self .ax [i ].yaxis .set_ticks_position ('left' )
1075
1138
1076
1139
def _process_series (self , series , ax , parent ):
1140
+ np = import_module ('numpy' )
1141
+ mpl_toolkits = import_module (
1142
+ 'mpl_toolkits' , import_kwargs = {'fromlist' : ['mplot3d' ]})
1143
+
1144
+ # XXX Workaround for matplotlib issue
1145
+ # https://github.com/matplotlib/matplotlib/issues/17130
1146
+ xlims , ylims , zlims = [], [], []
1147
+
1077
1148
for s in series :
1078
1149
# Create the collections
1079
1150
if s .is_2Dline :
@@ -1083,24 +1154,22 @@ def _process_series(self, series, ax, parent):
1083
1154
ax .contour (* s .get_meshes ())
1084
1155
elif s .is_3Dline :
1085
1156
# TODO too complicated, I blame matplotlib
1086
- mpl_toolkits = import_module ('mpl_toolkits' ,
1087
- import_kwargs = {'fromlist' : ['mplot3d' ]})
1088
1157
art3d = mpl_toolkits .mplot3d .art3d
1089
1158
collection = art3d .Line3DCollection (s .get_segments ())
1090
1159
ax .add_collection (collection )
1091
1160
x , y , z = s .get_points ()
1092
- ax . set_xlim (( min ( x ), max ( x )) )
1093
- ax . set_ylim (( min ( y ), max ( y )) )
1094
- ax . set_zlim (( min ( z ), max ( z )) )
1161
+ xlims . append ( s . _xlim )
1162
+ ylims . append ( s . _ylim )
1163
+ zlims . append ( s . _zlim )
1095
1164
elif s .is_3Dsurface :
1096
1165
x , y , z = s .get_meshes ()
1097
1166
collection = ax .plot_surface (x , y , z ,
1098
1167
cmap = getattr (self .cm , 'viridis' , self .cm .jet ),
1099
1168
rstride = 1 , cstride = 1 , linewidth = 0.1 )
1169
+ xlims .append (s ._xlim )
1170
+ ylims .append (s ._ylim )
1171
+ zlims .append (s ._zlim )
1100
1172
elif s .is_implicit :
1101
- # Smart bounds have to be set to False for implicit plots.
1102
- ax .spines ['left' ].set_smart_bounds (False )
1103
- ax .spines ['bottom' ].set_smart_bounds (False )
1104
1173
points = s .get_raster ()
1105
1174
if len (points ) == 2 :
1106
1175
# interval math plotting
@@ -1118,9 +1187,10 @@ def _process_series(self, series, ax, parent):
1118
1187
else :
1119
1188
ax .contourf (xarray , yarray , zarray , cmap = colormap )
1120
1189
else :
1121
- raise ValueError ('The matplotlib backend supports only '
1122
- 'is_2Dline, is_3Dline, is_3Dsurface and '
1123
- 'is_contour objects.' )
1190
+ raise NotImplementedError (
1191
+ '{} is not supported in the sympy plotting module '
1192
+ 'with matplotlib backend. Please report this issue.'
1193
+ .format (ax ))
1124
1194
1125
1195
# Customise the collections with the corresponding per-series
1126
1196
# options.
@@ -1142,12 +1212,38 @@ def _process_series(self, series, ax, parent):
1142
1212
else :
1143
1213
collection .set_color (s .surface_color )
1144
1214
1215
+ Axes3D = mpl_toolkits .mplot3d .Axes3D
1216
+ if not isinstance (ax , Axes3D ):
1217
+ ax .autoscale_view (
1218
+ scalex = ax .get_autoscalex_on (),
1219
+ scaley = ax .get_autoscaley_on ())
1220
+ else :
1221
+ # XXX Workaround for matplotlib issue
1222
+ # https://github.com/matplotlib/matplotlib/issues/17130
1223
+ if xlims :
1224
+ xlims = np .array (xlims )
1225
+ xlim = (np .amin (xlims [:, 0 ]), np .amax (xlims [:, 1 ]))
1226
+ ax .set_xlim (xlim )
1227
+ else :
1228
+ ax .set_xlim ([0 , 1 ])
1229
+
1230
+ if ylims :
1231
+ ylims = np .array (ylims )
1232
+ ylim = (np .amin (ylims [:, 0 ]), np .amax (ylims [:, 1 ]))
1233
+ ax .set_ylim (ylim )
1234
+ else :
1235
+ ax .set_ylim ([0 , 1 ])
1236
+
1237
+ if zlims :
1238
+ zlims = np .array (zlims )
1239
+ zlim = (np .amin (zlims [:, 0 ]), np .amax (zlims [:, 1 ]))
1240
+ ax .set_zlim (zlim )
1241
+ else :
1242
+ ax .set_zlim ([0 , 1 ])
1243
+
1145
1244
# Set global options.
1146
1245
# TODO The 3D stuff
1147
1246
# XXX The order of those is important.
1148
- mpl_toolkits = import_module ('mpl_toolkits' ,
1149
- import_kwargs = {'fromlist' : ['mplot3d' ]})
1150
- Axes3D = mpl_toolkits .mplot3d .Axes3D
1151
1247
if parent .xscale and not isinstance (ax , Axes3D ):
1152
1248
ax .set_xscale (parent .xscale )
1153
1249
if parent .yscale and not isinstance (ax , Axes3D ):
@@ -1205,38 +1301,9 @@ def _process_series(self, series, ax, parent):
1205
1301
# xlim and ylim shoulld always be set at last so that plot limits
1206
1302
# doesn't get altered during the process.
1207
1303
if parent .xlim :
1208
- from sympy .core .basic import Basic
1209
- xlim = parent .xlim
1210
- if any (isinstance (i , Basic ) and not i .is_real for i in xlim ):
1211
- raise ValueError (
1212
- "All numbers from xlim={} must be real" .format (xlim ))
1213
- if any (isinstance (i , Basic ) and not i .is_finite for i in xlim ):
1214
- raise ValueError (
1215
- "All numbers from xlim={} must be finite" .format (xlim ))
1216
- xlim = (float (i ) for i in xlim )
1217
- ax .set_xlim (xlim )
1218
- else :
1219
- if parent ._series and all (isinstance (s , LineOver1DRangeSeries ) for s in parent ._series ):
1220
- starts = [s .start for s in parent ._series ]
1221
- ends = [s .end for s in parent ._series ]
1222
- ax .set_xlim (min (starts ), max (ends ))
1223
-
1304
+ ax .set_xlim (parent .xlim )
1224
1305
if parent .ylim :
1225
- from sympy .core .basic import Basic
1226
- ylim = parent .ylim
1227
- if any (isinstance (i ,Basic ) and not i .is_real for i in ylim ):
1228
- raise ValueError (
1229
- "All numbers from ylim={} must be real" .format (ylim ))
1230
- if any (isinstance (i ,Basic ) and not i .is_finite for i in ylim ):
1231
- raise ValueError (
1232
- "All numbers from ylim={} must be finite" .format (ylim ))
1233
- ylim = (float (i ) for i in ylim )
1234
- ax .set_ylim (ylim )
1235
-
1236
- if not isinstance (ax , Axes3D ):
1237
- ax .autoscale_view (
1238
- scalex = ax .get_autoscalex_on (),
1239
- scaley = ax .get_autoscaley_on ())
1306
+ ax .set_ylim (parent .ylim )
1240
1307
1241
1308
1242
1309
def process_series (self ):
0 commit comments