Skip to content

Commit 5c46c0c

Browse files
committed
Merge pull request scikit-learn#3248 from MechCoder/remove_precompute_multi
[MRG] Remove unused param precompute from MultiTask models
2 parents d298a37 + ae7a4ad commit 5c46c0c

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

sklearn/linear_model/coordinate_descent.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,13 @@ def _path_residuals(X, y, train, test, path, path_params, alphas=None,
889889
y_test = y[test]
890890
fit_intercept = path_params['fit_intercept']
891891
normalize = path_params['normalize']
892-
precompute = path_params['precompute']
892+
893+
if y.ndim == 1:
894+
precompute = path_params['precompute']
895+
else:
896+
# No Gram variant of multi-task exists right now.
897+
# Fall back to default enet_multitask
898+
precompute = False
893899

894900
X_train, y_train, X_mean, y_mean, X_std, precompute, Xy = \
895901
_pre_fit(X_train, y_train, None, precompute, normalize, fit_intercept,
@@ -1638,11 +1644,6 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
16381644
List of alphas where to compute the models.
16391645
If not provided, set automatically.
16401646
1641-
precompute : True | False | 'auto' | array-like
1642-
Whether to use a precomputed Gram matrix to speed up
1643-
calculations. If set to ``'auto'`` let us decide. The Gram
1644-
matrix can also be passed as argument.
1645-
16461647
n_alphas : int, optional
16471648
Number of alphas along the regularization path
16481649
@@ -1716,8 +1717,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
17161717
... #doctest: +NORMALIZE_WHITESPACE
17171718
MultiTaskElasticNetCV(alphas=None, copy_X=True, cv=None, eps=0.001,
17181719
fit_intercept=True, l1_ratio=0.5, max_iter=1000, n_alphas=100,
1719-
n_jobs=1, normalize=False, precompute='auto', tol=0.0001,
1720-
verbose=0)
1720+
n_jobs=1, normalize=False, tol=0.0001, verbose=0)
17211721
>>> print(clf.coef_)
17221722
[[ 0.52875032 0.46958558]
17231723
[ 0.52875032 0.46958558]]
@@ -1740,7 +1740,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
17401740
path = staticmethod(enet_path)
17411741

17421742
def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
1743-
fit_intercept=True, normalize=False, precompute='auto',
1743+
fit_intercept=True, normalize=False,
17441744
max_iter=1000, tol=1e-4, cv=None, copy_X=True,
17451745
verbose=0, n_jobs=1):
17461746
self.l1_ratio = l1_ratio
@@ -1749,7 +1749,6 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
17491749
self.alphas = alphas
17501750
self.fit_intercept = fit_intercept
17511751
self.normalize = normalize
1752-
self.precompute = precompute
17531752
self.max_iter = max_iter
17541753
self.tol = tol
17551754
self.cv = cv
@@ -1781,11 +1780,6 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin):
17811780
List of alphas where to compute the models.
17821781
If not provided, set automaticlly.
17831782
1784-
precompute : True | False | 'auto' | array-like
1785-
Whether to use a precomputed Gram matrix to speed up
1786-
calculations. If set to ``'auto'`` let us decide. The Gram
1787-
matrix can also be passed as argument.
1788-
17891783
n_alphas : int, optional
17901784
Number of alphas along the regularization path
17911785
@@ -1856,10 +1850,10 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin):
18561850
path = staticmethod(lasso_path)
18571851

18581852
def __init__(self, eps=1e-3, n_alphas=100, alphas=None, fit_intercept=True,
1859-
normalize=False, precompute='auto', max_iter=1000, tol=1e-4,
1860-
copy_X=True, cv=None, verbose=False, n_jobs=1):
1853+
normalize=False, max_iter=1000, tol=1e-4, copy_X=True,
1854+
cv=None, verbose=False, n_jobs=1):
18611855
super(MultiTaskLassoCV, self).__init__(
18621856
eps=eps, n_alphas=n_alphas, alphas=alphas,
18631857
fit_intercept=fit_intercept, normalize=normalize,
1864-
precompute=precompute, max_iter=max_iter, tol=tol, copy_X=copy_X,
1858+
max_iter=max_iter, tol=tol, copy_X=copy_X,
18651859
cv=cv, verbose=verbose, n_jobs=n_jobs)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy