Skip to content

Commit a76b5be

Browse files
committed
don't to input validation in each tree for RandomForest.predict
1 parent 6d604c9 commit a76b5be

File tree

2 files changed

+31
-16
lines changed

2 files changed

+31
-16
lines changed

sklearn/ensemble/forest.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ def _set_oob_score(self, X, y):
365365
mask = np.ones(n_samples, dtype=np.bool)
366366
mask[estimator.indices_] = False
367367
mask_indices = sample_indices[mask]
368-
p_estimator = estimator.predict_proba(X[mask_indices, :])
368+
p_estimator = estimator.predict_proba(X[mask_indices, :],
369+
check_input=False)
369370

370371
if self.n_outputs_ == 1:
371372
p_estimator = [p_estimator]
@@ -508,7 +509,7 @@ class in a leaf.
508509
# Parallel loop
509510
all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose,
510511
backend="threading")(
511-
delayed(_parallel_helper)(e, 'predict_proba', X)
512+
delayed(_parallel_helper)(e, 'predict_proba', X, check_input=False)
512513
for e in self.estimators_)
513514

514515
# Reduce
@@ -614,6 +615,10 @@ def predict(self, X):
614615

615616
# Check data
616617
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
618+
if issparse(X) and (X.indices.dtype != np.intc or
619+
X.indptr.dtype != np.intc):
620+
raise ValueError("No support for np.int64 index based "
621+
"sparse matrices")
617622

618623
# Assign chunk of trees to jobs
619624
n_jobs, n_trees, starts = _partition_estimators(self.n_estimators,
@@ -622,7 +627,7 @@ def predict(self, X):
622627
# Parallel loop
623628
all_y_hat = Parallel(n_jobs=n_jobs, verbose=self.verbose,
624629
backend="threading")(
625-
delayed(_parallel_helper)(e, 'predict', X)
630+
delayed(_parallel_helper)(e, 'predict', X, check_input=False)
626631
for e in self.estimators_)
627632

628633
# Reduce
@@ -642,7 +647,7 @@ def _set_oob_score(self, X, y):
642647
mask = np.ones(n_samples, dtype=np.bool)
643648
mask[estimator.indices_] = False
644649
mask_indices = sample_indices[mask]
645-
p_estimator = estimator.predict(X[mask_indices, :])
650+
p_estimator = estimator.predict(X[mask_indices, :], check_input=False)
646651

647652
if self.n_outputs_ == 1:
648653
p_estimator = p_estimator[:, np.newaxis]

sklearn/tree/tree.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
309309

310310
return self
311311

312-
def predict(self, X):
312+
def predict(self, X, check_input=True):
313313
"""Predict class or regression value for X.
314314
315315
For a classification model, the predicted class for each sample in X is
@@ -323,16 +323,21 @@ def predict(self, X):
323323
``dtype=np.float32`` and if a sparse matrix is provided
324324
to a sparse ``csr_matrix``.
325325
326+
check_input : boolean, (default=True)
327+
Allow to bypass several input checking.
328+
Don't use this parameter unless you know what you do.
329+
326330
Returns
327331
-------
328332
y : array of shape = [n_samples] or [n_samples, n_outputs]
329333
The predicted classes, or the predict values.
330334
"""
331-
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
332-
if issparse(X) and (X.indices.dtype != np.intc or
333-
X.indptr.dtype != np.intc):
334-
raise ValueError("No support for np.int64 index based "
335-
"sparse matrices")
335+
if check_input:
336+
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
337+
if issparse(X) and (X.indices.dtype != np.intc or
338+
X.indptr.dtype != np.intc):
339+
raise ValueError("No support for np.int64 index based "
340+
"sparse matrices")
336341

337342
n_samples, n_features = X.shape
338343

@@ -541,12 +546,16 @@ def __init__(self,
541546
class_weight=class_weight,
542547
random_state=random_state)
543548

544-
def predict_proba(self, X):
549+
def predict_proba(self, X, check_input=True):
545550
"""Predict class probabilities of the input samples X.
546551
547552
The predicted class probability is the fraction of samples of the same
548553
class in a leaf.
549554
555+
check_input : boolean, (default=True)
556+
Allow to bypass several input checking.
557+
Don't use this parameter unless you know what you do.
558+
550559
Parameters
551560
----------
552561
X : array-like or sparse matrix of shape = [n_samples, n_features]
@@ -562,11 +571,12 @@ class in a leaf.
562571
classes corresponds to that in the attribute `classes_`.
563572
"""
564573
check_is_fitted(self, 'n_outputs_')
565-
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
566-
if issparse(X) and (X.indices.dtype != np.intc or
567-
X.indptr.dtype != np.intc):
568-
raise ValueError("No support for np.int64 index based "
569-
"sparse matrices")
574+
if check_input:
575+
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
576+
if issparse(X) and (X.indices.dtype != np.intc or
577+
X.indptr.dtype != np.intc):
578+
raise ValueError("No support for np.int64 index based "
579+
"sparse matrices")
570580

571581
n_samples, n_features = X.shape
572582

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy