From 90bcc0a3dba521c0fa7cf09914e6357e66d89544 Mon Sep 17 00:00:00 2001 From: Dan Crescimanno Date: Sun, 22 May 2022 00:49:04 -0700 Subject: [PATCH] feat: added random state to sgd regressor --- src/linear_model/LinearRegression.test.ts | 18 ++++++++++-------- src/linear_model/LinearRegression.ts | 5 ++++- src/linear_model/SgdRegressor.ts | 22 ++++++++++++++++------ 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/linear_model/LinearRegression.test.ts b/src/linear_model/LinearRegression.test.ts index 6681df6..c238eff 100644 --- a/src/linear_model/LinearRegression.test.ts +++ b/src/linear_model/LinearRegression.test.ts @@ -11,7 +11,7 @@ function roughlyEqual(a: number, b: number, tol = 0.1) { describe('LinearRegression', function () { it('Works on arrays (small example)', async function () { - const lr = new LinearRegression() + const lr = new LinearRegression({ randomState: 42 }) await lr.fit([[1], [2]], [2, 4]) expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true) expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) @@ -24,6 +24,7 @@ describe('LinearRegression', function () { console.log('training begins') } const lr = new LinearRegression({ + randomState: 42, modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } }) await lr.fit([[1], [2]], [2, 4]) @@ -39,6 +40,7 @@ describe('LinearRegression', function () { console.log('training begins') } const lr = new LinearRegression({ + randomState: 42, modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } }) await lr.fit([[1], [2]], [2, 4]) @@ -50,7 +52,7 @@ describe('LinearRegression', function () { }, 30000) it('Works on small multi-output example (small example)', async function () { - const lr = new LinearRegression() + const lr = new LinearRegression({ randomState: 42 }) await lr.fit( [[1], [2]], [ @@ -63,14 +65,14 @@ describe('LinearRegression', function () { }, 30000) it('Works on arrays with no intercept (small example)', async function () { - const lr = new LinearRegression({ fitIntercept: false }) + const lr = new LinearRegression({ fitIntercept: false, randomState: 42 }) await lr.fit([[1], [2]], [2, 4]) expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true) expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) }, 30000) it('Works on arrays with none zero intercept (small example)', async function () { - const lr = new LinearRegression({ fitIntercept: true }) + const lr = new LinearRegression({ fitIntercept: true, randomState: 42 }) await lr.fit([[1], [2]], [3, 5]) expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true) expect(roughlyEqual(lr.intercept as number, 1)).toBe(true) @@ -95,7 +97,7 @@ describe('LinearRegression', function () { const yPlusJitter = y.add( tf.randomNormal([sizeOfMatrix], 0, 1, 'float32', seed) ) as tf.Tensor1D - const lr = new LinearRegression({ fitIntercept: false }) + const lr = new LinearRegression({ fitIntercept: false, randomState: 42 }) await lr.fit(mediumX, yPlusJitter) expect(tensorEqual(lr.coef, tf.tensor1d([2.5, 1]), 0.1)).toBe(true) @@ -121,7 +123,7 @@ describe('LinearRegression', function () { const yPlusJitter = y.add( tf.randomNormal([sizeOfMatrix], 0, 1, 'float32', seed) ) as tf.Tensor1D - const lr = new LinearRegression({ fitIntercept: false }) + const lr = new LinearRegression({ fitIntercept: false, randomState: 42 }) await lr.fit(mediumX, yPlusJitter) expect(tensorEqual(lr.coef, tf.tensor1d([2.5, 1]), 0.1)).toBe(true) @@ -158,7 +160,7 @@ describe('LinearRegression', function () { let score = 1.0 /*[[[end]]]*/ - const lr = new LinearRegression() + const lr = new LinearRegression({ randomState: 42 }) await lr.fit(X, y) expect(lr.score(X, y)).toBeCloseTo(score) }, 30000) @@ -180,7 +182,7 @@ describe('LinearRegression', function () { const yPlusJitter = y.add( tf.randomNormal([sizeOfMatrix], 0, 1, 'float32', seed) ) as tf.Tensor1D - const lr = new LinearRegression({ fitIntercept: false }) + const lr = new LinearRegression({ fitIntercept: false, randomState: 42 }) await lr.fit(mediumX, yPlusJitter) const serialized = await lr.toObject() diff --git a/src/linear_model/LinearRegression.ts b/src/linear_model/LinearRegression.ts index c09a620..5d36652 100644 --- a/src/linear_model/LinearRegression.ts +++ b/src/linear_model/LinearRegression.ts @@ -41,6 +41,7 @@ export interface LinearRegressionParams { */ fitIntercept?: boolean modelFitOptions?: Partial + randomState?: number } /* @@ -70,7 +71,8 @@ Next steps: export class LinearRegression extends SGDRegressor { constructor({ fitIntercept = true, - modelFitOptions + modelFitOptions, + randomState }: LinearRegressionParams = {}) { let tf = getBackend() super({ @@ -92,6 +94,7 @@ export class LinearRegression extends SGDRegressor { units: 1, useBias: Boolean(fitIntercept) }, + randomState, optimizerType: 'adam', lossType: 'meanSquaredError' }) diff --git a/src/linear_model/SgdRegressor.ts b/src/linear_model/SgdRegressor.ts index b089b55..f8e1bba 100644 --- a/src/linear_model/SgdRegressor.ts +++ b/src/linear_model/SgdRegressor.ts @@ -91,6 +91,8 @@ export interface SGDRegressorParams { optimizerType: OptimizerTypes lossType: LossTypes + + randomState?: number } export class SGDRegressor extends RegressorMixin { @@ -101,13 +103,15 @@ export class SGDRegressor extends RegressorMixin { isMultiOutput: boolean optimizerType: OptimizerTypes lossType: LossTypes + randomState?: number constructor({ modelFitArgs, modelCompileArgs, denseLayerArgs, optimizerType, - lossType + lossType, + randomState }: SGDRegressorParams) { super() this.tf = getBackend() @@ -118,6 +122,7 @@ export class SGDRegressor extends RegressorMixin { this.isMultiOutput = false this.optimizerType = optimizerType this.lossType = lossType + this.randomState = randomState } /** @@ -139,12 +144,17 @@ export class SGDRegressor extends RegressorMixin { ): void { this.denseLayerArgs.units = y.shape.length === 1 ? 1 : y.shape[1] const model = this.tf.sequential() - model.add( - this.tf.layers.dense({ - inputShape: [X.shape[1]], - ...this.denseLayerArgs + let denseLayerArgs = { + inputShape: [X.shape[1]], + ...this.denseLayerArgs + } + // If randomState is set, then use it to set the args in this layer + if (this.randomState) { + denseLayerArgs.kernelInitializer = this.tf.initializers.glorotUniform({ + seed: this.randomState }) - ) + } + model.add(this.tf.layers.dense(denseLayerArgs)) model.compile(this.modelCompileArgs) if (weightsTensors?.length) { model.setWeights(weightsTensors) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy