We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 083b5f5 commit 37c65e5Copy full SHA for 37c65e5
sklearn/linear_model/coordinate_descent.py
@@ -467,12 +467,16 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
467
coef_, l1_reg, l2_reg, X.data, X.indices,
468
X.indptr, y, X_sparse_scaling,
469
max_iter, tol, positive)
470
- elif not multi_output:
471
- model = cd_fast.enet_coordinate_descent(
472
- coef_, l1_reg, l2_reg, X, y, max_iter, tol, positive)
473
- else:
+ elif multi_output:
474
model = cd_fast.enet_coordinate_descent_multi_task(
475
coef_, l1_reg, l2_reg, X, y, max_iter, tol)
+ elif isinstance(precompute, np.ndarray):
+ model = cd_fast.enet_coordinate_descent_gram(
+ coef_, l1_reg, l2_reg, precompute, Xy, y, max_iter,
476
+ tol, positive)
477
+ else:
478
+ model = cd_fast.enet_coordinate_descent(
479
+ coef_, l1_reg, l2_reg, X, y, max_iter, tol, positive)
480
coef_, dual_gap_, eps_ = model
481
coefs[..., i] = coef_
482
dual_gaps[i] = dual_gap_
0 commit comments