Skip to content

Commit 7097af0

Browse files
committed
PSL and CCA convergence check and input validation fixes
1 parent 28546c8 commit 7097af0

File tree

3 files changed

+82
-42
lines changed

3 files changed

+82
-42
lines changed

sklearn/cross_decomposition/pls_.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -235,17 +235,18 @@ def fit(self, X, Y):
235235

236236
# copy since this will contains the residuals (deflated) matrices
237237
check_consistent_length(X, Y)
238-
X = check_array(X, dtype=np.float, copy=self.copy)
239-
Y = check_array(Y, dtype=np.float, copy=self.copy, ensure_2d=False)
238+
X = check_array(X, dtype=np.float64, copy=self.copy)
239+
Y = check_array(Y, dtype=np.float64, copy=self.copy, ensure_2d=False)
240240
if Y.ndim == 1:
241-
Y = Y[:, None]
241+
Y = Y.reshape(-1, 1)
242242

243243
n = X.shape[0]
244244
p = X.shape[1]
245245
q = Y.shape[1]
246246

247247
if self.n_components < 1 or self.n_components > p:
248-
raise ValueError('invalid number of components')
248+
raise ValueError('Invalid number of components: %d' %
249+
self.n_components)
249250
if self.algorithm not in ("svd", "nipals"):
250251
raise ValueError("Got algorithm %s when only 'svd' "
251252
"and 'nipals' are known" % self.algorithm)
@@ -271,6 +272,10 @@ def fit(self, X, Y):
271272

272273
# NIPALS algo: outer loop, over components
273274
for k in range(self.n_components):
275+
if np.all(np.dot(Yk.T, Yk) < np.finfo(np.double).eps):
276+
# Yk constant
277+
warnings.warn('Y residual constant at iteration %s' % k)
278+
break
274279
#1) weights estimation (inner loop)
275280
# -----------------------------------
276281
if self.algorithm == "nipals":
@@ -291,6 +296,7 @@ def fit(self, X, Y):
291296
# test for null variance
292297
if np.dot(x_scores.T, x_scores) < np.finfo(np.double).eps:
293298
warnings.warn('X scores are null at iteration %s' % k)
299+
break
294300
#2) Deflation (in place)
295301
# ----------------------
296302
# Possible memory footprint reduction may done here: in order to
@@ -335,6 +341,7 @@ def fit(self, X, Y):
335341
self.y_rotations_ = np.ones(1)
336342

337343
if True or self.deflation_mode == "regression":
344+
# FIXME what's with the if?
338345
# Estimate regression coefficient
339346
# Regress Y on T
340347
# Y = TQ' + Err,
@@ -367,23 +374,19 @@ def transform(self, X, Y=None, copy=True):
367374
x_scores if Y is not given, (x_scores, y_scores) otherwise.
368375
"""
369376
check_is_fitted(self, 'x_mean_')
377+
X = check_array(X, copy=copy)
370378
# Normalize
371-
if copy:
372-
Xc = (np.asarray(X) - self.x_mean_) / self.x_std_
373-
if Y is not None:
374-
Yc = (np.asarray(Y) - self.y_mean_) / self.y_std_
375-
else:
376-
X = np.asarray(X)
377-
Xc -= self.x_mean_
378-
Xc /= self.x_std_
379-
if Y is not None:
380-
Y = np.asarray(Y)
381-
Yc -= self.y_mean_
382-
Yc /= self.y_std_
379+
X -= self.x_mean_
380+
X /= self.x_std_
383381
# Apply rotation
384-
x_scores = np.dot(Xc, self.x_rotations_)
382+
x_scores = np.dot(X, self.x_rotations_)
385383
if Y is not None:
386-
y_scores = np.dot(Yc, self.y_rotations_)
384+
Y = check_array(Y, ensure_2d=False, copy=copy)
385+
if Y.ndim == 1:
386+
Y = Y.reshape(-1, 1)
387+
Y -= self.y_mean_
388+
Y /= self.y_std_
389+
y_scores = np.dot(Y, self.y_rotations_)
387390
return x_scores, y_scores
388391

389392
return x_scores
@@ -406,14 +409,11 @@ def predict(self, X, copy=True):
406409
be an issue in high dimensional space.
407410
"""
408411
check_is_fitted(self, 'x_mean_')
412+
X = check_array(X, copy=copy)
409413
# Normalize
410-
if copy:
411-
Xc = (np.asarray(X) - self.x_mean_)
412-
else:
413-
X = np.asarray(X)
414-
Xc -= self.x_mean_
415-
Xc /= self.x_std_
416-
Ypred = np.dot(Xc, self.coef_)
414+
X -= self.x_mean_
415+
X /= self.x_std_
416+
Ypred = np.dot(X, self.coef_)
417417
return Ypred + self.y_mean_
418418

419419
def fit_transform(self, X, y=None, **fit_params):
@@ -724,13 +724,15 @@ def __init__(self, n_components=2, scale=True, copy=True):
724724
def fit(self, X, Y):
725725
# copy since this will contains the centered data
726726
check_consistent_length(X, Y)
727-
X = check_array(X, dtype=np.float, copy=self.copy)
728-
Y = check_array(Y, dtype=np.float, copy=self.copy)
729-
730-
p = X.shape[1]
727+
X = check_array(X, dtype=np.float64, copy=self.copy)
728+
Y = check_array(Y, dtype=np.float64, copy=self.copy, ensure_2d=False)
729+
if Y.ndim == 1:
730+
Y = Y.reshape(-1, 1)
731731

732-
if self.n_components < 1 or self.n_components > p:
733-
raise ValueError('invalid number of components')
732+
if self.n_components > max(Y.shape[1], X.shape[1]):
733+
raise ValueError("Invalid number of components n_components=%d with "
734+
"X of shape %s and Y of shape %s."
735+
% (self.n_components, str(X.shape), str(Y.shape)))
734736

735737
# Scale (in place)
736738
X, Y, self.x_mean_, self.y_mean_, self.x_std_, self.y_std_ =\
@@ -742,7 +744,7 @@ def fit(self, X, Y):
742744
# components is smaller than rank(X) - 1. Hence, if we want to extract
743745
# all the components (C.shape[1]), we have to use another one. Else,
744746
# let's use arpacks to compute only the interesting components.
745-
if self.n_components == C.shape[1]:
747+
if self.n_components >= np.min(C.shape):
746748
U, s, V = linalg.svd(C, full_matrices=False)
747749
else:
748750
U, s, V = arpack.svds(C, k=self.n_components)
@@ -756,9 +758,12 @@ def fit(self, X, Y):
756758
def transform(self, X, Y=None):
757759
"""Apply the dimension reduction learned on the train data."""
758760
check_is_fitted(self, 'x_mean_')
761+
X = check_array(X, dtype=np.float64)
759762
Xr = (X - self.x_mean_) / self.x_std_
760763
x_scores = np.dot(Xr, self.x_weights_)
761764
if Y is not None:
765+
if Y.ndim == 1:
766+
Y = Y.reshape(-1, 1)
762767
Yr = (Y - self.y_mean_) / self.y_std_
763768
y_scores = np.dot(Yr, self.y_weights_)
764769
return x_scores, y_scores

sklearn/cross_decomposition/tests/test_pls.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
2-
from sklearn.utils.testing import assert_array_almost_equal
2+
from sklearn.utils.testing import (assert_array_almost_equal,
3+
assert_array_equal, assert_true, assert_raise_message)
34
from sklearn.datasets import load_linnerud
45
from sklearn.cross_decomposition import pls_
56
from nose.tools import assert_equal
@@ -248,6 +249,30 @@ def test_univariate_pls_regression():
248249
assert_array_almost_equal(model1, model2)
249250

250251

252+
def test_predict_transform_copy():
253+
# check that the "copy" keyword works
254+
d = load_linnerud()
255+
X = d.data
256+
Y = d.target
257+
clf = pls_.PLSCanonical()
258+
X_copy = X.copy()
259+
Y_copy = Y.copy()
260+
clf.fit(X, Y)
261+
# check that results are identical with copy
262+
assert_array_almost_equal(clf.predict(X), clf.predict(X.copy(), copy=False))
263+
assert_array_almost_equal(clf.transform(X), clf.transform(X.copy(), copy=False))
264+
265+
# check also if passing Y
266+
assert_array_almost_equal(clf.transform(X, Y),
267+
clf.transform(X.copy(), Y.copy(), copy=False))
268+
# check that copy doesn't destroy
269+
# we do want to check exact equality here
270+
assert_array_equal(X_copy, X)
271+
assert_array_equal(Y_copy, Y)
272+
# also check that mean wasn't zero before (to make sure we didn't touch it)
273+
assert_true(np.all(X.mean(axis=0) != 0))
274+
275+
251276
def test_scale():
252277
d = load_linnerud()
253278
X = d.data
@@ -260,3 +285,13 @@ def test_scale():
260285
pls_.PLSSVD()]:
261286
clf.set_params(scale=True)
262287
clf.fit(X, Y)
288+
289+
290+
def test_pls_errors():
291+
d = load_linnerud()
292+
X = d.data
293+
Y = d.target
294+
for clf in [pls_.PLSCanonical(), pls_.PLSRegression(),
295+
pls_.PLSSVD()]:
296+
clf.n_components = 4
297+
assert_raise_message(ValueError, "Invalid number of components", clf.fit, X, Y)

sklearn/tests/test_common.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,25 +95,25 @@ def test_non_meta_estimators():
9595
continue
9696
if name.endswith("HMM") or name.startswith("_"):
9797
continue
98-
if name not in CROSS_DECOMPOSITION:
99-
yield check_estimators_dtypes, name, Estimator
100-
yield check_fit_score_takes_y, name, Estimator
101-
yield check_dtype_object, name, Estimator
98+
yield check_estimators_dtypes, name, Estimator
99+
yield check_fit_score_takes_y, name, Estimator
100+
yield check_dtype_object, name, Estimator
102101

103-
# Check that all estimator yield informative messages when
104-
# trained on empty datasets
105-
yield check_estimators_empty_data_messages, name, Estimator
102+
# Check that all estimator yield informative messages when
103+
# trained on empty datasets
104+
yield check_estimators_empty_data_messages, name, Estimator
106105

107106
if name not in CROSS_DECOMPOSITION + ['SpectralEmbedding']:
108107
# SpectralEmbedding is non-deterministic,
109108
# see issue #4236
109+
# cross-decomposition's "transform" returns X and Y
110110
yield check_pipeline_consistency, name, Estimator
111111

112-
if name not in CROSS_DECOMPOSITION + ['Imputer']:
112+
if name not in ['Imputer']:
113113
# Test that all estimators check their input for NaN's and infs
114114
yield check_estimators_nan_inf, name, Estimator
115115

116-
if name not in CROSS_DECOMPOSITION + ['GaussianProcess']:
116+
if name not in ['GaussianProcess']:
117117
# FIXME!
118118
# in particular GaussianProcess!
119119
yield check_estimators_overwrite_params, name, Estimator

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy