Skip to content

Commit 6d604c9

Browse files
committed
FIX be robust to columns name dtype and also to dataframes that hold dtype=object.
1 parent 3001e6d commit 6d604c9

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

sklearn/utils/tests/test_validation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn.utils.testing import assert_raises_regexp
1414
from sklearn.utils import as_float_array, check_array, check_symmetric
1515
from sklearn.utils import check_X_y
16+
from sklearn.utils.mocking import MockDataFrame
1617
from sklearn.utils.estimator_checks import NotAnArray
1718
from sklearn.random_projection import sparse_random_matrix
1819
from sklearn.linear_model import ARDRegression
@@ -218,6 +219,25 @@ def test_check_array():
218219
assert_true(isinstance(result, np.ndarray))
219220

220221

222+
def test_check_array_pandas_dtype_object_conversion():
223+
# test that data-frame like objects with dtype object
224+
# get converted
225+
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.object)
226+
X_df = MockDataFrame(X)
227+
assert_equal(check_array(X_df).dtype.kind, "f")
228+
assert_equal(check_array(X_df, ensure_2d=False).dtype.kind, "f")
229+
# smoke-test against dataframes with column named "dtype"
230+
X_df.dtype = "Hans"
231+
assert_equal(check_array(X_df, ensure_2d=False).dtype.kind, "f")
232+
233+
234+
def test_check_array_dtype_stability():
235+
# test that lists with ints don't get converted to floats
236+
X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
237+
assert_equal(check_array(X).dtype.kind, "i")
238+
assert_equal(check_array(X, ensure_2d=False).dtype.kind, "i")
239+
240+
221241
def test_check_array_min_samples_and_features_messages():
222242
# empty list is considered 2D by default:
223243
msg = "0 feature(s) (shape=(1, 0)) while a minimum of 1 is required."

sklearn/utils/validation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,21 +324,27 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
324324
if isinstance(accept_sparse, str):
325325
accept_sparse = [accept_sparse]
326326

327+
# store whether originally we wanted numeric dtype
328+
dtype_numeric = dtype == "numeric"
329+
327330
if sp.issparse(array):
328-
if dtype == "numeric":
331+
if dtype_numeric:
329332
dtype = None
330333
array = _ensure_sparse_format(array, accept_sparse, dtype, order,
331334
copy, force_all_finite)
332335
else:
333336
if ensure_2d:
334337
array = np.atleast_2d(array)
335-
if dtype == "numeric":
336-
if hasattr(array, "dtype") and array.dtype.kind == "O":
338+
if dtype_numeric:
339+
if hasattr(array, "dtype") and getattr(array.dtype, "kind", None) == "O":
337340
# if input is object, convert to float.
338341
dtype = np.float64
339342
else:
340343
dtype = None
341344
array = np.array(array, dtype=dtype, order=order, copy=copy)
345+
# make sure we actually converted to numeric:
346+
if dtype_numeric and array.dtype.kind == "O":
347+
array = array.astype(np.float64)
342348
if not allow_nd and array.ndim >= 3:
343349
raise ValueError("Found array with dim %d. Expected <= 2" %
344350
array.ndim)
@@ -353,7 +359,6 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
353359
" minimum of %d is required."
354360
% (n_samples, shape_repr, ensure_min_samples))
355361

356-
357362
if ensure_min_features > 0 and array.ndim == 2:
358363
n_features = array.shape[1]
359364
if n_features < ensure_min_features:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy