Skip to content

Commit ab77e02

Browse files
committed
Merge branch 'extend-tf2scipylti' into array-matrix-tests
2 parents 4824143 + a8aa41e commit ab77e02

File tree

4 files changed

+146
-17
lines changed

4 files changed

+146
-17
lines changed

control/statesp.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
from numpy.linalg import solve, eigvals, matrix_rank
6060
from numpy.linalg.linalg import LinAlgError
6161
import scipy as sp
62-
from scipy.signal import lti, cont2discrete
62+
from scipy.signal import cont2discrete
63+
from scipy.signal import StateSpace as signalStateSpace
6364
from warnings import warn
6465
from .lti import LTI, common_timebase, isdtime
6566
from . import config
@@ -802,26 +803,51 @@ def minreal(self, tol=0.0):
802803
else:
803804
return StateSpace(self)
804805

805-
806-
# TODO: add discrete time check
807-
def returnScipySignalLTI(self):
808-
"""Return a list of a list of scipy.signal.lti objects.
806+
def returnScipySignalLTI(self, strict=True):
807+
"""Return a list of a list of SISO scipy.signal.lti objects.
809808
810809
For instance,
811810
812811
>>> out = ssobject.returnScipySignalLTI()
813812
>>> out[3][5]
814813
815-
is a signal.scipy.lti object corresponding to the transfer function from
816-
the 6th input to the 4th output."""
814+
is a signal.scipy.lti object corresponding to the transfer function
815+
from the 6th input to the 4th output.
816+
817+
Parameters
818+
----------
819+
strict : bool, optional
820+
True (default):
821+
`ssobject` must be continuous or discrete. `tfobject.dt` cannot
822+
be None.
823+
False:
824+
if `ssobject.dt` is None, continuous time signal.StateSpace
825+
objects are returned
826+
827+
Returns
828+
-------
829+
out : list of list of scipy.signal.StateSpace
830+
"""
831+
if strict and self.dt is None:
832+
raise ValueError("with strict=True, dt cannot be None")
833+
834+
if self.dt:
835+
kwdt = {'dt': self.dt}
836+
else:
837+
# scipy convention for continuous time lti systems: call without
838+
# dt keyword argument
839+
kwdt = {}
817840

818841
# Preallocate the output.
819842
out = [[[] for _ in range(self.inputs)] for _ in range(self.outputs)]
820843

821844
for i in range(self.outputs):
822845
for j in range(self.inputs):
823-
out[i][j] = lti(asarray(self.A), asarray(self.B[:, j]),
824-
asarray(self.C[i, :]), self.D[i, j])
846+
out[i][j] = signalStateSpace(asarray(self.A),
847+
asarray(self.B[:, j]),
848+
asarray(self.C[i, :]),
849+
self.D[i, j],
850+
**kwdt)
825851

826852
return out
827853

control/tests/statesp_test.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010

1111
import numpy as np
12-
from numpy.linalg import solve
1312
import pytest
13+
14+
from numpy.linalg import solve
1415
from scipy.linalg import eigvals, block_diag
1516

1617
from control.statesp import StateSpace, _convertToStateSpace, drss, rss, tf2ss
@@ -698,3 +699,49 @@ def test_pole(self, states, outputs, inputs):
698699
assert abs(z) < 1
699700

700701

702+
class TestLTIConverter:
703+
"""Test the LTI system return function"""
704+
705+
@pytest.fixture
706+
def mimoss(self, request):
707+
"""Test system with various dt values"""
708+
n = 5
709+
m = 3
710+
p = 2
711+
bx, bu = np.mgrid[1:n + 1, 1:m + 1]
712+
cy, cx = np.mgrid[1:p + 1, 1:n + 1]
713+
dy, du = np.mgrid[1:p + 1, 1:m + 1]
714+
return StateSpace(np.eye(5),
715+
bx * bu,
716+
cy * cx,
717+
dy * du,
718+
request.param)
719+
720+
@pytest.mark.parametrize("mimoss",
721+
[None,
722+
0,
723+
0.1,
724+
1,
725+
True],
726+
indirect=True)
727+
def test_returnScipySignalLTI(self, mimoss):
728+
"""Test returnScipySignalLTI method with strict=False"""
729+
sslti = mimoss.returnScipySignalLTI(strict=False)
730+
for i in range(2):
731+
for j in range(3):
732+
np.testing.assert_allclose(sslti[i][j].A, mimoss.A)
733+
np.testing.assert_allclose(sslti[i][j].B, mimoss.B[:, j])
734+
np.testing.assert_allclose(sslti[i][j].C, mimoss.C[i, :])
735+
np.testing.assert_allclose(sslti[i][j].D, mimoss.D[i, j])
736+
if mimoss.dt == 0:
737+
assert sslti[i][j].dt is None
738+
else:
739+
assert sslti[i][j].dt == mimoss.dt
740+
741+
@pytest.mark.parametrize("mimoss", [None], indirect=True)
742+
def test_returnScipySignalLTI_error(self, mimoss):
743+
"""Test returnScipySignalLTI method with dt=None and strict=True"""
744+
with pytest.raises(ValueError):
745+
mimoss.returnScipySignalLTI()
746+
with pytest.raises(ValueError):
747+
mimoss.returnScipySignalLTI(strict=True)

control/tests/xferfcn_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,3 +957,39 @@ def test_repr(self, Hargs, ref):
957957
np.testing.assert_array_almost_equal(H.num[p][m], H2.num[p][m])
958958
np.testing.assert_array_almost_equal(H.den[p][m], H2.den[p][m])
959959
assert H.dt == H2.dt
960+
961+
@pytest.fixture
962+
def mimotf(self, request):
963+
"""Test system with various dt values"""
964+
return TransferFunction([[[11], [12], [13]],
965+
[[21], [22], [23]]],
966+
[[[1, -1]] * 3] * 2,
967+
request.param)
968+
969+
@pytest.mark.parametrize("mimotf",
970+
[None,
971+
0,
972+
0.1,
973+
1,
974+
True],
975+
indirect=True)
976+
def test_returnScipySignalLTI(self, mimotf):
977+
"""Test returnScipySignalLTI method with strict=False"""
978+
sslti = mimotf.returnScipySignalLTI(strict=False)
979+
for i in range(2):
980+
for j in range(3):
981+
np.testing.assert_allclose(sslti[i][j].num, mimotf.num[i][j])
982+
np.testing.assert_allclose(sslti[i][j].den, mimotf.den[i][j])
983+
if mimotf.dt == 0:
984+
assert sslti[i][j].dt is None
985+
else:
986+
assert sslti[i][j].dt == mimotf.dt
987+
988+
@pytest.mark.parametrize("mimotf", [None], indirect=True)
989+
def test_returnScipySignalLTI_error(self, mimotf):
990+
"""Test returnScipySignalLTI method with dt=None and default strict=True"""
991+
with pytest.raises(ValueError):
992+
mimotf.returnScipySignalLTI()
993+
with pytest.raises(ValueError):
994+
mimotf.returnScipySignalLTI(strict=True)
995+

control/xferfcn.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@
5757
polyadd, polymul, polyval, roots, sqrt, zeros, squeeze, exp, pi, \
5858
where, delete, real, poly, nonzero
5959
import scipy as sp
60-
from scipy.signal import lti, tf2zpk, zpk2tf, cont2discrete
60+
from scipy.signal import tf2zpk, zpk2tf, cont2discrete
61+
from scipy.signal import TransferFunction as signalTransferFunction
6162
from copy import deepcopy
6263
from warnings import warn
6364
from itertools import chain
@@ -788,7 +789,7 @@ def minreal(self, tol=None):
788789
# end result
789790
return TransferFunction(num, den, self.dt)
790791

791-
def returnScipySignalLTI(self):
792+
def returnScipySignalLTI(self, strict=True):
792793
"""Return a list of a list of scipy.signal.lti objects.
793794
794795
For instance,
@@ -799,19 +800,38 @@ def returnScipySignalLTI(self):
799800
is a signal.scipy.lti object corresponding to the
800801
transfer function from the 6th input to the 4th output.
801802
803+
Parameters
804+
----------
805+
strict : bool, optional
806+
True (default):
807+
`tfobject` must be continuous or discrete.
808+
`tfobject.dt`cannot be None.
809+
False:
810+
if `tfobject.dt` is None, continuous time signal.TransferFunction
811+
objects are is returned
812+
813+
Returns
814+
-------
815+
out : list of list of scipy.signal.TransferFunction
802816
"""
817+
if strict and self.dt is None:
818+
raise ValueError("with strict=True, dt cannot be None")
803819

804-
# TODO: implement for discrete time systems
805-
if self.dt != 0 and self.dt is not None:
806-
raise NotImplementedError("Function not \
807-
implemented in discrete time")
820+
if self.dt:
821+
kwdt = {'dt': self.dt}
822+
else:
823+
# scipy convention for continuous time lti systems: call without
824+
# dt keyword argument
825+
kwdt = {}
808826

809827
# Preallocate the output.
810828
out = [[[] for j in range(self.inputs)] for i in range(self.outputs)]
811829

812830
for i in range(self.outputs):
813831
for j in range(self.inputs):
814-
out[i][j] = lti(self.num[i][j], self.den[i][j])
832+
out[i][j] = signalTransferFunction(self.num[i][j],
833+
self.den[i][j],
834+
**kwdt)
815835

816836
return out
817837

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy