Skip to content

Commit 4eda9e6

Browse files
committed
Merge pull request scikit-learn#4770 from TomDLT/logistic
[MRG+3] improve parameter check in LogisticRegression
2 parents 549ecae + 355dc7c commit 4eda9e6

File tree

2 files changed

+77
-54
lines changed

2 files changed

+77
-54
lines changed

sklearn/linear_model/logistic.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,28 @@ def hessp(v):
391391
return grad, hessp
392392

393393

394+
def _check_solver_option(solver, multi_class, penalty, dual):
395+
if solver not in ['liblinear', 'newton-cg', 'lbfgs']:
396+
raise ValueError("Logistic Regression supports only liblinear,"
397+
" newton-cg and lbfgs solvers, got %s" % solver)
398+
399+
if multi_class not in ['multinomial', 'ovr']:
400+
raise ValueError("multi_class should be either multinomial or "
401+
"ovr, got %s" % multi_class)
402+
403+
if multi_class == 'multinomial' and solver == 'liblinear':
404+
raise ValueError("Solver %s does not support "
405+
"a multinomial backend." % solver)
406+
407+
if solver != 'liblinear':
408+
if penalty != 'l2':
409+
raise ValueError("Solver %s supports only l2 penalties, "
410+
"got %s penalty." % (solver, penalty))
411+
if dual:
412+
raise ValueError("Solver %s supports only "
413+
"dual=False, got dual=%s" % (solver, dual))
414+
415+
394416
def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
395417
max_iter=100, tol=1e-4, verbose=0,
396418
solver='lbfgs', coef=None, copy=True,
@@ -501,25 +523,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
501523
if isinstance(Cs, numbers.Integral):
502524
Cs = np.logspace(-4, 4, Cs)
503525

504-
if multi_class not in ['multinomial', 'ovr']:
505-
raise ValueError("multi_class can be either 'multinomial' or 'ovr'"
506-
"got %s" % multi_class)
507-
508-
if solver not in ['liblinear', 'newton-cg', 'lbfgs']:
509-
raise ValueError("Logistic Regression supports only liblinear,"
510-
" newton-cg and lbfgs solvers. got %s" % solver)
511-
512-
if multi_class == 'multinomial' and solver == 'liblinear':
513-
raise ValueError("Solver %s cannot solve problems with "
514-
"a multinomial backend." % solver)
526+
_check_solver_option(solver, multi_class, penalty, dual)
515527

516-
if solver != 'liblinear':
517-
if penalty != 'l2':
518-
raise ValueError("newton-cg and lbfgs solvers support only "
519-
"l2 penalties, got %s penalty." % penalty)
520-
if dual:
521-
raise ValueError("newton-cg and lbfgs solvers support only "
522-
"dual=False, got dual=%s" % dual)
523528
# Preprocessing.
524529
X = check_array(X, accept_sparse='csr', dtype=np.float64)
525530
y = check_array(y, ensure_2d=False, copy=copy, dtype=None)
@@ -781,6 +786,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
781786
scores : ndarray, shape (n_cs,)
782787
Scores obtained for each Cs.
783788
"""
789+
_check_solver_option(solver, multi_class, penalty, dual)
784790

785791
log_reg = LogisticRegression(fit_intercept=fit_intercept)
786792

@@ -1015,18 +1021,9 @@ def fit(self, X, y):
10151021

10161022
X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, order="C")
10171023
self.classes_ = np.unique(y)
1018-
if self.solver not in ['liblinear', 'newton-cg', 'lbfgs']:
1019-
raise ValueError(
1020-
"Logistic Regression supports only liblinear, newton-cg and "
1021-
"lbfgs solvers, Got solver=%s" % self.solver
1022-
)
10231024

1024-
if self.solver == 'liblinear' and self.multi_class == 'multinomial':
1025-
raise ValueError("Solver %s does not support a multinomial "
1026-
"backend." % self.solver)
1027-
if self.multi_class not in ['ovr', 'multinomial']:
1028-
raise ValueError("multi_class should be either ovr or multinomial "
1029-
"got %s" % self.multi_class)
1025+
_check_solver_option(self.solver, self.multi_class, self.penalty,
1026+
self.dual)
10301027

10311028
if self.solver == 'liblinear':
10321029
self.coef_, self.intercept_, self.n_iter_ = _fit_liblinear(
@@ -1308,22 +1305,19 @@ def fit(self, X, y):
13081305
self : object
13091306
Returns self.
13101307
"""
1311-
if self.solver != 'liblinear':
1312-
if self.penalty != 'l2':
1313-
raise ValueError("newton-cg and lbfgs solvers support only "
1314-
"l2 penalties.")
1315-
if self.dual:
1316-
raise ValueError("newton-cg and lbfgs solvers support only "
1317-
"the primal form.")
1308+
_check_solver_option(self.solver, self.multi_class, self.penalty,
1309+
self.dual)
1310+
1311+
if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0:
1312+
raise ValueError("Maximum number of iteration must be positive;"
1313+
" got (max_iter=%r)" % self.max_iter)
1314+
if not isinstance(self.tol, numbers.Number) or self.tol < 0:
1315+
raise ValueError("Tolerance for stopping criteria must be "
1316+
"positive; got (tol=%r)" % self.tol)
13181317

13191318
X = check_array(X, accept_sparse='csr', dtype=np.float64)
13201319
y = check_array(y, ensure_2d=False, dtype=None)
13211320

1322-
if self.multi_class not in ['ovr', 'multinomial']:
1323-
raise ValueError("multi_class backend should be either "
1324-
"'ovr' or 'multinomial'"
1325-
" got %s" % self.multi_class)
1326-
13271321
if y.ndim == 2 and y.shape[1] == 1:
13281322
warnings.warn(
13291323
"A column-vector y was passed when a 1d array was"

sklearn/linear_model/tests/test_logistic.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from sklearn.utils.testing import assert_almost_equal
66
from sklearn.utils.testing import assert_array_equal
77
from sklearn.utils.testing import assert_array_almost_equal
8-
from sklearn.utils.testing import assert_raises_regexp
98
from sklearn.utils.testing import assert_equal
109
from sklearn.utils.testing import assert_greater
1110
from sklearn.utils.testing import assert_raises
@@ -69,22 +68,19 @@ def test_predict_2_classes():
6968
def test_error():
7069
# Test for appropriate exception on errors
7170
msg = "Penalty term must be positive"
72-
assert_raises_regexp(ValueError, msg,
71+
assert_raise_message(ValueError, msg,
7372
LogisticRegression(C=-1).fit, X, Y1)
74-
assert_raises_regexp(ValueError, msg,
73+
assert_raise_message(ValueError, msg,
7574
LogisticRegression(C="test").fit, X, Y1)
7675

77-
msg = "Tolerance for stopping criteria must be positive"
78-
assert_raises_regexp(ValueError, msg,
79-
LogisticRegression(tol=-1).fit, X, Y1)
80-
assert_raises_regexp(ValueError, msg,
81-
LogisticRegression(tol="test").fit, X, Y1)
76+
for LR in [LogisticRegression, LogisticRegressionCV]:
77+
msg = "Tolerance for stopping criteria must be positive"
78+
assert_raise_message(ValueError, msg, LR(tol=-1).fit, X, Y1)
79+
assert_raise_message(ValueError, msg, LR(tol="test").fit, X, Y1)
8280

83-
msg = "Maximum number of iteration must be positive"
84-
assert_raises_regexp(ValueError, msg,
85-
LogisticRegression(max_iter=-1).fit, X, Y1)
86-
assert_raises_regexp(ValueError, msg,
87-
LogisticRegression(max_iter="test").fit, X, Y1)
81+
msg = "Maximum number of iteration must be positive"
82+
assert_raise_message(ValueError, msg, LR(max_iter=-1).fit, X, Y1)
83+
assert_raise_message(ValueError, msg, LR(max_iter="test").fit, X, Y1)
8884

8985

9086
def test_predict_3_classes():
@@ -126,6 +122,39 @@ def test_multinomial_validation():
126122
assert_raises(ValueError, lr.fit, [[0, 1], [1, 0]], [0, 1])
127123

128124

125+
def test_check_solver_option():
126+
X, y = iris.data, iris.target
127+
for LR in [LogisticRegression, LogisticRegressionCV]:
128+
129+
msg = ("Logistic Regression supports only liblinear, newton-cg and"
130+
" lbfgs solvers, got wrong_name")
131+
lr = LR(solver="wrong_name")
132+
assert_raise_message(ValueError, msg, lr.fit, X, y)
133+
134+
msg = "multi_class should be either multinomial or ovr, got wrong_name"
135+
lr = LR(solver='newton-cg', multi_class="wrong_name")
136+
assert_raise_message(ValueError, msg, lr.fit, X, y)
137+
138+
# all solver except 'newton-cg' and 'lfbgs'
139+
for solver in ['liblinear']:
140+
msg = ("Solver %s does not support a multinomial backend." %
141+
solver)
142+
lr = LR(solver=solver, multi_class='multinomial')
143+
assert_raise_message(ValueError, msg, lr.fit, X, y)
144+
145+
# all solvers except 'liblinear'
146+
for solver in ['newton-cg', 'lbfgs']:
147+
msg = ("Solver %s supports only l2 penalties, got l1 penalty." %
148+
solver)
149+
lr = LR(solver=solver, penalty='l1')
150+
assert_raise_message(ValueError, msg, lr.fit, X, y)
151+
152+
msg = ("Solver %s supports only dual=False, got dual=True" %
153+
solver)
154+
lr = LR(solver=solver, dual=True)
155+
assert_raise_message(ValueError, msg, lr.fit, X, y)
156+
157+
129158
def test_multinomial_binary():
130159
# Test multinomial LR on a binary problem.
131160
target = (iris.target > 0).astype(np.intp)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy