Skip to content

Commit 4cc0235

Browse files
committed
Merge pull request scikit-learn#4352 from amueller/issue-4297-infinite-isotonic_bak
[MRG + 2] Adding fix for issue scikit-learn#4297, isotonic infinite loop
2 parents 555b859 + 16075fb commit 4cc0235

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

sklearn/isotonic.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ def _build_y(self, X, y, sample_weight):
252252
"""Build the y_ IsotonicRegression."""
253253
check_consistent_length(X, y, sample_weight)
254254
X, y = [check_array(x, ensure_2d=False) for x in [X, y]]
255-
if sample_weight is not None:
256-
sample_weight = check_array(sample_weight, ensure_2d=False)
257255

258256
y = as_float_array(y)
259257
self._check_fit_data(X, y, sample_weight)
@@ -264,10 +262,16 @@ def _build_y(self, X, y, sample_weight):
264262
else:
265263
self.increasing_ = self.increasing
266264

265+
# If sample_weights is passed, removed zero-weight values and clean order
266+
if sample_weight is not None:
267+
sample_weight = check_array(sample_weight, ensure_2d=False)
268+
mask = sample_weight > 0
269+
X, y, sample_weight = X[mask], y[mask], sample_weight[mask]
270+
else:
271+
sample_weight = np.ones(len(y))
272+
267273
order = np.lexsort((y, X))
268274
order_inv = np.argsort(order)
269-
if sample_weight is None:
270-
sample_weight = np.ones(len(y))
271275
X, y, sample_weight = [astype(array[order], np.float64, copy=False)
272276
for array in [X, y, sample_weight]]
273277
unique_X, unique_y, unique_sample_weight = _make_unique(X, y, sample_weight)

sklearn/tests/test_isotonic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,29 @@ def test_isotonic_duplicate_min_entry():
325325
all_predictions_finite = np.all(np.isfinite(ir.predict(x)))
326326
assert_true(all_predictions_finite)
327327

328+
329+
def test_isotonic_zero_weight_loop():
330+
# Test from @ogrisel's issue:
331+
# https://github.com/scikit-learn/scikit-learn/issues/4297
332+
333+
# Get deterministic RNG with seed
334+
rng = np.random.RandomState(42)
335+
336+
# Create regression and samples
337+
regression = IsotonicRegression()
338+
n_samples = 50
339+
x = np.linspace(-3, 3, n_samples)
340+
y = x + rng.uniform(size=n_samples)
341+
342+
# Get some random weights and zero out
343+
w = rng.uniform(size=n_samples)
344+
w[5:8] = 0
345+
regression.fit(x, y, sample_weight=w)
346+
347+
# This will hang in failure case.
348+
regression.fit(x, y, sample_weight=w)
349+
350+
328351
if __name__ == "__main__":
329352
import nose
330353
nose.run(argv=['', __file__])

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy