Skip to content

Commit a95fc3e

Browse files
committed
Manually set bounding box for 3d plots
1 parent 68e9626 commit a95fc3e

File tree

4 files changed

+559
-466
lines changed

4 files changed

+559
-466
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ matrix:
5353
# all of the dependencies are supported on 3.8.
5454
env:
5555
- TEST_ASCII="true"
56-
- TEST_OPT_DEPENDENCY="matchpy numpy scipy gmpy2 matplotlib<3.2 theano llvmlite autowrap cython wurlitzer python-symengine=0.5.1 tensorflow numexpr ipython antlr-python-runtime>=4.7,<4.8 antlr>=4.7,<4.8 cloudpickle pyglet pycosat lfortran python-clang lxml"
56+
- TEST_OPT_DEPENDENCY="matchpy numpy scipy gmpy2 matplotlib theano llvmlite autowrap cython wurlitzer python-symengine=0.5.1 tensorflow numexpr ipython antlr-python-runtime>=4.7,<4.8 antlr>=4.7,<4.8 cloudpickle pyglet pycosat lfortran python-clang lxml"
5757
- TEST_SAGE="true"
5858
- SYMPY_STRICT_COMPILER_CHECKS=1
5959
addons:

sympy/plotting/plot.py

Lines changed: 140 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -144,29 +144,32 @@ class Plot(object):
144144
- surface_color : function which returns a float.
145145
"""
146146

147-
def __init__(self, *args, **kwargs):
147+
def __init__(self, *args,
148+
title=None, xlabel=None, ylabel=None, aspect_ratio='auto',
149+
xlim=None, ylim=None, axis_center='auto', axis=True,
150+
xscale='linear', yscale='linear', legend=False, autoscale=True,
151+
margin=0, annotations=None, markers=None, rectangles=None,
152+
fill=None, backend='default', **kwargs):
148153
super(Plot, self).__init__()
149154

150155
# Options for the graph as a whole.
151156
# The possible values for each option are described in the docstring of
152157
# Plot. They are based purely on convention, no checking is done.
153-
self.title = None
154-
self.xlabel = None
155-
self.ylabel = None
156-
self.aspect_ratio = 'auto'
157-
self.xlim = None
158-
self.ylim = None
159-
self.axis_center = 'auto'
160-
self.axis = True
161-
self.xscale = 'linear'
162-
self.yscale = 'linear'
163-
self.legend = False
164-
self.autoscale = True
165-
self.margin = 0
166-
self.annotations = None
167-
self.markers = None
168-
self.rectangles = None
169-
self.fill = None
158+
self.title = title
159+
self.xlabel = xlabel
160+
self.ylabel = ylabel
161+
self.aspect_ratio = aspect_ratio
162+
self.axis_center = axis_center
163+
self.axis = axis
164+
self.xscale = xscale
165+
self.yscale = yscale
166+
self.legend = legend
167+
self.autoscale = autoscale
168+
self.margin = margin
169+
self.annotations = annotations
170+
self.markers = markers
171+
self.rectangles = rectangles
172+
self.fill = fill
170173

171174
# Contains the data objects to be plotted. The backend should be smart
172175
# enough to iterate over this list.
@@ -176,13 +179,32 @@ def __init__(self, *args, **kwargs):
176179
# The backend type. On every show() a new backend instance is created
177180
# in self._backend which is tightly coupled to the Plot instance
178181
# (thanks to the parent attribute of the backend).
179-
self.backend = plot_backends[kwargs.pop('backend', 'default')]
182+
self.backend = plot_backends[backend]
180183

184+
is_real = \
185+
lambda lim: all(getattr(i, 'is_real', True) for i in lim)
186+
is_finite = \
187+
lambda lim: all(getattr(i, 'is_finite', True) for i in lim)
188+
189+
self.xlim = None
190+
self.ylim = None
191+
if xlim:
192+
if not is_real(xlim):
193+
raise ValueError(
194+
"All numbers from xlim={} must be real".format(xlim))
195+
if not is_finite(xlim):
196+
raise ValueError(
197+
"All numbers from xlim={} must be finite".format(xlim))
198+
self.xlim = (float(xlim[0]), float(xlim[1]))
199+
if ylim:
200+
if not is_real(ylim):
201+
raise ValueError(
202+
"All numbers from ylim={} must be real".format(ylim))
203+
if not is_finite(ylim):
204+
raise ValueError(
205+
"All numbers from ylim={} must be finite".format(ylim))
206+
self.ylim = (float(ylim[0]), float(ylim[1]))
181207

182-
# The keyword arguments should only contain options for the plot.
183-
for key, val in kwargs.items():
184-
if hasattr(self, key):
185-
setattr(self, key, val)
186208

187209
def show(self):
188210
# TODO move this to the backend (also for save)
@@ -847,14 +869,28 @@ def get_parameter_points(self):
847869
return np.linspace(self.start, self.end, num=self.nb_of_points)
848870

849871
def get_points(self):
872+
np = import_module('numpy')
850873
param = self.get_parameter_points()
851874
fx = vectorized_lambdify([self.var], self.expr_x)
852875
fy = vectorized_lambdify([self.var], self.expr_y)
853876
fz = vectorized_lambdify([self.var], self.expr_z)
877+
854878
list_x = fx(param)
855879
list_y = fy(param)
856880
list_z = fz(param)
857-
return (list_x, list_y, list_z)
881+
882+
list_x = np.array(list_x, dtype=np.float64)
883+
list_y = np.array(list_y, dtype=np.float64)
884+
list_z = np.array(list_z, dtype=np.float64)
885+
886+
list_x = np.ma.masked_invalid(list_x)
887+
list_y = np.ma.masked_invalid(list_y)
888+
list_z = np.ma.masked_invalid(list_z)
889+
890+
self._xlim = (np.amin(list_x), np.amax(list_x))
891+
self._ylim = (np.amin(list_y), np.amax(list_y))
892+
self._zlim = (np.amin(list_z), np.amax(list_z))
893+
return list_x, list_y, list_z
858894

859895

860896
### Surfaces
@@ -906,6 +942,9 @@ def __init__(self, expr, var_start_end_x, var_start_end_y, **kwargs):
906942
self.nb_of_points_y = kwargs.get('nb_of_points_y', 50)
907943
self.surface_color = kwargs.get('surface_color', None)
908944

945+
self._xlim = (self.start_x, self.end_x)
946+
self._ylim = (self.start_y, self.end_y)
947+
909948
def __str__(self):
910949
return ('cartesian surface: %s for'
911950
' %s over %s and %s over %s') % (
@@ -922,7 +961,11 @@ def get_meshes(self):
922961
np.linspace(self.start_y, self.end_y,
923962
num=self.nb_of_points_y))
924963
f = vectorized_lambdify((self.var_x, self.var_y), self.expr)
925-
return (mesh_x, mesh_y, f(mesh_x, mesh_y))
964+
mesh_z = f(mesh_x, mesh_y)
965+
mesh_z = np.array(mesh_z, dtype=np.float64)
966+
mesh_z = np.ma.masked_invalid(mesh_z)
967+
self._zlim = (np.amin(mesh_z), np.amax(mesh_z))
968+
return mesh_x, mesh_y, mesh_z
926969

927970

928971
class ParametricSurfaceSeries(SurfaceBaseSeries):
@@ -967,11 +1010,30 @@ def get_parameter_meshes(self):
9671010
num=self.nb_of_points_v))
9681011

9691012
def get_meshes(self):
1013+
np = import_module('numpy')
1014+
9701015
mesh_u, mesh_v = self.get_parameter_meshes()
9711016
fx = vectorized_lambdify((self.var_u, self.var_v), self.expr_x)
9721017
fy = vectorized_lambdify((self.var_u, self.var_v), self.expr_y)
9731018
fz = vectorized_lambdify((self.var_u, self.var_v), self.expr_z)
974-
return (fx(mesh_u, mesh_v), fy(mesh_u, mesh_v), fz(mesh_u, mesh_v))
1019+
1020+
mesh_x = fx(mesh_u, mesh_v)
1021+
mesh_y = fy(mesh_u, mesh_v)
1022+
mesh_z = fz(mesh_u, mesh_v)
1023+
1024+
mesh_x = np.array(mesh_x, dtype=np.float64)
1025+
mesh_y = np.array(mesh_y, dtype=np.float64)
1026+
mesh_z = np.array(mesh_z, dtype=np.float64)
1027+
1028+
mesh_x = np.ma.masked_invalid(mesh_x)
1029+
mesh_y = np.ma.masked_invalid(mesh_y)
1030+
mesh_z = np.ma.masked_invalid(mesh_z)
1031+
1032+
self._xlim = (np.amin(mesh_x), np.amax(mesh_x))
1033+
self._ylim = (np.amin(mesh_y), np.amax(mesh_y))
1034+
self._zlim = (np.amin(mesh_z), np.amax(mesh_z))
1035+
1036+
return mesh_x, mesh_y, mesh_z
9751037

9761038

9771039
### Contours
@@ -996,6 +1058,9 @@ def __init__(self, expr, var_start_end_x, var_start_end_y):
9961058

9971059
self.get_points = self.get_meshes
9981060

1061+
self._xlim = (self.start_x, self.end_x)
1062+
self._ylim = (self.start_y, self.end_y)
1063+
9991064
def __str__(self):
10001065
return ('contour: %s for '
10011066
'%s over %s and %s over %s') % (
@@ -1068,12 +1133,18 @@ def __init__(self, parent):
10681133
self.ax[i].spines['right'].set_color('none')
10691134
self.ax[i].spines['bottom'].set_position('zero')
10701135
self.ax[i].spines['top'].set_color('none')
1071-
self.ax[i].spines['left'].set_smart_bounds(True)
1072-
self.ax[i].spines['bottom'].set_smart_bounds(False)
10731136
self.ax[i].xaxis.set_ticks_position('bottom')
10741137
self.ax[i].yaxis.set_ticks_position('left')
10751138

10761139
def _process_series(self, series, ax, parent):
1140+
np = import_module('numpy')
1141+
mpl_toolkits = import_module(
1142+
'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']})
1143+
1144+
# XXX Workaround for matplotlib issue
1145+
# https://github.com/matplotlib/matplotlib/issues/17130
1146+
xlims, ylims, zlims = [], [], []
1147+
10771148
for s in series:
10781149
# Create the collections
10791150
if s.is_2Dline:
@@ -1083,24 +1154,22 @@ def _process_series(self, series, ax, parent):
10831154
ax.contour(*s.get_meshes())
10841155
elif s.is_3Dline:
10851156
# TODO too complicated, I blame matplotlib
1086-
mpl_toolkits = import_module('mpl_toolkits',
1087-
import_kwargs={'fromlist': ['mplot3d']})
10881157
art3d = mpl_toolkits.mplot3d.art3d
10891158
collection = art3d.Line3DCollection(s.get_segments())
10901159
ax.add_collection(collection)
10911160
x, y, z = s.get_points()
1092-
ax.set_xlim((min(x), max(x)))
1093-
ax.set_ylim((min(y), max(y)))
1094-
ax.set_zlim((min(z), max(z)))
1161+
xlims.append(s._xlim)
1162+
ylims.append(s._ylim)
1163+
zlims.append(s._zlim)
10951164
elif s.is_3Dsurface:
10961165
x, y, z = s.get_meshes()
10971166
collection = ax.plot_surface(x, y, z,
10981167
cmap=getattr(self.cm, 'viridis', self.cm.jet),
10991168
rstride=1, cstride=1, linewidth=0.1)
1169+
xlims.append(s._xlim)
1170+
ylims.append(s._ylim)
1171+
zlims.append(s._zlim)
11001172
elif s.is_implicit:
1101-
# Smart bounds have to be set to False for implicit plots.
1102-
ax.spines['left'].set_smart_bounds(False)
1103-
ax.spines['bottom'].set_smart_bounds(False)
11041173
points = s.get_raster()
11051174
if len(points) == 2:
11061175
# interval math plotting
@@ -1118,9 +1187,10 @@ def _process_series(self, series, ax, parent):
11181187
else:
11191188
ax.contourf(xarray, yarray, zarray, cmap=colormap)
11201189
else:
1121-
raise ValueError('The matplotlib backend supports only '
1122-
'is_2Dline, is_3Dline, is_3Dsurface and '
1123-
'is_contour objects.')
1190+
raise NotImplementedError(
1191+
'{} is not supported in the sympy plotting module '
1192+
'with matplotlib backend. Please report this issue.'
1193+
.format(ax))
11241194

11251195
# Customise the collections with the corresponding per-series
11261196
# options.
@@ -1142,12 +1212,38 @@ def _process_series(self, series, ax, parent):
11421212
else:
11431213
collection.set_color(s.surface_color)
11441214

1215+
Axes3D = mpl_toolkits.mplot3d.Axes3D
1216+
if not isinstance(ax, Axes3D):
1217+
ax.autoscale_view(
1218+
scalex=ax.get_autoscalex_on(),
1219+
scaley=ax.get_autoscaley_on())
1220+
else:
1221+
# XXX Workaround for matplotlib issue
1222+
# https://github.com/matplotlib/matplotlib/issues/17130
1223+
if xlims:
1224+
xlims = np.array(xlims)
1225+
xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1]))
1226+
ax.set_xlim(xlim)
1227+
else:
1228+
ax.set_xlim([0, 1])
1229+
1230+
if ylims:
1231+
ylims = np.array(ylims)
1232+
ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1]))
1233+
ax.set_ylim(ylim)
1234+
else:
1235+
ax.set_ylim([0, 1])
1236+
1237+
if zlims:
1238+
zlims = np.array(zlims)
1239+
zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1]))
1240+
ax.set_zlim(zlim)
1241+
else:
1242+
ax.set_zlim([0, 1])
1243+
11451244
# Set global options.
11461245
# TODO The 3D stuff
11471246
# XXX The order of those is important.
1148-
mpl_toolkits = import_module('mpl_toolkits',
1149-
import_kwargs={'fromlist': ['mplot3d']})
1150-
Axes3D = mpl_toolkits.mplot3d.Axes3D
11511247
if parent.xscale and not isinstance(ax, Axes3D):
11521248
ax.set_xscale(parent.xscale)
11531249
if parent.yscale and not isinstance(ax, Axes3D):
@@ -1205,38 +1301,9 @@ def _process_series(self, series, ax, parent):
12051301
# xlim and ylim shoulld always be set at last so that plot limits
12061302
# doesn't get altered during the process.
12071303
if parent.xlim:
1208-
from sympy.core.basic import Basic
1209-
xlim = parent.xlim
1210-
if any(isinstance(i, Basic) and not i.is_real for i in xlim):
1211-
raise ValueError(
1212-
"All numbers from xlim={} must be real".format(xlim))
1213-
if any(isinstance(i, Basic) and not i.is_finite for i in xlim):
1214-
raise ValueError(
1215-
"All numbers from xlim={} must be finite".format(xlim))
1216-
xlim = (float(i) for i in xlim)
1217-
ax.set_xlim(xlim)
1218-
else:
1219-
if parent._series and all(isinstance(s, LineOver1DRangeSeries) for s in parent._series):
1220-
starts = [s.start for s in parent._series]
1221-
ends = [s.end for s in parent._series]
1222-
ax.set_xlim(min(starts), max(ends))
1223-
1304+
ax.set_xlim(parent.xlim)
12241305
if parent.ylim:
1225-
from sympy.core.basic import Basic
1226-
ylim = parent.ylim
1227-
if any(isinstance(i,Basic) and not i.is_real for i in ylim):
1228-
raise ValueError(
1229-
"All numbers from ylim={} must be real".format(ylim))
1230-
if any(isinstance(i,Basic) and not i.is_finite for i in ylim):
1231-
raise ValueError(
1232-
"All numbers from ylim={} must be finite".format(ylim))
1233-
ylim = (float(i) for i in ylim)
1234-
ax.set_ylim(ylim)
1235-
1236-
if not isinstance(ax, Axes3D):
1237-
ax.autoscale_view(
1238-
scalex=ax.get_autoscalex_on(),
1239-
scaley=ax.get_autoscaley_on())
1306+
ax.set_ylim(parent.ylim)
12401307

12411308

12421309
def process_series(self):

sympy/plotting/tests/test_experimental_lambdify.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
1-
from sympy.core.symbol import symbols
1+
from sympy.core.symbol import symbols, Symbol
2+
from sympy.functions import Max
23
from sympy.plotting.experimental_lambdify import experimental_lambdify
34
from sympy.plotting.intervalmath.interval_arithmetic import \
45
interval, intervalMembership
56

67

8+
# Tests for exception handling in experimental_lambdify
9+
def test_experimental_lambify():
10+
x = Symbol('x')
11+
f = experimental_lambdify([x], Max(x, 5))
12+
# XXX should f be tested? If f(2) is attempted, an
13+
# error is raised because a complex produced during wrapping of the arg
14+
# is being compared with an int.
15+
assert Max(2, 5) == 5
16+
assert Max(5, 7) == 7
17+
18+
x = Symbol('x-3')
19+
f = experimental_lambdify([x], x + 1)
20+
assert f(1) == 2
21+
22+
723
def test_composite_boolean_region():
824
x, y = symbols('x y')
925

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy