Skip to content

Commit fc9d7be

Browse files
committed
FIX use random_state in LogisticRegression
1 parent 4eda9e6 commit fc9d7be

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

sklearn/linear_model/logistic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..preprocessing import LabelEncoder, LabelBinarizer
2121
from ..svm.base import _fit_liblinear
2222
from ..utils import check_array, check_consistent_length, compute_class_weight
23+
from ..utils import check_random_state
2324
from ..utils.extmath import (logsumexp, log_logistic, safe_sparse_dot,
2425
squared_norm)
2526
from ..utils.optimize import newton_cg
@@ -417,7 +418,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
417418
max_iter=100, tol=1e-4, verbose=0,
418419
solver='lbfgs', coef=None, copy=True,
419420
class_weight=None, dual=False, penalty='l2',
420-
intercept_scaling=1., multi_class='ovr'):
421+
intercept_scaling=1., multi_class='ovr',
422+
random_state=None):
421423
"""Compute a Logistic Regression model for a list of regularization
422424
parameters.
423425
@@ -502,8 +504,12 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
502504
Multiclass option can be either 'ovr' or 'multinomial'. If the option
503505
chosen is 'ovr', then a binary problem is fit for each label. Else
504506
the loss minimised is the multinomial loss fit across
505-
the entire probability distribution. Works only for the 'lbfgs'
506-
solver.
507+
the entire probability distribution. Works only for the 'lbfgs' and
508+
'newton-cg' solvers.
509+
510+
random_state : int seed, RandomState instance, or None (default)
511+
The seed of the pseudo random number generator to use when
512+
shuffling the data.
507513
508514
Returns
509515
-------
@@ -531,6 +537,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
531537
_, n_features = X.shape
532538
check_consistent_length(X, y)
533539
classes = np.unique(y)
540+
random_state = check_random_state(random_state)
534541

535542
if pos_class is None and multi_class != 'multinomial':
536543
if (classes.size > 2):
@@ -659,7 +666,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
659666
elif solver == 'liblinear':
660667
coef_, intercept_, _, = _fit_liblinear(
661668
X, y, C, fit_intercept, intercept_scaling, class_weight,
662-
penalty, dual, verbose, max_iter, tol,
669+
penalty, dual, verbose, max_iter, tol, random_state
663670
)
664671
if fit_intercept:
665672
w0 = np.concatenate([coef_.ravel(), intercept_])
@@ -1029,7 +1036,7 @@ def fit(self, X, y):
10291036
self.coef_, self.intercept_, self.n_iter_ = _fit_liblinear(
10301037
X, y, self.C, self.fit_intercept, self.intercept_scaling,
10311038
self.class_weight, self.penalty, self.dual, self.verbose,
1032-
self.max_iter, self.tol
1039+
self.max_iter, self.tol, self.random_state
10331040
)
10341041
return self
10351042

sklearn/linear_model/tests/test_logistic.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,22 @@ def test_consistency_path():
266266
assert_array_almost_equal(lr_coef, coefs[0], decimal=4)
267267

268268

269-
def test_liblinear_random_state():
269+
def test_liblinear_dual_random_state():
270+
# random_state is relevant for liblinear solver only if dual=True
270271
X, y = make_classification(n_samples=20)
271-
lr1 = LogisticRegression(random_state=0)
272+
lr1 = LogisticRegression(random_state=0, dual=True, max_iter=1, tol=1e-15)
272273
lr1.fit(X, y)
273-
lr2 = LogisticRegression(random_state=0)
274+
lr2 = LogisticRegression(random_state=0, dual=True, max_iter=1, tol=1e-15)
274275
lr2.fit(X, y)
276+
lr3 = LogisticRegression(random_state=8, dual=True, max_iter=1, tol=1e-15)
277+
lr3.fit(X, y)
278+
279+
# same result for same random state
275280
assert_array_almost_equal(lr1.coef_, lr2.coef_)
281+
# different results for different random states
282+
msg = "Arrays are not almost equal to 6 decimals"
283+
assert_raise_message(AssertionError, msg,
284+
assert_array_almost_equal, lr1.coef_, lr3.coef_)
276285

277286

278287
def test_logistic_loss_and_grad():

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy