Skip to content

Commit eedc1cd

Browse files
committed
Use more natural class_weight="auto" heuristic
1 parent 2c79a98 commit eedc1cd

File tree

21 files changed

+323
-178
lines changed

21 files changed

+323
-178
lines changed

doc/modules/svm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ Tips on Practical Use
405405
approximates the fraction of training errors and support vectors.
406406

407407
* In :class:`SVC`, if data for classification are unbalanced (e.g. many
408-
positive and few negative), set ``class_weight='auto'`` and/or try
408+
positive and few negative), set ``class_weight='balanced'`` and/or try
409409
different penalty parameters ``C``.
410410

411411
* The underlying :class:`LinearSVC` implementation uses a random

doc/whats_new.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ Enhancements
5656
:class:`linear_model.LogisticRegression`, by avoiding loss computation.
5757
By `Mathieu Blondel`_ and `Tom Dupre la Tour`_.
5858

59+
- Improved heuristic for ``class_weight="auto"`` for classifiers supporting
60+
``class_weight`` by Hanna Wallach and `Andreas Müller`_
61+
5962
Bug fixes
6063
.........
6164

@@ -339,6 +342,7 @@ Enhancements
339342
- :class:`svm.SVC` fitted on sparse input now implements ``decision_function``.
340343
By `Rob Zinkov`_ and `Andreas Müller`_.
341344

345+
342346
Documentation improvements
343347
..........................
344348

@@ -462,7 +466,7 @@ Bug fixes
462466
in GMM. By `Alexis Mignon`_.
463467

464468
- Fixed a error in the computation of conditional probabilities in
465-
:class:`naive_bayes.BernoulliNB`. By `Hanna Wallach`_.
469+
:class:`naive_bayes.BernoulliNB`. By Hanna Wallach.
466470

467471
- Make the method ``radius_neighbors`` of
468472
:class:`neighbors.NearestNeighbors` return the samples lying on the

examples/applications/face_recognition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
t0 = time()
106106
param_grid = {'C': [1e3, 5e3, 1e4, 5e4, 1e5],
107107
'gamma': [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1], }
108-
clf = GridSearchCV(SVC(kernel='rbf', class_weight='auto'), param_grid)
108+
clf = GridSearchCV(SVC(kernel='rbf', class_weight='balanced'), param_grid)
109109
clf = clf.fit(X_train_pca, y_train)
110110
print("done in %0.3fs" % (time() - t0))
111111
print("Best estimator found by grid search:")

sklearn/ensemble/forest.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
8989
curr_sample_weight *= sample_counts
9090

9191
if class_weight == 'subsample':
92-
curr_sample_weight *= compute_sample_weight('auto', y, indices)
92+
curr_sample_weight *= compute_sample_weight('balanced', y, indices)
9393

9494
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
9595

@@ -414,17 +414,17 @@ def _validate_y_class_weight(self, y):
414414
self.n_classes_.append(classes_k.shape[0])
415415

416416
if self.class_weight is not None:
417-
valid_presets = ('auto', 'subsample')
417+
valid_presets = ('auto', 'balanced', 'subsample', 'auto')
418418
if isinstance(self.class_weight, six.string_types):
419419
if self.class_weight not in valid_presets:
420420
raise ValueError('Valid presets for class_weight include '
421-
'"auto" and "subsample". Given "%s".'
421+
'"balanced" and "subsample". Given "%s".'
422422
% self.class_weight)
423423
if self.warm_start:
424-
warn('class_weight presets "auto" or "subsample" are '
424+
warn('class_weight presets "balanced" or "subsample" are '
425425
'not recommended for warm_start if the fitted data '
426426
'differs from the full dataset. In order to use '
427-
'"auto" weights, use compute_class_weight("auto", '
427+
'"auto" weights, use compute_class_weight("balanced", '
428428
'classes, y). In place of y you can use a large '
429429
'enough sample of the full training set target to '
430430
'properly estimate the class frequency '
@@ -433,7 +433,7 @@ def _validate_y_class_weight(self, y):
433433

434434
if self.class_weight != 'subsample' or not self.bootstrap:
435435
if self.class_weight == 'subsample':
436-
class_weight = 'auto'
436+
class_weight = 'balanced'
437437
else:
438438
class_weight = self.class_weight
439439
expanded_class_weight = compute_sample_weight(class_weight,
@@ -758,17 +758,18 @@ class RandomForestClassifier(ForestClassifier):
758758
and add more estimators to the ensemble, otherwise, just fit a whole
759759
new forest.
760760
761-
class_weight : dict, list of dicts, "auto", "subsample" or None, optional
761+
class_weight : dict, list of dicts, "balanced", "subsample" or None, optional
762762
763763
Weights associated with classes in the form ``{class_label: weight}``.
764764
If not given, all classes are supposed to have weight one. For
765765
multi-output problems, a list of dicts can be provided in the same
766766
order as the columns of y.
767767
768-
The "auto" mode uses the values of y to automatically adjust
769-
weights inversely proportional to class frequencies in the input data.
768+
The "balanced" mode uses the values of y to automatically adjust
769+
weights inversely proportional to class frequencies in the input data
770+
as ``n_samples / (n_classes * np.bincount(y))``
770771
771-
The "subsample" mode is the same as "auto" except that weights are
772+
The "subsample" mode is the same as "balanced" except that weights are
772773
computed based on the bootstrap sample for every tree grown.
773774
774775
For multi-output, the weights of each column of y will be multiplied.
@@ -1100,17 +1101,18 @@ class ExtraTreesClassifier(ForestClassifier):
11001101
and add more estimators to the ensemble, otherwise, just fit a whole
11011102
new forest.
11021103
1103-
class_weight : dict, list of dicts, "auto", "subsample" or None, optional
1104+
class_weight : dict, list of dicts, "balanced", "subsample" or None, optional
11041105
11051106
Weights associated with classes in the form ``{class_label: weight}``.
11061107
If not given, all classes are supposed to have weight one. For
11071108
multi-output problems, a list of dicts can be provided in the same
11081109
order as the columns of y.
11091110
1110-
The "auto" mode uses the values of y to automatically adjust
1111-
weights inversely proportional to class frequencies in the input data.
1111+
The "balanced" mode uses the values of y to automatically adjust
1112+
weights inversely proportional to class frequencies in the input data
1113+
as ``n_samples / (n_classes * np.bincount(y))``
11121114
1113-
The "subsample" mode is the same as "auto" except that weights are
1115+
The "subsample" mode is the same as "balanced" except that weights are
11141116
computed based on the bootstrap sample for every tree grown.
11151117
11161118
For multi-output, the weights of each column of y will be multiplied.

sklearn/ensemble/tests/test_forest.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def test_parallel():
329329
yield check_parallel, name, iris.data, iris.target
330330

331331
for name in FOREST_REGRESSORS:
332-
yield check_parallel, name, boston.data, boston.target
332+
yield check_parallel, name, boston.data, boston.target
333333

334334

335335
def check_pickle(name, X, y):
@@ -352,7 +352,7 @@ def test_pickle():
352352
yield check_pickle, name, iris.data[::2], iris.target[::2]
353353

354354
for name in FOREST_REGRESSORS:
355-
yield check_pickle, name, boston.data[::2], boston.target[::2]
355+
yield check_pickle, name, boston.data[::2], boston.target[::2]
356356

357357

358358
def check_multioutput(name):
@@ -749,10 +749,10 @@ def check_class_weights(name):
749749
# Check class_weights resemble sample_weights behavior.
750750
ForestClassifier = FOREST_CLASSIFIERS[name]
751751

752-
# Iris is balanced, so no effect expected for using 'auto' weights
752+
# Iris is balanced, so no effect expected for using 'balanced' weights
753753
clf1 = ForestClassifier(random_state=0)
754754
clf1.fit(iris.data, iris.target)
755-
clf2 = ForestClassifier(class_weight='auto', random_state=0)
755+
clf2 = ForestClassifier(class_weight='balanced', random_state=0)
756756
clf2.fit(iris.data, iris.target)
757757
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
758758

@@ -765,8 +765,8 @@ def check_class_weights(name):
765765
random_state=0)
766766
clf3.fit(iris.data, iris_multi)
767767
assert_almost_equal(clf2.feature_importances_, clf3.feature_importances_)
768-
# Check against multi-output "auto" which should also have no effect
769-
clf4 = ForestClassifier(class_weight='auto', random_state=0)
768+
# Check against multi-output "balanced" which should also have no effect
769+
clf4 = ForestClassifier(class_weight='balanced', random_state=0)
770770
clf4.fit(iris.data, iris_multi)
771771
assert_almost_equal(clf3.feature_importances_, clf4.feature_importances_)
772772

@@ -782,7 +782,7 @@ def check_class_weights(name):
782782

783783
# Check that sample_weight and class_weight are multiplicative
784784
clf1 = ForestClassifier(random_state=0)
785-
clf1.fit(iris.data, iris.target, sample_weight**2)
785+
clf1.fit(iris.data, iris.target, sample_weight ** 2)
786786
clf2 = ForestClassifier(class_weight=class_weight, random_state=0)
787787
clf2.fit(iris.data, iris.target, sample_weight)
788788
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
@@ -793,11 +793,11 @@ def test_class_weights():
793793
yield check_class_weights, name
794794

795795

796-
def check_class_weight_auto_and_bootstrap_multi_output(name):
797-
# Test class_weight works for multi-output
796+
def check_class_weight_balanced_and_bootstrap_multi_output(name):
797+
# Test class_weight works for multi-output"""
798798
ForestClassifier = FOREST_CLASSIFIERS[name]
799799
_y = np.vstack((y, np.array(y) * 2)).T
800-
clf = ForestClassifier(class_weight='auto', random_state=0)
800+
clf = ForestClassifier(class_weight='balanced', random_state=0)
801801
clf.fit(X, _y)
802802
clf = ForestClassifier(class_weight=[{-1: 0.5, 1: 1.}, {-2: 1., 2: 1.}],
803803
random_state=0)
@@ -806,9 +806,9 @@ def check_class_weight_auto_and_bootstrap_multi_output(name):
806806
clf.fit(X, _y)
807807

808808

809-
def test_class_weight_auto_and_bootstrap_multi_output():
809+
def test_class_weight_balanced_and_bootstrap_multi_output():
810810
for name in FOREST_CLASSIFIERS:
811-
yield check_class_weight_auto_and_bootstrap_multi_output, name
811+
yield check_class_weight_balanced_and_bootstrap_multi_output, name
812812

813813

814814
def check_class_weight_errors(name):

sklearn/linear_model/logistic.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,13 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
473473
is called repeatedly with the same data, as y is modified
474474
along the path.
475475
476-
class_weight : {dict, 'auto'}, optional
477-
Over-/undersamples the samples of each class according to the given
478-
weights. If not given, all classes are supposed to have weight one.
479-
The 'auto' mode selects weights inversely proportional to class
480-
frequencies in the training set.
476+
class_weight : dict or 'balanced', optional
477+
Weights associated with classes in the form ``{class_label: weight}``.
478+
If not given, all classes are supposed to have weight one.
479+
480+
The "balanced" mode uses the values of y to automatically adjust
481+
weights inversely proportional to class frequencies in the input data
482+
as ``n_samples / (n_classes * np.bincount(y))``
481483
482484
dual : bool
483485
Dual or primal formulation. Dual formulation is only implemented for
@@ -734,11 +736,13 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
734736
tol : float
735737
Tolerance for stopping criteria.
736738
737-
class_weight : {dict, 'auto'}, optional
738-
Over-/undersamples the samples of each class according to the given
739-
weights. If not given, all classes are supposed to have weight one.
740-
The 'auto' mode selects weights inversely proportional to class
741-
frequencies in the training set.
739+
class_weight : dict or 'balanced', optional
740+
Weights associated with classes in the form ``{class_label: weight}``.
741+
If not given, all classes are supposed to have weight one.
742+
743+
The "balanced" mode uses the values of y to automatically adjust
744+
weights inversely proportional to class frequencies in the input data
745+
as ``n_samples / (n_classes * np.bincount(y))``
742746
743747
verbose : int
744748
For the liblinear and lbfgs solvers set verbose to any positive
@@ -903,11 +907,13 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
903907
To lessen the effect of regularization on synthetic feature weight
904908
(and therefore on the intercept) intercept_scaling has to be increased.
905909
906-
class_weight : {dict, 'auto'}, optional
907-
Over-/undersamples the samples of each class according to the given
908-
weights. If not given, all classes are supposed to have weight one.
909-
The 'auto' mode selects weights inversely proportional to class
910-
frequencies in the training set.
910+
class_weight : dict or 'balanced', optional
911+
Weights associated with classes in the form ``{class_label: weight}``.
912+
If not given, all classes are supposed to have weight one.
913+
914+
The "balanced" mode uses the values of y to automatically adjust
915+
weights inversely proportional to class frequencies in the input data
916+
as ``n_samples / (n_classes * np.bincount(y))``
911917
912918
max_iter : int
913919
Useful only for the newton-cg and lbfgs solvers. Maximum number of
@@ -1150,11 +1156,13 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
11501156
Specifies if a constant (a.k.a. bias or intercept) should be
11511157
added the decision function.
11521158
1153-
class_weight : {dict, 'auto'}, optional
1154-
Over-/undersamples the samples of each class according to the given
1155-
weights. If not given, all classes are supposed to have weight one.
1156-
The 'auto' mode selects weights inversely proportional to class
1157-
frequencies in the training set.
1159+
class_weight : dict or 'balanced', optional
1160+
Weights associated with classes in the form ``{class_label: weight}``.
1161+
If not given, all classes are supposed to have weight one.
1162+
1163+
The "balanced" mode uses the values of y to automatically adjust
1164+
weights inversely proportional to class frequencies in the input data
1165+
as ``n_samples / (n_classes * np.bincount(y))``
11581166
11591167
cv : integer or cross-validation generator
11601168
The default cross-validation generator used is Stratified K-Folds.
@@ -1185,11 +1193,13 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
11851193
max_iter : int, optional
11861194
Maximum number of iterations of the optimization algorithm.
11871195
1188-
class_weight : {dict, 'auto'}, optional
1189-
Over-/undersamples the samples of each class according to the given
1190-
weights. If not given, all classes are supposed to have weight one.
1191-
The 'auto' mode selects weights inversely proportional to class
1192-
frequencies in the training set.
1196+
class_weight : dict or 'balanced', optional
1197+
Weights associated with classes in the form ``{class_label: weight}``.
1198+
If not given, all classes are supposed to have weight one.
1199+
1200+
The "balanced" mode uses the values of y to automatically adjust
1201+
weights inversely proportional to class frequencies in the input data
1202+
as ``n_samples / (n_classes * np.bincount(y))``
11931203
11941204
n_jobs : int, optional
11951205
Number of CPU cores used during the cross-validation loop. If given
@@ -1363,9 +1373,9 @@ def fit(self, X, y):
13631373
iter_labels = [None]
13641374

13651375
if self.class_weight and not(isinstance(self.class_weight, dict) or
1366-
self.class_weight == 'auto'):
1376+
self.class_weight in ['balanced', 'auto']):
13671377
raise ValueError("class_weight provided should be a "
1368-
"dict or 'auto'")
1378+
"dict or 'balanced'")
13691379

13701380
path_func = delayed(_log_reg_scoring_path)
13711381

sklearn/linear_model/perceptron.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ class Perceptron(BaseSGDClassifier, _LearntSelectorMixin):
4444
eta0 : double
4545
Constant by which the updates are multiplied. Defaults to 1.
4646
47-
class_weight : dict, {class_label: weight} or "auto" or None, optional
47+
class_weight : dict, {class_label: weight} or "balanced" or None, optional
4848
Preset for the class_weight fit parameter.
4949
5050
Weights associated with classes. If not given, all classes
5151
are supposed to have weight one.
5252
53-
The "auto" mode uses the values of y to automatically adjust
54-
weights inversely proportional to class frequencies.
53+
The "balanced" mode uses the values of y to automatically adjust
54+
weights inversely proportional to class frequencies in the input data
55+
as ``n_samples / (n_classes * np.bincount(y))``
5556
5657
warm_start : bool, optional
5758
When set to True, reuse the solution of the previous call to fit as

sklearn/linear_model/ridge.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,13 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
507507
``(2*C)^-1`` in other linear models such as LogisticRegression or
508508
LinearSVC.
509509
510-
class_weight : dict, optional
511-
Weights associated with classes in the form
512-
``{class_label : weight}``. If not given, all classes are
513-
supposed to have weight one.
510+
class_weight : dict or 'balanced', optional
511+
Weights associated with classes in the form ``{class_label: weight}``.
512+
If not given, all classes are supposed to have weight one.
513+
514+
The "balanced" mode uses the values of y to automatically adjust
515+
weights inversely proportional to class frequencies in the input data
516+
as ``n_samples / (n_classes * np.bincount(y))``
514517
515518
copy_X : boolean, optional, default True
516519
If True, X will be copied; else, it may be overwritten.
@@ -994,10 +997,13 @@ class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
994997
If None, Generalized Cross-Validation (efficient Leave-One-Out)
995998
will be used.
996999
997-
class_weight : dict, optional
998-
Weights associated with classes in the form
999-
``{class_label : weight}``. If not given, all classes are
1000-
supposed to have weight one.
1000+
class_weight : dict or 'balanced', optional
1001+
Weights associated with classes in the form ``{class_label: weight}``.
1002+
If not given, all classes are supposed to have weight one.
1003+
1004+
The "balanced" mode uses the values of y to automatically adjust
1005+
weights inversely proportional to class frequencies in the input data
1006+
as ``n_samples / (n_classes * np.bincount(y))``
10011007
10021008
Attributes
10031009
----------

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy