diff --git a/.gitignore b/.gitignore index 4d06003..b57f5e7 100644 --- a/.gitignore +++ b/.gitignore @@ -107,4 +107,5 @@ dist # IDE Files .vscode/ -.idea/ \ No newline at end of file +.idea/ +.dccache \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 781ea72..a99ede5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# [1.23.0](https://github.com/javascriptdata/scikit.js/compare/v1.22.0...v1.23.0) (2022-05-19) + + +### Features + +* added test case for custom callbacks. works great and somehow serializes. ([7fa5c42](https://github.com/javascriptdata/scikit.js/commit/7fa5c4259902d7dca0a925002cbfaf1937dc2b1b)) +* custom modelfitargs for linear models ([2ddcad9](https://github.com/javascriptdata/scikit.js/commit/2ddcad9053097c75bcdd97d234c1f8b605a88241)) + # [1.22.0](https://github.com/javascriptdata/scikit.js/compare/v1.21.0...v1.22.0) (2022-05-18) diff --git a/package-lock.json b/package-lock.json index 2d1ffc2..04e3044 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "hasInstallScript": true, "license": "ISC", "dependencies": { diff --git a/package.json b/package.json index 59bc332..040769c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "description": "Scikit-Learn for JS", "output": { "node": "dist/node/index.js", diff --git a/src/linear_model/LinearRegression.test.ts b/src/linear_model/LinearRegression.test.ts index 2e54a97..6681df6 100644 --- a/src/linear_model/LinearRegression.test.ts +++ b/src/linear_model/LinearRegression.test.ts @@ -17,6 +17,38 @@ describe('LinearRegression', function () { expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) }, 30000) + it('Works on arrays (small example) with custom callbacks', async function () { + let trainingHasStarted = false + const onTrainBegin = async (logs: any) => { + trainingHasStarted = true + console.log('training begins') + } + const lr = new LinearRegression({ + modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } + }) + await lr.fit([[1], [2]], [2, 4]) + expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true) + expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) + expect(trainingHasStarted).toBe(true) + }, 30000) + + it('Works on arrays (small example) with custom callbacks', async function () { + let trainingHasStarted = false + const onTrainBegin = async (logs: any) => { + trainingHasStarted = true + console.log('training begins') + } + const lr = new LinearRegression({ + modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } + }) + await lr.fit([[1], [2]], [2, 4]) + + const serialized = await lr.toJSON() + const newModel = await fromJSON(serialized) + expect(tensorEqual(newModel.coef, tf.tensor1d([2]), 0.1)).toBe(true) + expect(roughlyEqual(newModel.intercept as number, 0)).toBe(true) + }, 30000) + it('Works on small multi-output example (small example)', async function () { const lr = new LinearRegression() await lr.fit( diff --git a/src/linear_model/LinearRegression.ts b/src/linear_model/LinearRegression.ts index 1913ed2..c09a620 100644 --- a/src/linear_model/LinearRegression.ts +++ b/src/linear_model/LinearRegression.ts @@ -15,6 +15,7 @@ import { SGDRegressor } from './SgdRegressor' import { getBackend } from '../tf-singleton' +import { ModelFitArgs } from '../types' /** * LinearRegression implementation using gradient descent @@ -39,6 +40,7 @@ export interface LinearRegressionParams { * **default = true** */ fitIntercept?: boolean + modelFitOptions?: Partial } /* @@ -50,7 +52,7 @@ Next steps: /** Linear Least Squares * @example * ```js - * import {LinearRegression} from 'scikitjs' + * import { LinearRegression } from 'scikitjs' * * let X = [ * [1, 2], @@ -60,13 +62,16 @@ Next steps: * [10, 20] * ] * let y = [3, 5, 8, 8, 30] - * const lr = new LinearRegression({fitIntercept: false}) + * const lr = new LinearRegression({ fitIntercept: false }) await lr.fit(X, y) lr.coef.print() // probably around [1, 1] * ``` */ export class LinearRegression extends SGDRegressor { - constructor({ fitIntercept = true }: LinearRegressionParams = {}) { + constructor({ + fitIntercept = true, + modelFitOptions + }: LinearRegressionParams = {}) { let tf = getBackend() super({ modelCompileArgs: { @@ -80,7 +85,8 @@ export class LinearRegression extends SGDRegressor { verbose: 0, callbacks: [ tf.callbacks.earlyStopping({ monitor: 'mse', patience: 30 }) - ] + ], + ...modelFitOptions }, denseLayerArgs: { units: 1, diff --git a/src/linear_model/LogisticRegression.ts b/src/linear_model/LogisticRegression.ts index 159cd36..b235bb3 100644 --- a/src/linear_model/LogisticRegression.ts +++ b/src/linear_model/LogisticRegression.ts @@ -15,6 +15,7 @@ import { SGDClassifier } from './SgdClassifier' import { getBackend } from '../tf-singleton' +import { ModelFitArgs } from '../types' // First pass at a LogisticRegression implementation using gradient descent // Trying to mimic the API of scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html @@ -35,6 +36,7 @@ export interface LogisticRegressionParams { C?: number /** Whether or not the intercept should be estimator not. **default = true** */ fitIntercept?: boolean + modelFitOptions?: Partial } /** Builds a linear classification model with associated penalty and regularization @@ -63,7 +65,8 @@ export class LogisticRegression extends SGDClassifier { constructor({ penalty = 'l2', C = 1, - fitIntercept = true + fitIntercept = true, + modelFitOptions }: LogisticRegressionParams = {}) { // Assume Binary classification // If we call fit, and it isn't binary then update args @@ -80,7 +83,8 @@ export class LogisticRegression extends SGDClassifier { verbose: 0, callbacks: [ tf.callbacks.earlyStopping({ monitor: 'loss', patience: 50 }) - ] + ], + ...modelFitOptions }, denseLayerArgs: { units: 1, pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy