Skip to content

Commit 3001e6d

Browse files
committed
FIX make shuffle / resample pass-through indexing utilities
1 parent f0f4c79 commit 3001e6d

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

sklearn/utils/__init__.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def resample(*arrays, **options):
169169
170170
Parameters
171171
----------
172-
*arrays : sequence of arrays or scipy.sparse matrices with same shape[0]
172+
*arrays : sequence of indexable data-structures
173+
Indexable data-structures can be arrays, lists, dataframes or scipy
174+
sparse matrices with consistent first dimension.
173175
174176
replace : boolean, True by default
175177
Implements resampling with replacement. If False, this will implement
@@ -184,16 +186,15 @@ def resample(*arrays, **options):
184186
185187
Returns
186188
-------
187-
resampled_arrays : sequence of arrays or scipy.sparse matrices with same \
188-
shape[0]
189-
Sequence of resampled views of the collections. The original arrays are
189+
resampled_arrays : sequence of indexable data-structures
190+
Sequence of resampled views of the collections. The original arrays are
190191
not impacted.
191192
192193
Examples
193194
--------
194195
It is possible to mix sparse and dense arrays in the same run::
195196
196-
>>> X = [[1., 0.], [2., 1.], [0., 0.]]
197+
>>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
197198
>>> y = np.array([0, 1, 2])
198199
199200
>>> from scipy.sparse import coo_matrix
@@ -247,8 +248,6 @@ def resample(*arrays, **options):
247248
max_n_samples, n_samples))
248249

249250
check_consistent_length(*arrays)
250-
arrays = [check_array(x, accept_sparse='csr', ensure_2d=False,
251-
allow_nd=True) for x in arrays]
252251

253252
if replace:
254253
indices = random_state.randint(0, n_samples, size=(max_n_samples,))
@@ -257,12 +256,9 @@ def resample(*arrays, **options):
257256
random_state.shuffle(indices)
258257
indices = indices[:max_n_samples]
259258

260-
resampled_arrays = []
261-
262-
for array in arrays:
263-
array = array[indices]
264-
resampled_arrays.append(array)
265-
259+
# convert sparse matrices to CSR for row-based indexing
260+
arrays = [a.tocsr() if issparse(a) else a for a in arrays]
261+
resampled_arrays = [safe_indexing(a, indices) for a in arrays]
266262
if len(resampled_arrays) == 1:
267263
# syntactic sugar for the unit argument case
268264
return resampled_arrays[0]
@@ -278,7 +274,9 @@ def shuffle(*arrays, **options):
278274
279275
Parameters
280276
----------
281-
*arrays : sequence of arrays or scipy.sparse matrices with same shape[0]
277+
*arrays : sequence of indexable data-structures
278+
Indexable data-structures can be arrays, lists, dataframes or scipy
279+
sparse matrices with consistent first dimension.
282280
283281
random_state : int or RandomState instance
284282
Control the shuffling for reproducible behavior.
@@ -289,16 +287,15 @@ def shuffle(*arrays, **options):
289287
290288
Returns
291289
-------
292-
shuffled_arrays : sequence of arrays or scipy.sparse matrices with same \
293-
shape[0]
290+
shuffled_arrays : sequence of indexable data-structures
294291
Sequence of shuffled views of the collections. The original arrays are
295292
not impacted.
296293
297294
Examples
298295
--------
299296
It is possible to mix sparse and dense arrays in the same run::
300297
301-
>>> X = [[1., 0.], [2., 1.], [0., 0.]]
298+
>>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
302299
>>> y = np.array([0, 1, 2])
303300
304301
>>> from scipy.sparse import coo_matrix

sklearn/utils/tests/test_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,36 @@ def to_tuple(A): # to make the inner arrays hashable
186186
S = set(to_tuple(A))
187187
shuffle(A) # shouldn't raise a ValueError for dim = 3
188188
assert_equal(set(to_tuple(A)), S)
189+
190+
191+
def test_shuffle_dont_convert_to_array():
192+
# Check that shuffle does not try to convert to numpy arrays with float
193+
# dtypes can let any indexable datastructure pass-through.
194+
a = ['a', 'b', 'c']
195+
b = np.array(['a', 'b', 'c'], dtype=object)
196+
c = [1, 2, 3]
197+
d = MockDataFrame(np.array([['a', 0],
198+
['b', 1],
199+
['c', 2]],
200+
dtype=object))
201+
e = sp.csc_matrix(np.arange(6).reshape(3, 2))
202+
a_s, b_s, c_s, d_s, e_s = shuffle(a, b, c, d, e, random_state=0)
203+
204+
assert_equal(a_s, ['c', 'b', 'a'])
205+
assert_equal(type(a_s), list)
206+
207+
assert_array_equal(b_s, ['c', 'b', 'a'])
208+
assert_equal(b_s.dtype, object)
209+
210+
assert_equal(c_s, [3, 2, 1])
211+
assert_equal(type(c_s), list)
212+
213+
assert_array_equal(d_s, np.array([['c', 2],
214+
['b', 1],
215+
['a', 0]],
216+
dtype=object))
217+
assert_equal(type(d_s), MockDataFrame)
218+
219+
assert_array_equal(e_s.toarray(), np.array([[4, 5],
220+
[2, 3],
221+
[0, 1]]))

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy