diff --git a/control/statesp.py b/control/statesp.py index 03349b0ac..c12583111 100644 --- a/control/statesp.py +++ b/control/statesp.py @@ -1434,11 +1434,11 @@ def _convert_to_statespace(sys, **kw): # TODO: add discrete time option -def _rss_generate(states, inputs, outputs, type, strictly_proper=False): +def _rss_generate(states, inputs, outputs, cdtype, strictly_proper=False): """Generate a random state space. This does the actual random state space generation expected from rss and - drss. type is 'c' for continuous systems and 'd' for discrete systems. + drss. cdtype is 'c' for continuous systems and 'd' for discrete systems. """ @@ -1465,6 +1465,8 @@ def _rss_generate(states, inputs, outputs, type, strictly_proper=False): if outputs < 1 or outputs % 1: raise ValueError("outputs must be a positive integer. outputs = %g." % outputs) + if cdtype not in ['c', 'd']: + raise ValueError("cdtype must be `c` or `d`") # Make some poles for A. Preallocate a complex array. poles = zeros(states) + zeros(states) * 0.j @@ -1484,16 +1486,16 @@ def _rss_generate(states, inputs, outputs, type, strictly_proper=False): i += 2 elif rand() < pReal or i == states - 1: # No-oscillation pole. - if type == 'c': + if cdtype == 'c': poles[i] = -exp(randn()) + 0.j - elif type == 'd': + else: poles[i] = 2. * rand() - 1. i += 1 else: # Complex conjugate pair of oscillating poles. - if type == 'c': + if cdtype == 'c': poles[i] = complex(-exp(randn()), 3. * exp(randn())) - elif type == 'd': + else: mag = rand() phase = 2. * math.pi * rand() poles[i] = complex(mag * cos(phase), mag * sin(phase)) @@ -1546,7 +1548,11 @@ def _rss_generate(states, inputs, outputs, type, strictly_proper=False): C = C * Cmask D = D * Dmask if not strictly_proper else zeros(D.shape) - return StateSpace(A, B, C, D) + if cdtype == 'c': + ss_args = (A, B, C, D) + else: + ss_args = (A, B, C, D, True) + return StateSpace(*ss_args) # Convert a MIMO system to a SISO system @@ -1825,15 +1831,14 @@ def rss(states=1, outputs=1, inputs=1, strictly_proper=False): Parameters ---------- - states : integer + states : int Number of state variables - inputs : integer + inputs : int Number of system inputs - outputs : integer + outputs : int Number of system outputs strictly_proper : bool, optional - If set to 'True', returns a proper system (no direct term). Default - value is 'False'. + If set to 'True', returns a proper system (no direct term). Returns ------- @@ -1867,12 +1872,15 @@ def drss(states=1, outputs=1, inputs=1, strictly_proper=False): Parameters ---------- - states : integer + states : int Number of state variables inputs : integer Number of system inputs - outputs : integer + outputs : int Number of system outputs + strictly_proper: bool, optional + If set to 'True', returns a proper system (no direct term). + Returns ------- diff --git a/control/tests/statesp_test.py b/control/tests/statesp_test.py index 67cf950e7..71e7cc4bc 100644 --- a/control/tests/statesp_test.py +++ b/control/tests/statesp_test.py @@ -19,7 +19,7 @@ from control.dtime import sample_system from control.lti import evalfr from control.statesp import (StateSpace, _convert_to_statespace, drss, - rss, ss, tf2ss, _statesp_defaults) + rss, ss, tf2ss, _statesp_defaults, _rss_generate) from control.tests.conftest import ismatarrayout, slycotonly from control.xferfcn import TransferFunction, ss2tf @@ -855,6 +855,28 @@ def test_pole(self, states, outputs, inputs): for z in p: assert z.real < 0 + @pytest.mark.parametrize('strictly_proper', [True, False]) + def test_strictly_proper(self, strictly_proper): + """Test that the strictly_proper argument returns a correct D.""" + for i in range(100): + # The probability that drss(..., strictly_proper=False) returns an + # all zero D 100 times in a row is 0.5**100 = 7.89e-31 + sys = rss(1, 1, 1, strictly_proper=strictly_proper) + if np.all(sys.D == 0.) == strictly_proper: + break + assert np.all(sys.D == 0.) == strictly_proper + + @pytest.mark.parametrize('par, errmatch', + [((-1, 1, 1, 'c'), 'states must be'), + ((1, -1, 1, 'c'), 'inputs must be'), + ((1, 1, -1, 'c'), 'outputs must be'), + ((1, 1, 1, 'x'), 'cdtype must be'), + ]) + def test_rss_invalid(self, par, errmatch): + """Test invalid inputs for rss() and drss().""" + with pytest.raises(ValueError, match=errmatch): + _rss_generate(*par) + class TestDrss: """These are tests for the proper functionality of statesp.drss.""" @@ -873,6 +895,7 @@ def test_shape(self, states, outputs, inputs): assert sys.nstates == states assert sys.ninputs == inputs assert sys.noutputs == outputs + assert sys.dt is True @pytest.mark.parametrize('states', range(1, maxStates)) @pytest.mark.parametrize('outputs', range(1, maxIO)) @@ -884,6 +907,17 @@ def test_pole(self, states, outputs, inputs): for z in p: assert abs(z) < 1 + @pytest.mark.parametrize('strictly_proper', [True, False]) + def test_strictly_proper(self, strictly_proper): + """Test that the strictly_proper argument returns a correct D.""" + for i in range(100): + # The probability that drss(..., strictly_proper=False) returns an + # all zero D 100 times in a row is 0.5**100 = 7.89e-31 + sys = drss(1, 1, 1, strictly_proper=strictly_proper) + if np.all(sys.D == 0.) == strictly_proper: + break + assert np.all(sys.D == 0.) == strictly_proper + class TestLTIConverter: """Test returnScipySignalLTI method"""
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: