diff --git a/src/linear_model/LogisticRegression.test.ts b/src/linear_model/LogisticRegression.test.ts index 1d38ba5..80b3964 100644 --- a/src/linear_model/LogisticRegression.test.ts +++ b/src/linear_model/LogisticRegression.test.ts @@ -47,6 +47,68 @@ describe('LogisticRegression', function () { expect(results.arraySync()).toEqual([0, 0, 0, 1, 1, 1]) expect(logreg.score(X, y) > 0.5).toBe(true) }, 30000) + it('Test of the function used with 2 classes (one hot)', async function () { + let X = [ + [0, -1], + [1, 0], + [1, 1], + [1, -1], + [2, 0], + [2, 1], + [2, -1], + [3, 2], + [0, 4], + [1, 3], + [1, 4], + [1, 5], + [2, 3], + [2, 4], + [2, 5], + [3, 4] + ] + let y = [ + [1, 0], + [1, 0], + [1, 0], + [1, 0], + [1, 0], + [1, 0], + [1, 0], + [1, 0], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1] + ] + + let Xtest = [ + [0, -2], + [1, 0.5], + [1.5, -1], + [1, 4.5], + [2, 3.5], + [1.5, 5] + ] + + let logreg = new LogisticRegression({ penalty: 'none' }) + await logreg.fit(X, y) + let probabilities = logreg.predictProba(X) + expect(probabilities instanceof tf.Tensor).toBe(true) + let results = logreg.predict(Xtest) // compute results of the training set + expect(results.arraySync()).toEqual([ + [1, 0], + [1, 0], + [1, 0], + [0, 1], + [0, 1], + [0, 1] + ]) + expect(logreg.score(X, y) > 0.5).toBe(true) + }, 30000) it('Test of the prediction with 3 classes', async function () { let X = [ [0, -1], diff --git a/src/linear_model/SgdClassifier.ts b/src/linear_model/SgdClassifier.ts index 56a55e1..bc9ee15 100644 --- a/src/linear_model/SgdClassifier.ts +++ b/src/linear_model/SgdClassifier.ts @@ -13,7 +13,10 @@ * ========================================================================== */ -import { convertToNumericTensor1D, convertToNumericTensor2D } from '../utils' +import { + convertToNumericTensor1D_2D, + convertToNumericTensor2D +} from '../utils' import { Scikit2D, Scikit1D, @@ -23,8 +26,7 @@ import { Tensor2D, Tensor, ModelCompileArgs, - ModelFitArgs, - RecursiveArray + ModelFitArgs } from '../types' import { OneHotEncoder } from '../preprocessing/OneHotEncoder' import { assert } from '../typesUtils' @@ -103,6 +105,7 @@ export class SGDClassifier extends ClassifierMixin { lossType: LossTypes oneHot: OneHotEncoder tf: any + isMultiOutput: boolean constructor({ modelFitArgs, @@ -119,6 +122,7 @@ export class SGDClassifier extends ClassifierMixin { this.denseLayerArgs = denseLayerArgs this.optimizerType = optimizerType this.lossType = lossType + this.isMultiOutput = false // Next steps: Implement "drop" mechanics for OneHotEncoder // There is a possibility to do a drop => if_binary which would // squash down on the number of variables that we'd have to learn @@ -200,12 +204,17 @@ export class SGDClassifier extends ClassifierMixin { * // lr model weights have been updated */ - public async fit(X: Scikit2D, y: Scikit1D): Promise { + public async fit( + X: Scikit2D, + y: Scikit1D | Scikit2D + ): Promise { let XTwoD = convertToNumericTensor2D(X) - let yOneD = convertToNumericTensor1D(y) + let yOneD = convertToNumericTensor1D_2D(y) const yTwoD = this.initializeModelForClassification(yOneD) - + if (yOneD.shape.length > 1) { + this.isMultiOutput = true + } if (this.model.layers.length === 0) { this.initializeModel(XTwoD, yTwoD) } @@ -344,6 +353,9 @@ export class SGDClassifier extends ClassifierMixin { public predict(X: Scikit2D): Tensor1D { assert(this.model.layers.length > 0, 'Need to call "fit" before "predict"') const y2D = this.predictProba(X) + if (this.isMultiOutput) { + return this.tf.oneHot(y2D.argMax(1), y2D.shape[1]) + } return this.tf.tensor1d(this.oneHot.inverseTransform(y2D)) } @@ -418,10 +430,4 @@ export class SGDClassifier extends ClassifierMixin { return intercept } - - private getModelWeight(): Promise> { - return Promise.all( - this.model.getWeights().map((weight: any) => weight.array()) - ) - } } diff --git a/src/mixins.ts b/src/mixins.ts index 2d71b81..225f7f7 100644 --- a/src/mixins.ts +++ b/src/mixins.ts @@ -1,6 +1,8 @@ import { Scikit2D, Scikit1D, Tensor2D, Tensor1D } from './types' import { r2Score, accuracyScore } from './metrics/metrics' import { Serialize } from './simpleSerializer' +import { assert, isScikit2D } from './typesUtils' +import { convertToNumericTensor1D_2D } from './utils' export class TransformerMixin extends Serialize { // We assume that fit and transform exist [x: string]: any @@ -35,8 +37,17 @@ export class ClassifierMixin extends Serialize { [x: string]: any EstimatorType = 'classifier' - public score(X: Scikit2D, y: Scikit1D): number { + public score(X: Scikit2D, y: Scikit1D | Scikit2D): number { const yPred = this.predict(X) + const yTrue = convertToNumericTensor1D_2D(y) + assert( + yPred.shape.length === yTrue.shape.length, + "The shape of the model output doesn't match the shape of the actual y values" + ) + + if (isScikit2D(y)) { + return accuracyScore(yTrue.argMax(1) as Scikit1D, yPred.argMax(1)) + } return accuracyScore(y, yPred) } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy