diff --git a/control/ctrlplot.py b/control/ctrlplot.py index c8c30880d..6d31664a0 100644 --- a/control/ctrlplot.py +++ b/control/ctrlplot.py @@ -5,6 +5,7 @@ from os.path import commonprefix +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np @@ -12,6 +13,28 @@ __all__ = ['suptitle', 'get_plot_axes'] +# +# Style parameters +# + +_ctrlplot_rcParams = mpl.rcParams.copy() +_ctrlplot_rcParams.update({ + 'axes.labelsize': 'small', + 'axes.titlesize': 'small', + 'figure.titlesize': 'medium', + 'legend.fontsize': 'x-small', + 'xtick.labelsize': 'small', + 'ytick.labelsize': 'small', +}) + + +# +# User functions +# +# The functions below can be used by users to modify ctrl plots or get +# information about them. +# + def suptitle( title, fig=None, frame='axes', **kwargs): @@ -35,7 +58,7 @@ def suptitle( Additional keywords (passed to matplotlib). """ - rcParams = config._get_param('freqplot', 'rcParams', kwargs, pop=True) + rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True) if fig is None: fig = plt.gcf() @@ -61,10 +84,10 @@ def suptitle( def get_plot_axes(line_array): """Get a list of axes from an array of lines. - This function can be used to return the set of axes corresponding to - the line array that is returned by `time_response_plot`. This is useful for - generating an axes array that can be passed to subsequent plotting - calls. + This function can be used to return the set of axes corresponding + to the line array that is returned by `time_response_plot`. This + is useful for generating an axes array that can be passed to + subsequent plotting calls. Parameters ---------- @@ -89,6 +112,125 @@ def get_plot_axes(line_array): # # Utility functions # +# These functions are used by plotting routines to provide a consistent way +# of processing and displaying information. +# + + +def _process_ax_keyword( + axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False): + """Utility function to process ax keyword to plotting commands. + + This function processes the `ax` keyword to plotting commands. If no + ax keyword is passed, the current figure is checked to see if it has + the correct shape. If the shape matches the desired shape, then the + current figure and axes are returned. Otherwise a new figure is + created with axes of the desired shape. + + Legacy behavior: some of the older plotting commands use a axes label + to identify the proper axes for plotting. This behavior is supported + through the use of the label keyword, but will only work if shape == + (1, 1) and squeeze == True. + + """ + if axs is None: + fig = plt.gcf() # get current figure (or create new one) + axs = fig.get_axes() + + # Check to see if axes are the right shape; if not, create new figure + # Note: can't actually check the shape, just the total number of axes + if len(axs) != np.prod(shape): + with plt.rc_context(rcParams): + if len(axs) != 0: + # Create a new figure + fig, axs = plt.subplots(*shape, squeeze=False) + else: + # Create new axes on (empty) figure + axs = fig.subplots(*shape, squeeze=False) + fig.set_layout_engine('tight') + fig.align_labels() + else: + # Use the existing axes, properly reshaped + axs = np.asarray(axs).reshape(*shape) + + if clear_text: + # Clear out any old text from the current figure + for text in fig.texts: + text.set_visible(False) # turn off the text + del text # get rid of it completely + else: + try: + axs = np.asarray(axs).reshape(shape) + except ValueError: + raise ValueError( + "specified axes are not the right shape; " + f"got {axs.shape} but expecting {shape}") + fig = axs[0, 0].figure + + # Process the squeeze keyword + if squeeze and shape == (1, 1): + axs = axs[0, 0] # Just return the single axes object + elif squeeze: + axs = axs.squeeze() + + return fig, axs + + +# Turn label keyword into array indexed by trace, output, input +# TODO: move to ctrlutil.py and update parameter names to reflect general use +def _process_line_labels(label, ntraces, ninputs=0, noutputs=0): + if label is None: + return None + + if isinstance(label, str): + label = [label] * ntraces # single label for all traces + + # Convert to an ndarray, if not done aleady + try: + line_labels = np.asarray(label) + except ValueError: + raise ValueError("label must be a string or array_like") + + # Turn the data into a 3D array of appropriate shape + # TODO: allow more sophisticated broadcasting (and error checking) + try: + if ninputs > 0 and noutputs > 0: + if line_labels.ndim == 1 and line_labels.size == ntraces: + line_labels = line_labels.reshape(ntraces, 1, 1) + line_labels = np.broadcast_to( + line_labels, (ntraces, ninputs, noutputs)) + else: + line_labels = line_labels.reshape(ntraces, ninputs, noutputs) + except ValueError: + if line_labels.shape[0] != ntraces: + raise ValueError("number of labels must match number of traces") + else: + raise ValueError("labels must be given for each input/output pair") + + return line_labels + + +# Get labels for all lines in an axes +def _get_line_labels(ax, use_color=True): + labels, lines = [], [] + last_color, counter = None, 0 # label unknown systems + for i, line in enumerate(ax.get_lines()): + label = line.get_label() + if use_color and label.startswith("Unknown"): + label = f"Unknown-{counter}" + if last_color is None: + last_color = line.get_color() + elif last_color != line.get_color(): + counter += 1 + last_color = line.get_color() + elif label[0] == '_': + continue + + if label not in labels: + lines.append(line) + labels.append(label) + + return lines, labels # Utility function to make legend labels @@ -160,3 +302,83 @@ def _find_axes_center(fig, axs): ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])] return (np.sum(xlim)/2, np.sum(ylim)/2) + + +# Internal function to add arrows to a curve +def _add_arrows_to_line2D( + axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8], + arrowstyle='-|>', arrowsize=1, dir=1): + """ + Add arrows to a matplotlib.lines.Line2D at selected locations. + + Parameters: + ----------- + axes: Axes object as returned by axes command (or gca) + line: Line2D object as returned by plot command + arrow_locs: list of locations where to insert arrows, % of total length + arrowstyle: style of the arrow + arrowsize: size of the arrow + + Returns: + -------- + arrows: list of arrows + + Based on https://stackoverflow.com/questions/26911898/ + + """ + # Get the coordinates of the line, in plot coordinates + if not isinstance(line, mpl.lines.Line2D): + raise ValueError("expected a matplotlib.lines.Line2D object") + x, y = line.get_xdata(), line.get_ydata() + + # Determine the arrow properties + arrow_kw = {"arrowstyle": arrowstyle} + + color = line.get_color() + use_multicolor_lines = isinstance(color, np.ndarray) + if use_multicolor_lines: + raise NotImplementedError("multicolor lines not supported") + else: + arrow_kw['color'] = color + + linewidth = line.get_linewidth() + if isinstance(linewidth, np.ndarray): + raise NotImplementedError("multiwidth lines not supported") + else: + arrow_kw['linewidth'] = linewidth + + # Figure out the size of the axes (length of diagonal) + xlim, ylim = axes.get_xlim(), axes.get_ylim() + ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]]) + diag = np.linalg.norm(ul - lr) + + # Compute the arc length along the curve + s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2)) + + # Truncate the number of arrows if the curve is short + # TODO: figure out a smarter way to do this + frac = min(s[-1] / diag, 1) + if len(arrow_locs) and frac < 0.05: + arrow_locs = [] # too short; no arrows at all + elif len(arrow_locs) and frac < 0.2: + arrow_locs = [0.5] # single arrow in the middle + + # Plot the arrows (and return list if patches) + arrows = [] + for loc in arrow_locs: + n = np.searchsorted(s, s[-1] * loc) + + if dir == 1 and n == 0: + # Move the arrow forward by one if it is at start of a segment + n = 1 + + # Place the head of the arrow at the desired location + arrow_head = [x[n], y[n]] + arrow_tail = [x[n - dir], y[n - dir]] + + p = mpl.patches.FancyArrowPatch( + arrow_tail, arrow_head, transform=axes.transData, lw=0, + **arrow_kw) + axes.add_patch(p) + arrows.append(p) + return arrows diff --git a/control/freqplot.py b/control/freqplot.py index 5ff690450..277de8a54 100644 --- a/control/freqplot.py +++ b/control/freqplot.py @@ -19,8 +19,9 @@ from . import config from .bdalg import feedback -from .ctrlplot import suptitle, _find_axes_center, _make_legend_labels, \ - _update_suptitle +from .ctrlplot import _add_arrows_to_line2D, _ctrlplot_rcParams, \ + _find_axes_center, _get_line_labels, _make_legend_labels, \ + _process_ax_keyword, _process_line_labels, _update_suptitle, suptitle from .ctrlutil import unwrap from .exception import ControlMIMONotImplemented from .frdata import FrequencyResponseData @@ -34,21 +35,9 @@ 'singular_values_plot', 'gangof4_plot', 'gangof4_response', 'bode', 'nyquist', 'gangof4'] -# Default font dictionary -# TODO: move common plotting params to 'ctrlplot' -_freqplot_rcParams = mpl.rcParams.copy() -_freqplot_rcParams.update({ - 'axes.labelsize': 'small', - 'axes.titlesize': 'small', - 'figure.titlesize': 'medium', - 'legend.fontsize': 'x-small', - 'xtick.labelsize': 'small', - 'ytick.labelsize': 'small', -}) - # Default values for module parameter variables _freqplot_defaults = { - 'freqplot.rcParams': _freqplot_rcParams, + 'freqplot.rcParams': _ctrlplot_rcParams, 'freqplot.feature_periphery_decades': 1, 'freqplot.number_of_samples': 1000, 'freqplot.dB': False, # Plot gain in dB @@ -1937,86 +1926,6 @@ def _parse_linestyle(style_name, allow_false=False): return out -# Internal function to add arrows to a curve -def _add_arrows_to_line2D( - axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8], - arrowstyle='-|>', arrowsize=1, dir=1): - """ - Add arrows to a matplotlib.lines.Line2D at selected locations. - - Parameters: - ----------- - axes: Axes object as returned by axes command (or gca) - line: Line2D object as returned by plot command - arrow_locs: list of locations where to insert arrows, % of total length - arrowstyle: style of the arrow - arrowsize: size of the arrow - - Returns: - -------- - arrows: list of arrows - - Based on https://stackoverflow.com/questions/26911898/ - - """ - # Get the coordinates of the line, in plot coordinates - if not isinstance(line, mpl.lines.Line2D): - raise ValueError("expected a matplotlib.lines.Line2D object") - x, y = line.get_xdata(), line.get_ydata() - - # Determine the arrow properties - arrow_kw = {"arrowstyle": arrowstyle} - - color = line.get_color() - use_multicolor_lines = isinstance(color, np.ndarray) - if use_multicolor_lines: - raise NotImplementedError("multicolor lines not supported") - else: - arrow_kw['color'] = color - - linewidth = line.get_linewidth() - if isinstance(linewidth, np.ndarray): - raise NotImplementedError("multiwidth lines not supported") - else: - arrow_kw['linewidth'] = linewidth - - # Figure out the size of the axes (length of diagonal) - xlim, ylim = axes.get_xlim(), axes.get_ylim() - ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]]) - diag = np.linalg.norm(ul - lr) - - # Compute the arc length along the curve - s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2)) - - # Truncate the number of arrows if the curve is short - # TODO: figure out a smarter way to do this - frac = min(s[-1] / diag, 1) - if len(arrow_locs) and frac < 0.05: - arrow_locs = [] # too short; no arrows at all - elif len(arrow_locs) and frac < 0.2: - arrow_locs = [0.5] # single arrow in the middle - - # Plot the arrows (and return list if patches) - arrows = [] - for loc in arrow_locs: - n = np.searchsorted(s, s[-1] * loc) - - if dir == 1 and n == 0: - # Move the arrow forward by one if it is at start of a segment - n = 1 - - # Place the head of the arrow at the desired location - arrow_head = [x[n], y[n]] - arrow_tail = [x[n - dir], y[n - dir]] - - p = mpl.patches.FancyArrowPatch( - arrow_tail, arrow_head, transform=axes.transData, lw=0, - **arrow_kw) - axes.add_patch(p) - arrows.append(p) - return arrows - - # # Function to compute Nyquist curve offsets # @@ -2672,122 +2581,6 @@ def _default_frequency_range(syslist, Hz=None, number_of_samples=None, return omega -# Get labels for all lines in an axes -def _get_line_labels(ax, use_color=True): - labels, lines = [], [] - last_color, counter = None, 0 # label unknown systems - for i, line in enumerate(ax.get_lines()): - label = line.get_label() - if use_color and label.startswith("Unknown"): - label = f"Unknown-{counter}" - if last_color is None: - last_color = line.get_color() - elif last_color != line.get_color(): - counter += 1 - last_color = line.get_color() - elif label[0] == '_': - continue - - if label not in labels: - lines.append(line) - labels.append(label) - - return lines, labels - - -# Turn label keyword into array indexed by trace, output, input -# TODO: move to ctrlutil.py and update parameter names to reflect general use -def _process_line_labels(label, ntraces, ninputs=0, noutputs=0): - if label is None: - return None - - if isinstance(label, str): - label = [label] * ntraces # single label for all traces - - # Convert to an ndarray, if not done aleady - try: - line_labels = np.asarray(label) - except: - raise ValueError("label must be a string or array_like") - - # Turn the data into a 3D array of appropriate shape - # TODO: allow more sophisticated broadcasting (and error checking) - try: - if ninputs > 0 and noutputs > 0: - if line_labels.ndim == 1 and line_labels.size == ntraces: - line_labels = line_labels.reshape(ntraces, 1, 1) - line_labels = np.broadcast_to( - line_labels, (ntraces, ninputs, noutputs)) - else: - line_labels = line_labels.reshape(ntraces, ninputs, noutputs) - except: - if line_labels.shape[0] != ntraces: - raise ValueError("number of labels must match number of traces") - else: - raise ValueError("labels must be given for each input/output pair") - - return line_labels - - -def _process_ax_keyword( - axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False): - """Utility function to process ax keyword to plotting commands. - - This function processes the `ax` keyword to plotting commands. If no - ax keyword is passed, the current figure is checked to see if it has - the correct shape. If the shape matches the desired shape, then the - current figure and axes are returned. Otherwise a new figure is - created with axes of the desired shape. - - Legacy behavior: some of the older plotting commands use a axes label - to identify the proper axes for plotting. This behavior is supported - through the use of the label keyword, but will only work if shape == - (1, 1) and squeeze == True. - - """ - if axs is None: - fig = plt.gcf() # get current figure (or create new one) - axs = fig.get_axes() - - # Check to see if axes are the right shape; if not, create new figure - # Note: can't actually check the shape, just the total number of axes - if len(axs) != np.prod(shape): - with plt.rc_context(rcParams): - if len(axs) != 0: - # Create a new figure - fig, axs = plt.subplots(*shape, squeeze=False) - else: - # Create new axes on (empty) figure - axs = fig.subplots(*shape, squeeze=False) - fig.set_layout_engine('tight') - fig.align_labels() - else: - # Use the existing axes, properly reshaped - axs = np.asarray(axs).reshape(*shape) - - if clear_text: - # Clear out any old text from the current figure - for text in fig.texts: - text.set_visible(False) # turn off the text - del text # get rid of it completely - else: - try: - axs = np.asarray(axs).reshape(shape) - except ValueError: - raise ValueError( - "specified axes are not the right shape; " - f"got {axs.shape} but expecting {shape}") - fig = axs[0, 0].figure - - # Process the squeeze keyword - if squeeze and shape == (1, 1): - axs = axs[0, 0] # Just return the single axes object - elif squeeze: - axs = axs.squeeze() - - return fig, axs - - # # Utility functions to create nice looking labels (KLD 5/23/11) # diff --git a/control/grid.py b/control/grid.py index ef9995947..dfe8f9a3e 100644 --- a/control/grid.py +++ b/control/grid.py @@ -141,18 +141,6 @@ def sgrid(scaling=None): return ax, fig -# Utility function used by all grid code -def _final_setup(ax, scaling=None): - ax.set_xlabel('Real') - ax.set_ylabel('Imaginary') - ax.axhline(y=0, color='black', lw=0.25) - ax.axvline(x=0, color='black', lw=0.25) - - # Set up the scaling for the axes - scaling = 'equal' if scaling is None else scaling - plt.axis(scaling) - - # If not grid is given, at least separate stable/unstable regions def nogrid(dt=None, ax=None, scaling=None): fig = plt.gcf() @@ -226,3 +214,15 @@ def zgrid(zetas=None, wns=None, ax=None, scaling=None): _final_setup(ax, scaling=scaling) return ax, fig + + +# Utility function used by all grid code +def _final_setup(ax, scaling=None): + ax.set_xlabel('Real') + ax.set_ylabel('Imaginary') + ax.axhline(y=0, color='black', lw=0.25) + ax.axvline(x=0, color='black', lw=0.25) + + # Set up the scaling for the axes + scaling = 'equal' if scaling is None else scaling + plt.axis(scaling) diff --git a/control/nichols.py b/control/nichols.py index 5eafa594f..78b03b315 100644 --- a/control/nichols.py +++ b/control/nichols.py @@ -18,10 +18,9 @@ import numpy as np from . import config -from .ctrlplot import suptitle +from .ctrlplot import _get_line_labels, _process_ax_keyword, suptitle from .ctrlutil import unwrap -from .freqplot import _default_frequency_range, _freqplot_defaults, \ - _get_line_labels, _process_ax_keyword +from .freqplot import _default_frequency_range, _freqplot_defaults from .lti import frequency_response from .statesp import StateSpace from .xferfcn import TransferFunction diff --git a/control/phaseplot.py b/control/phaseplot.py index a885f2d5c..c7ccd1d1e 100644 --- a/control/phaseplot.py +++ b/control/phaseplot.py @@ -36,8 +36,8 @@ from scipy.integrate import odeint from . import config +from .ctrlplot import _add_arrows_to_line2D from .exception import ControlNotImplemented -from .freqplot import _add_arrows_to_line2D from .nlsys import NonlinearIOSystem, find_eqpt, input_output_response __all__ = ['phase_plane_plot', 'phase_plot', 'box_grid'] diff --git a/control/pzmap.py b/control/pzmap.py index dd3f9e42b..c7082db1d 100644 --- a/control/pzmap.py +++ b/control/pzmap.py @@ -18,7 +18,8 @@ from numpy import cos, exp, imag, linspace, real, sin, sqrt from . import config -from .freqplot import _freqplot_defaults, _get_line_labels +from .ctrlplot import _get_line_labels +from .freqplot import _freqplot_defaults from .grid import nogrid, sgrid, zgrid from .iosys import isctime, isdtime from .lti import LTI diff --git a/control/tests/ctrlplot_test.py b/control/tests/ctrlplot_test.py new file mode 100644 index 000000000..05970bdd1 --- /dev/null +++ b/control/tests/ctrlplot_test.py @@ -0,0 +1,42 @@ +# ctrlplot_test.py - test out control plotting utilities +# RMM, 27 Jun 2024 + +import pytest +import control as ct +import matplotlib.pyplot as plt + +@pytest.mark.usefixtures('mplcleanup') +def test_rcParams(): + sys = ct.rss(2, 2, 2) + + # Create new set of rcParams + my_rcParams = {} + for key in [ + 'axes.labelsize', 'axes.titlesize', 'figure.titlesize', + 'legend.fontsize', 'xtick.labelsize', 'ytick.labelsize']: + match plt.rcParams[key]: + case 8 | 9 | 10: + my_rcParams[key] = plt.rcParams[key] + 1 + case 'medium': + my_rcParams[key] = 11.5 + case 'large': + my_rcParams[key] = 9.5 + case _: + raise ValueError(f"unknown rcParam type for {key}") + + # Generate a figure with the new rcParams + out = ct.step_response(sys).plot(rcParams=my_rcParams) + ax = out[0, 0][0].axes + fig = ax.figure + + # Check to make sure new settings were used + assert ax.xaxis.get_label().get_fontsize() == my_rcParams['axes.labelsize'] + assert ax.yaxis.get_label().get_fontsize() == my_rcParams['axes.labelsize'] + assert ax.title.get_fontsize() == my_rcParams['axes.titlesize'] + assert ax.get_xticklabels()[0].get_fontsize() == \ + my_rcParams['xtick.labelsize'] + assert ax.get_yticklabels()[0].get_fontsize() == \ + my_rcParams['ytick.labelsize'] + assert fig._suptitle.get_fontsize() == my_rcParams['figure.titlesize'] + + diff --git a/control/tests/timeplot_test.py b/control/tests/timeplot_test.py index 0fcc159be..6c124c48f 100644 --- a/control/tests/timeplot_test.py +++ b/control/tests/timeplot_test.py @@ -397,41 +397,6 @@ def test_linestyles(): assert lines[7].get_color() == 'green' and lines[7].get_linestyle() == '--' -@pytest.mark.usefixtures('mplcleanup') -def test_rcParams(): - sys = ct.rss(2, 2, 2) - - # Create new set of rcParams - my_rcParams = {} - for key in [ - 'axes.labelsize', 'axes.titlesize', 'figure.titlesize', - 'legend.fontsize', 'xtick.labelsize', 'ytick.labelsize']: - match plt.rcParams[key]: - case 8 | 9 | 10: - my_rcParams[key] = plt.rcParams[key] + 1 - case 'medium': - my_rcParams[key] = 11.5 - case 'large': - my_rcParams[key] = 9.5 - case _: - raise ValueError(f"unknown rcParam type for {key}") - - # Generate a figure with the new rcParams - out = ct.step_response(sys).plot(rcParams=my_rcParams) - ax = out[0, 0][0].axes - fig = ax.figure - - # Check to make sure new settings were used - assert ax.xaxis.get_label().get_fontsize() == my_rcParams['axes.labelsize'] - assert ax.yaxis.get_label().get_fontsize() == my_rcParams['axes.labelsize'] - assert ax.title.get_fontsize() == my_rcParams['axes.titlesize'] - assert ax.get_xticklabels()[0].get_fontsize() == \ - my_rcParams['xtick.labelsize'] - assert ax.get_yticklabels()[0].get_fontsize() == \ - my_rcParams['ytick.labelsize'] - assert fig._suptitle.get_fontsize() == my_rcParams['figure.titlesize'] - - @pytest.mark.parametrize("resp_fcn", [ ct.step_response, ct.initial_response, ct.impulse_response, ct.forced_response, ct.input_output_response]) diff --git a/control/timeplot.py b/control/timeplot.py index 2eb7aec9b..01b5c7945 100644 --- a/control/timeplot.py +++ b/control/timeplot.py @@ -15,24 +15,13 @@ import numpy as np from . import config -from .ctrlplot import _make_legend_labels, _update_suptitle +from .ctrlplot import _ctrlplot_rcParams, _make_legend_labels, _update_suptitle __all__ = ['time_response_plot', 'combine_time_responses'] -# Default font dictionary -_timeplot_rcParams = mpl.rcParams.copy() -_timeplot_rcParams.update({ - 'axes.labelsize': 'small', - 'axes.titlesize': 'small', - 'figure.titlesize': 'medium', - 'legend.fontsize': 'x-small', - 'xtick.labelsize': 'small', - 'ytick.labelsize': 'small', -}) - # Default values for module parameter variables _timeplot_defaults = { - 'timeplot.rcParams': _timeplot_rcParams, + 'timeplot.rcParams': _ctrlplot_rcParams, 'timeplot.trace_props': [ {'linestyle': s} for s in ['-', '--', ':', '-.']], 'timeplot.output_props': [ @@ -162,7 +151,7 @@ def time_response_plot( config.defaults[''timeplot.rcParams']. """ - from .freqplot import _process_ax_keyword, _process_line_labels + from .ctrlplot import _process_ax_keyword, _process_line_labels from .iosys import InputOutputSystem from .timeresp import TimeResponseData
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: