Skip to content

Commit b278440

Browse files
zaxtaxamueller
authored andcommitted
Adding sparse support for decision function
Adding test for decision_function on sparse matrices Fixing typos Fix Memory Leak PEP8 fixes COSMIT PEP8 Fixed class_weight_label problem Fixed label problem in tests renamed label_ to classes_ Fixed OneClass bug Adding kernel tests for sparse matrices Pass dense data to decision function Don't check precomputed kernel use safe_sparse_dot in test, regenerate cython create a core-dump by testing the binary case... hurray?
1 parent 2ff021f commit b278440

File tree

5 files changed

+3787
-3063
lines changed

5 files changed

+3787
-3063
lines changed

sklearn/svm/base.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -366,29 +366,58 @@ def decision_function(self, X):
366366
# NOTE: _validate_for_predict contains check for is_fitted
367367
# hence must be placed before any other attributes are used.
368368
X = self._validate_for_predict(X)
369-
if self._sparse:
370-
raise NotImplementedError("Decision_function not supported for"
371-
" sparse SVM.")
372369
X = self._compute_kernel(X)
373370

371+
X = self._validate_for_predict(X)
372+
373+
if self._sparse:
374+
dec_func = self._sparse_decision_function(X)
375+
else:
376+
dec_func = self._dense_decision_function(X)
377+
378+
# In binary case, we need to flip the sign of coef, intercept and
379+
# decision function.
380+
if self.impl != 'one_class' and len(self.classes_) == 2:
381+
return -dec_func
382+
383+
return dec_func
384+
385+
def _dense_decision_function(self, X):
386+
X = check_array(X, dtype=np.float64, order="C")
387+
374388
kernel = self.kernel
375389
if callable(kernel):
376390
kernel = 'precomputed'
377391

378-
dec_func = libsvm.decision_function(
392+
return libsvm.decision_function(
379393
X, self.support_, self.support_vectors_, self.n_support_,
380394
self._dual_coef_, self._intercept_,
381395
self.probA_, self.probB_,
382396
svm_type=LIBSVM_IMPL.index(self._impl),
383397
kernel=kernel, degree=self.degree, cache_size=self.cache_size,
384398
coef0=self.coef0, gamma=self._gamma)
385399

386-
# In binary case, we need to flip the sign of coef, intercept and
387-
# decision function.
388-
if self._impl in ['c_svc', 'nu_svc'] and len(self.classes_) == 2:
389-
return -dec_func.ravel()
400+
def _sparse_decision_function(self, X):
401+
X.data = np.asarray(X.data, dtype=np.float64, order='C')
390402

391-
return dec_func
403+
kernel = self.kernel
404+
if hasattr(kernel, '__call__'):
405+
kernel = 'precomputed'
406+
407+
kernel_type = self._sparse_kernels.index(kernel)
408+
409+
return libsvm_sparse.libsvm_sparse_decision_function(
410+
X.data, X.indices, X.indptr,
411+
self.support_vectors_.data,
412+
self.support_vectors_.indices,
413+
self.support_vectors_.indptr,
414+
self._dual_coef_.data, self._intercept_,
415+
LIBSVM_IMPL.index(self.impl), kernel_type,
416+
self.degree, self.gamma, self.coef0, self.tol,
417+
self.C, self.class_weight_,
418+
self.nu, self.epsilon, self.shrinking,
419+
self.probability, self.n_support_, self._label,
420+
self.probA_, self.probB_)
392421

393422
def _validate_for_predict(self, X):
394423
check_is_fitted(self, 'support_')
@@ -666,9 +695,8 @@ def _get_liblinear_solver_type(multi_class, penalty, loss, dual):
666695
% (penalty, loss, dual))
667696
else:
668697
return solver_num
669-
670-
raise ValueError(('Unsupported set of arguments: %s, '
671-
'Parameters: penalty=%r, loss=%r, dual=%r')
698+
raise ValueError('Unsupported set of arguments: %s, '
699+
'Parameters: penalty=%r, loss=%r, dual=%r'
672700
% (error_string, penalty, loss, dual))
673701

674702

@@ -811,8 +839,7 @@ def _fit_liblinear(X, y, C, fit_intercept, intercept_scaling, class_weight,
811839
raw_coef_, n_iter_ = liblinear.train_wrap(
812840
X, y_ind, sp.isspmatrix(X), solver_type, tol, bias, C,
813841
class_weight_, max_iter, rnd.randint(np.iinfo('i').max),
814-
epsilon
815-
)
842+
epsilon)
816843
# Regarding rnd.randint(..) in the above signature:
817844
# seed for srand in range [0..INT_MAX); due to limitations in Numpy
818845
# on 32-bit platforms, we can't get to the UINT_MAX limit that

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy