From 2ddcad9053097c75bcdd97d234c1f8b605a88241 Mon Sep 17 00:00:00 2001 From: Yaw Joseph Etse Date: Wed, 18 May 2022 22:12:38 -0400 Subject: [PATCH 1/3] feat: custom modelfitargs for linear models --- .gitignore | 3 ++- src/linear_model/LinearRegression.ts | 8 ++++++-- src/linear_model/LogisticRegression.ts | 8 ++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 4d06003..b57f5e7 100644 --- a/.gitignore +++ b/.gitignore @@ -107,4 +107,5 @@ dist # IDE Files .vscode/ -.idea/ \ No newline at end of file +.idea/ +.dccache \ No newline at end of file diff --git a/src/linear_model/LinearRegression.ts b/src/linear_model/LinearRegression.ts index 1913ed2..995f577 100644 --- a/src/linear_model/LinearRegression.ts +++ b/src/linear_model/LinearRegression.ts @@ -15,6 +15,7 @@ import { SGDRegressor } from './SgdRegressor' import { getBackend } from '../tf-singleton' +import { ModelFitArgs } from '../types' /** * LinearRegression implementation using gradient descent @@ -39,6 +40,8 @@ export interface LinearRegressionParams { * **default = true** */ fitIntercept?: boolean + modelFitOptions?: Partial + } /* @@ -66,7 +69,7 @@ Next steps: * ``` */ export class LinearRegression extends SGDRegressor { - constructor({ fitIntercept = true }: LinearRegressionParams = {}) { + constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) { let tf = getBackend() super({ modelCompileArgs: { @@ -80,7 +83,8 @@ export class LinearRegression extends SGDRegressor { verbose: 0, callbacks: [ tf.callbacks.earlyStopping({ monitor: 'mse', patience: 30 }) - ] + ], + ...modelFitOptions }, denseLayerArgs: { units: 1, diff --git a/src/linear_model/LogisticRegression.ts b/src/linear_model/LogisticRegression.ts index 159cd36..b235bb3 100644 --- a/src/linear_model/LogisticRegression.ts +++ b/src/linear_model/LogisticRegression.ts @@ -15,6 +15,7 @@ import { SGDClassifier } from './SgdClassifier' import { getBackend } from '../tf-singleton' +import { ModelFitArgs } from '../types' // First pass at a LogisticRegression implementation using gradient descent // Trying to mimic the API of scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html @@ -35,6 +36,7 @@ export interface LogisticRegressionParams { C?: number /** Whether or not the intercept should be estimator not. **default = true** */ fitIntercept?: boolean + modelFitOptions?: Partial } /** Builds a linear classification model with associated penalty and regularization @@ -63,7 +65,8 @@ export class LogisticRegression extends SGDClassifier { constructor({ penalty = 'l2', C = 1, - fitIntercept = true + fitIntercept = true, + modelFitOptions }: LogisticRegressionParams = {}) { // Assume Binary classification // If we call fit, and it isn't binary then update args @@ -80,7 +83,8 @@ export class LogisticRegression extends SGDClassifier { verbose: 0, callbacks: [ tf.callbacks.earlyStopping({ monitor: 'loss', patience: 50 }) - ] + ], + ...modelFitOptions }, denseLayerArgs: { units: 1, From 7fa5c4259902d7dca0a925002cbfaf1937dc2b1b Mon Sep 17 00:00:00 2001 From: Dan Crescimanno Date: Wed, 18 May 2022 21:41:05 -0700 Subject: [PATCH 2/3] feat: added test case for custom callbacks. works great and somehow serializes. --- src/linear_model/LinearRegression.test.ts | 32 +++++++++++++++++++++++ src/linear_model/LinearRegression.ts | 10 ++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/linear_model/LinearRegression.test.ts b/src/linear_model/LinearRegression.test.ts index 2e54a97..6681df6 100644 --- a/src/linear_model/LinearRegression.test.ts +++ b/src/linear_model/LinearRegression.test.ts @@ -17,6 +17,38 @@ describe('LinearRegression', function () { expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) }, 30000) + it('Works on arrays (small example) with custom callbacks', async function () { + let trainingHasStarted = false + const onTrainBegin = async (logs: any) => { + trainingHasStarted = true + console.log('training begins') + } + const lr = new LinearRegression({ + modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } + }) + await lr.fit([[1], [2]], [2, 4]) + expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true) + expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) + expect(trainingHasStarted).toBe(true) + }, 30000) + + it('Works on arrays (small example) with custom callbacks', async function () { + let trainingHasStarted = false + const onTrainBegin = async (logs: any) => { + trainingHasStarted = true + console.log('training begins') + } + const lr = new LinearRegression({ + modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } + }) + await lr.fit([[1], [2]], [2, 4]) + + const serialized = await lr.toJSON() + const newModel = await fromJSON(serialized) + expect(tensorEqual(newModel.coef, tf.tensor1d([2]), 0.1)).toBe(true) + expect(roughlyEqual(newModel.intercept as number, 0)).toBe(true) + }, 30000) + it('Works on small multi-output example (small example)', async function () { const lr = new LinearRegression() await lr.fit( diff --git a/src/linear_model/LinearRegression.ts b/src/linear_model/LinearRegression.ts index 995f577..c09a620 100644 --- a/src/linear_model/LinearRegression.ts +++ b/src/linear_model/LinearRegression.ts @@ -41,7 +41,6 @@ export interface LinearRegressionParams { */ fitIntercept?: boolean modelFitOptions?: Partial - } /* @@ -53,7 +52,7 @@ Next steps: /** Linear Least Squares * @example * ```js - * import {LinearRegression} from 'scikitjs' + * import { LinearRegression } from 'scikitjs' * * let X = [ * [1, 2], @@ -63,13 +62,16 @@ Next steps: * [10, 20] * ] * let y = [3, 5, 8, 8, 30] - * const lr = new LinearRegression({fitIntercept: false}) + * const lr = new LinearRegression({ fitIntercept: false }) await lr.fit(X, y) lr.coef.print() // probably around [1, 1] * ``` */ export class LinearRegression extends SGDRegressor { - constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) { + constructor({ + fitIntercept = true, + modelFitOptions + }: LinearRegressionParams = {}) { let tf = getBackend() super({ modelCompileArgs: { From 3d7731cdcfefb6121a40b78f2000cedeb05a7d29 Mon Sep 17 00:00:00 2001 From: semantic-release-bot Date: Thu, 19 May 2022 05:00:22 +0000 Subject: [PATCH 3/3] chore(release): 1.23.0 [skip ci] # [1.23.0](https://github.com/javascriptdata/scikit.js/compare/v1.22.0...v1.23.0) (2022-05-19) ### Features * added test case for custom callbacks. works great and somehow serializes. ([7fa5c42](https://github.com/javascriptdata/scikit.js/commit/7fa5c4259902d7dca0a925002cbfaf1937dc2b1b)) * custom modelfitargs for linear models ([2ddcad9](https://github.com/javascriptdata/scikit.js/commit/2ddcad9053097c75bcdd97d234c1f8b605a88241)) --- CHANGELOG.md | 8 ++++++++ package-lock.json | 4 ++-- package.json | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 781ea72..a99ede5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# [1.23.0](https://github.com/javascriptdata/scikit.js/compare/v1.22.0...v1.23.0) (2022-05-19) + + +### Features + +* added test case for custom callbacks. works great and somehow serializes. ([7fa5c42](https://github.com/javascriptdata/scikit.js/commit/7fa5c4259902d7dca0a925002cbfaf1937dc2b1b)) +* custom modelfitargs for linear models ([2ddcad9](https://github.com/javascriptdata/scikit.js/commit/2ddcad9053097c75bcdd97d234c1f8b605a88241)) + # [1.22.0](https://github.com/javascriptdata/scikit.js/compare/v1.21.0...v1.22.0) (2022-05-18) diff --git a/package-lock.json b/package-lock.json index 2d1ffc2..04e3044 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "hasInstallScript": true, "license": "ISC", "dependencies": { diff --git a/package.json b/package.json index 59bc332..040769c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "description": "Scikit-Learn for JS", "output": { "node": "dist/node/index.js", pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy