diff --git a/pgml-extension/pgml_rust/Cargo.toml b/pgml-extension/pgml_rust/Cargo.toml index 91d17c8e2..60041d765 100644 --- a/pgml-extension/pgml_rust/Cargo.toml +++ b/pgml-extension/pgml_rust/Cargo.toml @@ -16,14 +16,19 @@ pg14 = ["pgx/pg14", "pgx-tests/pg14" ] pg_test = [] [dependencies] -pgx = "=0.4.5" -xgboost = { path = "rust-xgboost" } -rustlearn = "0.5" +pgx = "0.4.5" once_cell = "1" rand = "0.8" +xgboost = { path = "rust-xgboost" } +smartcore = { version = "0.2.0", features = ["serde", "ndarray-bindings"] } +ndarray = { version = "0.15.6", features = ["serde", "blas"] } blas = { version = "0.22.0" } blas-src = { version = "0.8", features = ["openblas"] } openblas-src = { version = "0.10", features = ["cblas", "system"] } +serde = { version = "1.0.2" } +serde_json = { version = "1.0.85" } +rmp-serde = { version = "1.1.0" } +typetag = "0.2" [dev-dependencies] pgx-tests = "=0.4.5" diff --git a/pgml-extension/pgml_rust/pgml_rust.control b/pgml-extension/pgml_rust/pgml_rust.control index 05223ba7c..d5a55a8b8 100644 --- a/pgml-extension/pgml_rust/pgml_rust.control +++ b/pgml-extension/pgml_rust/pgml_rust.control @@ -1,5 +1,6 @@ -comment = 'pgml_rust: Created by pgx' +comment = 'pgml_rust: Created by the PostgresML team' default_version = '@CARGO_VERSION@' module_pathname = '$libdir/pgml_rust' relocatable = false superuser = false +schema = 'pgml_rust' diff --git a/pgml-extension/pgml_rust/sql/schema.sql b/pgml-extension/pgml_rust/sql/schema.sql index 15f832ec4..cda5a2a14 100644 --- a/pgml-extension/pgml_rust/sql/schema.sql +++ b/pgml-extension/pgml_rust/sql/schema.sql @@ -1,5 +1,3 @@ -CREATE SCHEMA IF NOT EXISTS pgml_rust; - --- --- Track of updates to data --- @@ -33,39 +31,76 @@ BEGIN ) THEN NEW.updated_at := clock_timestamp(); END IF; - RETURN NEW; + RETURN new; END; $$ LANGUAGE plpgsql; + --- --- Projects organize work --- CREATE TABLE IF NOT EXISTS pgml_rust.projects( id BIGSERIAL PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - task TEXT NOT NULL, + name TEXT NOT NULL, + task pgml_rust.task NOT NULL, created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp() ); SELECT pgml_rust.auto_updated_at('pgml_rust.projects'); +CREATE UNIQUE INDEX IF NOT EXISTS projects_name_idx ON pgml_rust.projects(name); -CREATE TABLE IF NOT EXISTS pgml_rust.models ( +--- +--- Snapshots freeze data for training +--- +CREATE TABLE IF NOT EXISTS pgml_rust.snapshots( id BIGSERIAL PRIMARY KEY, - project_id BIGINT NOT NULL REFERENCES pgml_rust.projects(id), - algorithm VARCHAR, - data BYTEA + relation_name TEXT NOT NULL, + y_column_name TEXT[] NOT NULL, + test_size FLOAT4 NOT NULL, + test_sampling pgml_rust.sampling NOT NULL, + status TEXT NOT NULL, + columns JSONB, + analysis JSONB, + created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp() ); +SELECT pgml_rust.auto_updated_at('pgml_rust.snapshots'); + --- ---- Deployments determine which model is live +--- Models save the learned parameters +--- +CREATE TABLE IF NOT EXISTS pgml_rust.models( + id BIGSERIAL PRIMARY KEY, + project_id BIGINT NOT NULL, + snapshot_id BIGINT NOT NULL, + algorithm TEXT NOT NULL, + hyperparams JSONB NOT NULL, + status TEXT NOT NULL, + metrics JSONB, + search pgml_rust.search, + search_params JSONB NOT NULL, + search_args JSONB NOT NULL, + created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id), + CONSTRAINT snapshot_id_fk FOREIGN KEY(snapshot_id) REFERENCES pgml_rust.snapshots(id) +); +CREATE INDEX IF NOT EXISTS models_project_id_idx ON pgml_rust.models(project_id); +CREATE INDEX IF NOT EXISTS models_snapshot_id_idx ON pgml_rust.models(snapshot_id); +SELECT pgml_rust.auto_updated_at('pgml_rust.models'); + + +--- +--- Deployements determine which model is live --- CREATE TABLE IF NOT EXISTS pgml_rust.deployments( id BIGSERIAL PRIMARY KEY, project_id BIGINT NOT NULL, model_id BIGINT NOT NULL, - strategy TEXT NOT NULL, + strategy pgml_rust.strategy NOT NULL, created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id), CONSTRAINT model_id_fk FOREIGN KEY(model_id) REFERENCES pgml_rust.models(id) @@ -73,3 +108,87 @@ CREATE TABLE IF NOT EXISTS pgml_rust.deployments( CREATE INDEX IF NOT EXISTS deployments_project_id_created_at_idx ON pgml_rust.deployments(project_id); CREATE INDEX IF NOT EXISTS deployments_model_id_created_at_idx ON pgml_rust.deployments(model_id); SELECT pgml_rust.auto_updated_at('pgml_rust.deployments'); + +--- +--- Distribute serialized models consistently for HA +--- +CREATE TABLE IF NOT EXISTS pgml_rust.files( + id BIGSERIAL PRIMARY KEY, + model_id BIGINT NOT NULL, + path TEXT NOT NULL, + part INTEGER NOT NULL, + created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(), + data BYTEA NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS files_model_id_path_part_idx ON pgml_rust.files(model_id, path, part); +SELECT pgml_rust.auto_updated_at('pgml_rust.files'); + +--- +--- Quick status check on the system. +--- +DROP VIEW IF EXISTS pgml_rust.overview; +CREATE VIEW pgml_rust.overview AS +SELECT + p.name, + d.created_at AS deployed_at, + p.task, + m.algorithm, + s.relation_name, + s.y_column_name, + s.test_sampling, + s.test_size +FROM pgml_rust.projects p +INNER JOIN pgml_rust.models m ON p.id = m.project_id +INNER JOIN pgml_rust.deployments d ON d.project_id = p.id +AND d.model_id = m.id +INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id +ORDER BY d.created_at DESC; + + +--- +--- List details of trained models. +--- +DROP VIEW IF EXISTS pgml_rust.trained_models; +CREATE VIEW pgml_rust.trained_models AS +SELECT + m.id, + p.name, + p.task, + m.algorithm, + m.created_at, + s.test_sampling, + s.test_size, + d.model_id IS NOT NULL AS deployed +FROM pgml_rust.projects p +INNER JOIN pgml_rust.models m ON p.id = m.project_id +INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id +LEFT JOIN ( + SELECT DISTINCT ON(project_id) + project_id, model_id, created_at + FROM pgml_rust.deployments + ORDER BY project_id, created_at desc +) d ON d.model_id = m.id +ORDER BY m.created_at DESC; + + +--- +--- List details of deployed models. +--- +DROP VIEW IF EXISTS pgml_rust.deployed_models; +CREATE VIEW pgml_rust.deployed_models AS +SELECT + m.id, + p.name, + p.task, + m.algorithm, + d.created_at as deployed_at +FROM pgml_rust.projects p +INNER JOIN ( + SELECT DISTINCT ON(project_id) + project_id, model_id, created_at + FROM pgml_rust.deployments + ORDER BY project_id, created_at desc +) d ON d.project_id = p.id +INNER JOIN pgml_rust.models m ON m.id = d.model_id +ORDER BY p.name ASC; diff --git a/pgml-extension/pgml_rust/src/api.rs b/pgml-extension/pgml_rust/src/api.rs new file mode 100644 index 000000000..22fbf33f3 --- /dev/null +++ b/pgml-extension/pgml_rust/src/api.rs @@ -0,0 +1,104 @@ +use pgx::*; + +use crate::orm::Algorithm; +use crate::orm::Model; +use crate::orm::Project; +use crate::orm::Sampling; +use crate::orm::Search; +use crate::orm::Snapshot; +use crate::orm::Strategy; +use crate::orm::Task; + +#[pg_extern] +fn train( + project_name: &str, + task: Option, + relation_name: Option, + y_column_name: Option, + algorithm: default!(Algorithm, "'linear'"), + hyperparams: default!(JsonB, "'{}'"), + search: Option, + search_params: default!(JsonB, "'{}'"), + search_args: default!(JsonB, "'{}'"), + test_size: default!(f32, 0.25), + test_sampling: default!(Sampling, "'last'"), +) { + let project = match Project::find_by_name(project_name) { + Some(project) => project, + None => Project::create(project_name, task.unwrap()), + }; + if task.is_some() && task.unwrap() != project.task { + error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + } + let snapshot = match relation_name { + None => project.last_snapshot().expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."), + Some(relation_name) => Snapshot::create(relation_name, y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`"), test_size, test_sampling) + }; + + // # Default repeatable random state when possible + // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); + // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: + // hyperparams["random_state"] = 0 + + let model = Model::create( + &project, + &snapshot, + algorithm, + hyperparams, + search, + search_params, + search_args, + ); + + // TODO move deployment into a struct and only deploy if new model is better than old model + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()), + ] + ); +} + +#[pg_extern] +fn predict(project_name: &str, features: Vec) -> f32 { + let estimator = crate::orm::estimator::find_deployed_estimator_by_project_name(project_name); + estimator.predict(features) +} + +// #[pg_extern] +// fn return_table_example() -> impl std::Iterator), name!(title, Option))> { +// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None) +// vec![tuple].into_iter() +// } + +#[pg_extern] +fn create_snapshot( + relation_name: &str, + y_column_name: &str, + test_size: f32, + test_sampling: Sampling, +) -> i64 { + let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling); + info!("{:?}", snapshot); + snapshot.id +} + +#[cfg(any(test, feature = "pg_test"))] +#[pg_schema] +mod tests { + use super::*; + + #[pg_test] + fn test_project_lifecycle() { + assert_eq!(Project::create("test", Task::regression).id, 1); + assert_eq!(Project::find(1).id, 1); + } + + #[pg_test] + fn test_snapshot_lifecycle() { + let snapshot = Snapshot::create("test", "column", 0.5, Sampling::last); + assert_eq!(snapshot.id, 1); + } +} diff --git a/pgml-extension/pgml_rust/src/lib.rs b/pgml-extension/pgml_rust/src/lib.rs index a1992f1db..76d733dce 100644 --- a/pgml-extension/pgml_rust/src/lib.rs +++ b/pgml-extension/pgml_rust/src/lib.rs @@ -1,19 +1,21 @@ extern crate blas; extern crate openblas_src; +extern crate serde; use once_cell::sync::Lazy; // 1.3.1 use pgx::*; use std::collections::HashMap; use std::fs; -use std::path::Path; use std::sync::Mutex; -use xgboost::{parameters, Booster, DMatrix}; +use xgboost::{Booster, DMatrix}; +pub mod api; +pub mod orm; pub mod vectors; pg_module_magic!(); -extension_sql_file!("../sql/schema.sql", name = "bootstrap_raw", bootstrap); +extension_sql_file!("../sql/schema.sql", name = "bootstrap_raw"); extension_sql_file!( "../sql/diabetes.sql", name = "diabetes", @@ -25,371 +27,131 @@ extension_sql_file!( // This space here is connection-specific. static MODELS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); -/// Predict a novel data point using the model created by pgml_train. -/// -/// Example: -/// ``` -/// SELECT * FROM pgml_predict(ARRAY[1, 2, 3]); -#[pg_schema] -mod pgml_rust { - use super::*; - - #[derive(PostgresEnum, Copy, Clone)] - #[allow(non_camel_case_types)] - enum Algorithm { - xgboost, - } - - #[derive(PostgresEnum, Copy, Clone, PartialEq, Debug)] - #[allow(non_camel_case_types)] - enum ProjectTask { - regression, - classification, - } - - impl PartialEq for ProjectTask { - fn eq(&self, other: &String) -> bool { - match *self { - ProjectTask::regression => "regression" == other, - ProjectTask::classification => "classification" == other, - } - } - } - - impl ProjectTask { - pub fn to_string(&self) -> String { - match *self { - ProjectTask::regression => "regression".to_string(), - ProjectTask::classification => "classification".to_string(), - } - } - } - - /// Main training function to train an XGBoost model on a dataset. - /// - /// Example: - /// - /// ``` - /// SELECT * FROM pgml_rust.train('pgml_rust.diabetes', ARRAY['age', 'sex'], 'target'); - #[pg_extern] - fn train( - project_name: String, - task: ProjectTask, - relation_name: String, - label: String, - _algorithm: Algorithm, - hyperparams: Json, - ) -> i64 { - let parts = relation_name - .split(".") - .map(|name| name.to_string()) - .collect::>(); - - let (schema_name, table_name) = match parts.len() { - 1 => (String::from("public"), parts[0].clone()), - 2 => (parts[0].clone(), parts[1].clone()), - _ => error!( - "Relation name {} is not parsable into schema name and table name", - relation_name - ), - }; - - let (mut x, mut y, mut num_rows, mut num_features) = (vec![], vec![], 0, 0); +#[pg_extern] +fn model_predict(model_id: i64, features: Vec) -> f32 { + let mut guard = MODELS.lock().unwrap(); - let hyperparams = hyperparams.0; + match guard.get(&model_id) { + Some(data) => { + let bst = Booster::load_buffer(&data).unwrap(); + let dmat = DMatrix::from_dense(&features, 1).unwrap(); - let (projet_id, project_task) = Spi::get_two_with_args::("INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET name = $1 RETURNING id, task", - vec![ - (PgBuiltInOids::TEXTOID.oid(), project_name.clone().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), - ]); - - let (projet_id, project_task) = (projet_id.unwrap(), project_task.unwrap()); - - if project_task != task.to_string() { - error!( - "Project '{}' already exists with a different objective: {}", - project_name, project_task - ); + bst.predict(&dmat).unwrap()[0] } - Spi::connect(|client| { - let mut features = Vec::new(); - - client.select("SELECT CAST(column_name AS TEXT) FROM information_schema.columns WHERE table_name = $1 AND table_schema = $2 AND column_name != $3", - None, - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), label.clone().into_datum()), - ])) - .for_each(|row| { - features.push(row[1].value::().unwrap()) - }); - - let features = features - .into_iter() - .map(|column| format!("CAST({} AS REAL)", column)) - .collect::>(); - - let query = format!( - "SELECT {}, CAST({} AS REAL) FROM {} ORDER BY RANDOM()", - features.clone().join(", "), - label, - relation_name - ); + None => { + match Spi::get_one_with_args::>( + "SELECT data FROM pgml_rust.models WHERE id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], + ) { + Some(data) => { + info!("Model cache cold, loading from \"pgml_rust\".\"models\""); - info!("Fetching data: {}", query); + guard.insert(model_id, data.clone()); + let bst = Booster::load_buffer(&data).unwrap(); + let dmat = DMatrix::from_dense(&features, 1).unwrap(); - client.select(&query, None, None).for_each(|row| { - // Postgres arrays start at one and for some reason - // so do these tuple indexes. - for i in 1..features.len() + 1 { - x.push(row[i].value::().unwrap_or(0 as f32)); + bst.predict(&dmat).unwrap()[0] } - y.push(row[features.len() + 1].value::().unwrap_or(0 as f32)); - num_rows += 1; - }); - - num_features = features.len(); - - Ok(Some(())) - }); - - // todo parameterize test split instead of 0.5 - let test_rows = (num_rows as f32 * 0.5).round() as usize; - let train_rows = num_rows - test_rows; - let mut dtrain = DMatrix::from_dense(&x[..train_rows * num_features], train_rows).unwrap(); - let mut dtest = DMatrix::from_dense(&x[train_rows * num_features..], test_rows).unwrap(); - dtrain.set_labels(&y[..train_rows]).unwrap(); - dtest.set_labels(&y[train_rows..]).unwrap(); - - - // specify datasets to evaluate against during training - let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; - - // configure objectives, metrics, etc. - let learning_params = parameters::learning::LearningTaskParametersBuilder::default() - .objective(match task { - ProjectTask::regression => xgboost::parameters::learning::Objective::RegLinear, - ProjectTask::classification => { - xgboost::parameters::learning::Objective::RegLogistic + None => { + error!("No model with id = {} found", model_id); } - }) - .build() - .unwrap(); - - // configure the tree-based learning model's parameters - let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() - .max_depth(match hyperparams.get("max_depth") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) - .eta(0.3) - .build() - .unwrap(); - - // overall configuration for Booster - let booster_params = parameters::BoosterParametersBuilder::default() - .booster_type(parameters::BoosterType::Tree(tree_params)) - .learning_params(learning_params) - .verbose(true) - .build() - .unwrap(); - - - // overall configuration for training/evaluation - let params = parameters::TrainingParametersBuilder::default() - .dtrain(&dtrain) // dataset to train with - .boost_rounds(match hyperparams.get("n_estimators") { - Some(value) => value.as_u64().unwrap_or(2) as u32, - None => 2, - }) // number of training iterations - .booster_params(booster_params) // model parameters - .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration - .build() - .unwrap(); - - // train model, and print evaluation data - let bst = match Booster::train(¶ms) { - Ok(bst) => bst, - Err(err) => error!("{}", err), - }; - - let r: u64 = rand::random(); - let path = format!("/tmp/pgml_rust_{}.bin", r); - - bst.save(Path::new(&path)).unwrap(); - - let bytes = fs::read(&path).unwrap(); - - let model_id = Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, project_id, algorithm, data) VALUES (DEFAULT, $1, 'xgboost', $2) RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), projet_id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()) - ] - ).unwrap(); - - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, 'last_trained') RETURNING id", - vec![ - (PgBuiltInOids::INT8OID.oid(), projet_id.into_datum()), - (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), - ] - ); - - model_id - } - - #[pg_extern] - fn predict(project_name: String, features: Vec) -> f32 { - let model_id = Spi::get_one_with_args( - "SELECT model_id - FROM pgml_rust.deployments - INNER JOIN pgml_rust.projects ON - pgml_rust.deployments.project_id = pgml_rust.projects.id - AND pgml_rust.projects.name = $1 - ORDER BY pgml_rust.deployments.id DESC LIMIT 1", - vec![( - PgBuiltInOids::TEXTOID.oid(), - project_name.clone().into_datum(), - )], - ); - - match model_id { - Some(model_id) => model_predict(model_id, features), - None => error!("Project '{}' doesn't exist", project_name), + } } } +} - #[pg_extern] - fn model_predict(model_id: i64, features: Vec) -> f32 { - let mut guard = MODELS.lock().unwrap(); - - match guard.get(&model_id) { - Some(data) => { - let bst = Booster::load_buffer(&data).unwrap(); - let dmat = DMatrix::from_dense(&features, 1).unwrap(); - - bst.predict(&dmat).unwrap()[0] - } - - None => { - match Spi::get_one_with_args::>( - "SELECT data FROM pgml_rust.models WHERE id = $1", - vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], - ) { - Some(data) => { - info!("Model cache cold, loading from \"pgml_rust\".\"models\""); - - guard.insert(model_id, data.clone()); - let bst = Booster::load_buffer(&data).unwrap(); - let dmat = DMatrix::from_dense(&features, 1).unwrap(); +#[pg_extern] +fn model_predict_batch(model_id: i64, features: Vec, num_rows: i32) -> Vec { + let mut guard = MODELS.lock().unwrap(); - bst.predict(&dmat).unwrap()[0] - } - None => { - error!("No model with id = {} found", model_id); - } - } - } - } + if num_rows < 0 { + error!("Number of rows has to be greater than 0"); } - #[pg_extern] - fn model_predict_batch(model_id: i64, features: Vec, num_rows: i32) -> Vec { - let mut guard = MODELS.lock().unwrap(); + match guard.get(&model_id) { + Some(data) => { + let bst = Booster::load_buffer(&data).unwrap(); + let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); - if num_rows < 0 { - error!("Number of rows has to be greater than 0"); + bst.predict(&dmat).unwrap() } - match guard.get(&model_id) { - Some(data) => { - let bst = Booster::load_buffer(&data).unwrap(); - let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); - - bst.predict(&dmat).unwrap() - } - - None => { - match Spi::get_one_with_args::>( - "SELECT data FROM pgml_rust.models WHERE id = $1", - vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], - ) { - Some(data) => { - info!("Model cache cold, loading from \"pgml_rust\".\"models\""); + None => { + match Spi::get_one_with_args::>( + "SELECT data FROM pgml_rust.models WHERE id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], + ) { + Some(data) => { + info!("Model cache cold, loading from \"pgml_rust\".\"models\""); - guard.insert(model_id, data.clone()); - let bst = Booster::load_buffer(&data).unwrap(); - let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); + guard.insert(model_id, data.clone()); + let bst = Booster::load_buffer(&data).unwrap(); + let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); - bst.predict(&dmat).unwrap() - } - None => { - error!("No model with id = {} found", model_id); - } + bst.predict(&dmat).unwrap() + } + None => { + error!("No model with id = {} found", model_id); } } } } +} - /// Load a model into the extension. The model is saved in our table, - /// which is then replicated to replicas for load balancing. - #[pg_extern] - fn load_model(data: Vec) -> i64 { - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, algorithm, data) VALUES (DEFAULT, 'xgboost', $1) RETURNING id", - vec![ - (PgBuiltInOids::BYTEAOID.oid(), data.into_datum()), - ], - ).unwrap() - } - - /// Load a model into the extension from a file. - #[pg_extern] - fn load_model_from_file(path: String) -> i64 { - let bytes = fs::read(&path).unwrap(); +/// Load a model into the extension. The model is saved in our table, +/// which is then replicated to replicas for load balancing. +#[pg_extern] +fn load_model(data: Vec) -> i64 { + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.models (id, algorithm, data) VALUES (DEFAULT, 'xgboost', $1) RETURNING id", + vec![ + (PgBuiltInOids::BYTEAOID.oid(), data.into_datum()), + ], + ).unwrap() +} - Spi::get_one_with_args::( - "INSERT INTO pgml_rust.models (id, algorithm, data) VALUES (DEFAULT, 'xgboost', $1) RETURNING id", - vec![ - (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), - ], - ).unwrap() - } +/// Load a model into the extension from a file. +#[pg_extern] +fn load_model_from_file(path: String) -> i64 { + let bytes = fs::read(&path).unwrap(); + + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.models (id, algorithm, data) VALUES (DEFAULT, 'xgboost', $1) RETURNING id", + vec![ + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ], + ).unwrap() +} - #[pg_extern] - fn delete_model(model_id: i64) { - Spi::run(&format!( - "DELETE FROM pgml_rust.models WHERE id = {}", - model_id - )); - } +#[pg_extern] +fn delete_model(model_id: i64) { + Spi::run(&format!( + "DELETE FROM pgml_rust.models WHERE id = {}", + model_id + )); +} - #[pg_extern] - fn dump_model(model_id: i64) -> String { - let bytes = Spi::get_one_with_args::>( - "SELECT data FROM pgml_rust.models WHERE id = $1", - vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], - ); +#[pg_extern] +fn dump_model(model_id: i64) -> String { + let bytes = Spi::get_one_with_args::>( + "SELECT data FROM pgml_rust.models WHERE id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], + ); - match bytes { - Some(bytes) => match Booster::load_buffer(&bytes) { - Ok(bst) => bst.dump_model(true, None).unwrap(), - Err(err) => error!("Could not load XGBoost model: {:?}", err), - }, + match bytes { + Some(bytes) => match Booster::load_buffer(&bytes) { + Ok(bst) => bst.dump_model(true, None).unwrap(), + Err(err) => error!("Could not load XGBoost model: {:?}", err), + }, - None => error!("Model with id = {} does not exist", model_id), - } + None => error!("Model with id = {} does not exist", model_id), } } #[cfg(any(test, feature = "pg_test"))] #[pg_schema] -mod tests { -} +mod tests {} #[cfg(test)] pub mod pg_test { diff --git a/pgml-extension/pgml_rust/src/orm/algorithm.rs b/pgml-extension/pgml_rust/src/orm/algorithm.rs new file mode 100644 index 000000000..f45507fcb --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/algorithm.rs @@ -0,0 +1,30 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Algorithm { + linear, + xgboost, +} + +impl std::str::FromStr for Algorithm { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "linear" => Ok(Algorithm::linear), + "xgboost" => Ok(Algorithm::xgboost), + _ => Err(()), + } + } +} + +impl std::string::ToString for Algorithm { + fn to_string(&self) -> String { + match *self { + Algorithm::linear => "linear".to_string(), + Algorithm::xgboost => "xgboost".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/orm/dataset.rs b/pgml-extension/pgml_rust/src/orm/dataset.rs new file mode 100644 index 000000000..0950b8954 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/dataset.rs @@ -0,0 +1,27 @@ +pub struct Dataset { + pub x: Vec, + pub y: Vec, + pub num_features: usize, + pub num_labels: usize, + pub num_rows: usize, + pub num_train_rows: usize, + pub num_test_rows: usize, +} + +impl Dataset { + pub fn x_train(&self) -> &[f32] { + &self.x[..self.num_train_rows * self.num_features] + } + + pub fn x_test(&self) -> &[f32] { + &self.x[self.num_train_rows * self.num_features..] + } + + pub fn y_train(&self) -> &[f32] { + &self.y[..self.num_train_rows * self.num_labels] + } + + pub fn y_test(&self) -> &[f32] { + &self.y[self.num_train_rows * self.num_labels..] + } +} diff --git a/pgml-extension/pgml_rust/src/orm/estimator.rs b/pgml-extension/pgml_rust/src/orm/estimator.rs new file mode 100644 index 000000000..e31fdfc04 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/estimator.rs @@ -0,0 +1,258 @@ +use std::collections::HashMap; +use std::fmt::Debug; +use std::str::FromStr; +use std::sync::Arc; +use std::sync::Mutex; + +use ndarray::{Array1, Array2}; +use once_cell::sync::Lazy; +use pgx::*; +use xgboost::{Booster, DMatrix}; + +use crate::orm::Algorithm; +use crate::orm::Dataset; +use crate::orm::Task; + +static DEPLOYED_ESTIMATORS_BY_PROJECT_NAME: Lazy>>>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc> { + { + let estimators = DEPLOYED_ESTIMATORS_BY_PROJECT_NAME.lock().unwrap(); + let estimator = estimators.get(name); + if estimator.is_some() { + return estimator.unwrap().clone(); + } + } + + let (task, algorithm, data) = Spi::get_three_with_args::>( + " + SELECT projects.task::TEXT, models.algorithm::TEXT, files.data + FROM pgml_rust.files + JOIN pgml_rust.models + ON models.id = files.model_id + JOIN pgml_rust.deployments + ON deployments.model_id = models.id + JOIN pgml_rust.projects + ON projects.id = deployments.project_id + WHERE projects.name = $1 + ORDER by deployments.created_at DESC + LIMIT 1;", + vec![(PgBuiltInOids::TEXTOID.oid(), name.into_datum())], + ); + let task = Task::from_str( + &task.expect( + format!( + "Project {} does not have a trained and deployed model.", + name + ) + .as_str(), + ), + ) + .unwrap(); + let algorithm = Algorithm::from_str( + &algorithm.expect( + format!( + "Project {} does not have a trained and deployed model.", + name + ) + .as_str(), + ), + ) + .unwrap(); + let data = data.expect( + format!( + "Project {} does not have a trained and deployed model.", + name + ) + .as_str(), + ); + + let e: Box = match task { + Task::regression => match algorithm { + Algorithm::linear => { + let estimator: smartcore::linear::linear_regression::LinearRegression< + f32, + Array2, + > = rmp_serde::from_read(&*data).unwrap(); + Box::new(estimator) + } + Algorithm::xgboost => { + let bst = Booster::load_buffer(&*data).unwrap(); + Box::new(BoosterBox::new(bst)) + } + }, + Task::classification => match algorithm { + Algorithm::linear => { + let estimator: smartcore::linear::logistic_regression::LogisticRegression< + f32, + Array2, + > = rmp_serde::from_read(&*data).unwrap(); + Box::new(estimator) + } + Algorithm::xgboost => { + let bst = Booster::load_buffer(&*data).unwrap(); + Box::new(BoosterBox::new(bst)) + } + }, + }; + + let mut estimators = DEPLOYED_ESTIMATORS_BY_PROJECT_NAME.lock().unwrap(); + estimators.insert(name.to_string(), Arc::new(e)); + estimators.get(name).unwrap().clone() +} + +fn test_smartcore( + predictor: &dyn smartcore::api::Predictor, Array1>, + task: Task, + dataset: &Dataset, +) -> HashMap { + let x_test = Array2::from_shape_vec( + (dataset.num_test_rows, dataset.num_features), + dataset.x_test().to_vec(), + ) + .unwrap(); + let y_test = Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap(); + let y_hat = smartcore::api::Predictor::predict(predictor, &x_test).unwrap(); + calc_metrics(&y_test, &y_hat, task) +} + +fn predict_smartcore( + predictor: &dyn smartcore::api::Predictor, Array1>, + features: Vec, +) -> f32 { + let features = Array2::from_shape_vec((1, features.len()), features).unwrap(); + smartcore::api::Predictor::predict(predictor, &features).unwrap()[0] +} + +fn calc_metrics(y_test: &Array1, y_hat: &Array1, task: Task) -> HashMap { + let mut results = HashMap::new(); + match task { + Task::regression => { + results.insert("r2".to_string(), smartcore::metrics::r2(y_test, y_hat)); + results.insert( + "mean_absolute_error".to_string(), + smartcore::metrics::mean_absolute_error(y_test, y_hat), + ); + results.insert( + "mean_squared_error".to_string(), + smartcore::metrics::mean_squared_error(y_test, y_hat), + ); + } + Task::classification => { + results.insert( + "f1".to_string(), + smartcore::metrics::f1::F1 { beta: 1.0 }.get_score(y_test, y_hat), + ); + results.insert( + "precision".to_string(), + smartcore::metrics::precision(y_test, y_hat), + ); + results.insert( + "accuracy".to_string(), + smartcore::metrics::accuracy(y_test, y_hat), + ); + results.insert( + "roc_auc_score".to_string(), + smartcore::metrics::roc_auc_score(y_test, y_hat), + ); + results.insert( + "recall".to_string(), + smartcore::metrics::recall(y_test, y_hat), + ); + } + } + results +} + +#[typetag::serialize(tag = "type")] +pub trait Estimator: Send + Sync + Debug { + fn test(&self, task: Task, data: &Dataset) -> HashMap; + fn predict(&self, features: Vec) -> f32; +} + +#[typetag::serialize] +impl Estimator for smartcore::linear::linear_regression::LinearRegression> { + fn test(&self, task: Task, data: &Dataset) -> HashMap { + test_smartcore(self, task, data) + } + + fn predict(&self, features: Vec) -> f32 { + predict_smartcore(self, features) + } +} + +#[typetag::serialize] +impl Estimator for smartcore::linear::logistic_regression::LogisticRegression> { + fn test(&self, task: Task, data: &Dataset) -> HashMap { + test_smartcore(self, task, data) + } + + fn predict(&self, features: Vec) -> f32 { + predict_smartcore(self, features) + } +} + +pub struct BoosterBox { + contents: Box, +} + +impl BoosterBox { + pub fn new(contents: xgboost::Booster) -> Self { + BoosterBox { + contents: Box::new(contents), + } + } +} + +impl std::ops::Deref for BoosterBox { + type Target = xgboost::Booster; + + fn deref(&self) -> &Self::Target { + self.contents.as_ref() + } +} + +impl std::ops::DerefMut for BoosterBox { + fn deref_mut(&mut self) -> &mut Self::Target { + self.contents.as_mut() + } +} + +unsafe impl Send for BoosterBox {} +unsafe impl Sync for BoosterBox {} +impl std::fmt::Debug for BoosterBox { + fn fmt( + &self, + formatter: &mut std::fmt::Formatter<'_>, + ) -> std::result::Result<(), std::fmt::Error> { + formatter.debug_struct("BoosterBox").finish() + } +} +impl serde::Serialize for BoosterBox { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + todo!("this is never hit for now, since we'd need also need a deserializer.") + } +} + +#[typetag::serialize] +impl Estimator for BoosterBox { + fn test(&self, task: Task, dataset: &Dataset) -> HashMap { + let mut features = DMatrix::from_dense(dataset.x_test(), dataset.num_test_rows).unwrap(); + features.set_labels(dataset.y_test()).unwrap(); + let y_test = + Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap(); + let y_hat = self.contents.predict(&features).unwrap(); + let y_hat = Array1::from_shape_vec(dataset.num_test_rows, y_hat).unwrap(); + + calc_metrics(&y_test, &y_hat, task) + } + + fn predict(&self, features: Vec) -> f32 { + let features = DMatrix::from_dense(&features, 1).unwrap(); + self.contents.predict(&features).unwrap()[0] + } +} diff --git a/pgml-extension/pgml_rust/src/orm/mod.rs b/pgml-extension/pgml_rust/src/orm/mod.rs new file mode 100644 index 000000000..4a9b78b58 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/mod.rs @@ -0,0 +1,21 @@ +pub mod algorithm; +pub mod dataset; +pub mod estimator; +pub mod model; +pub mod project; +pub mod sampling; +pub mod search; +pub mod snapshot; +pub mod strategy; +pub mod task; + +pub use algorithm::Algorithm; +pub use dataset::Dataset; +pub use estimator::Estimator; +pub use model::Model; +pub use project::Project; +pub use sampling::Sampling; +pub use search::Search; +pub use snapshot::Snapshot; +pub use strategy::Strategy; +pub use task::Task; diff --git a/pgml-extension/pgml_rust/src/orm/model.rs b/pgml-extension/pgml_rust/src/orm/model.rs new file mode 100644 index 000000000..b78800108 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/model.rs @@ -0,0 +1,228 @@ +use std::str::FromStr; + +use ndarray::{Array1, Array2}; +use pgx::*; +use serde_json::json; +use xgboost::{parameters, Booster, DMatrix}; + +use crate::orm::estimator::BoosterBox; +use crate::orm::Algorithm; +use crate::orm::Dataset; +use crate::orm::Estimator; +use crate::orm::Project; +use crate::orm::Search; +use crate::orm::Snapshot; +use crate::orm::Task; + +#[derive(Debug)] +pub struct Model { + pub id: i64, + pub project_id: i64, + pub snapshot_id: i64, + pub algorithm: Algorithm, + pub hyperparams: JsonB, + pub status: String, + pub metrics: Option, + pub search: Option, + pub search_params: JsonB, + pub search_args: JsonB, + pub created_at: Timestamp, + pub updated_at: Timestamp, + estimator: Option>, +} + +impl Model { + pub fn create( + project: &Project, + snapshot: &Snapshot, + algorithm: Algorithm, + hyperparams: JsonB, + search: Option, + search_params: JsonB, + search_args: JsonB, + ) -> Model { + let mut model: Option = None; + + Spi::connect(|client| { + let result = client.select(" + INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args) + VALUES ($1, $2, $3, $4, $5, $6::pgml_rust.search, $7, $8) + RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), snapshot.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), algorithm.to_string().into_datum()), + (PgBuiltInOids::JSONBOID.oid(), hyperparams.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), search.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()), + ]) + ).first(); + if result.len() > 0 { + model = Some(Model { + id: result.get_datum(1).unwrap(), + project_id: result.get_datum(2).unwrap(), + snapshot_id: result.get_datum(3).unwrap(), + algorithm: Algorithm::from_str(result.get_datum(4).unwrap()).unwrap(), + hyperparams: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + metrics: result.get_datum(7), + search: search, // TODO + search_params: result.get_datum(9).unwrap(), + search_args: result.get_datum(10).unwrap(), + created_at: result.get_datum(11).unwrap(), + updated_at: result.get_datum(12).unwrap(), + estimator: None, + }); + } + + Ok(Some(1)) + }); + let mut model = model.unwrap(); + let dataset = snapshot.dataset(); + model.fit(&project, &dataset); + model.test(&project, &dataset); + model + } + + fn fit(&mut self, project: &Project, dataset: &Dataset) { + let hyperparams: &serde_json::Value = &self.hyperparams.0; + let hyperparams = hyperparams.as_object().unwrap(); + + self.estimator = match self.algorithm { + Algorithm::linear => { + let x_train = Array2::from_shape_vec( + (dataset.num_train_rows, dataset.num_features), + dataset.x_train().to_vec(), + ) + .unwrap(); + let y_train = + Array1::from_shape_vec(dataset.num_train_rows, dataset.y_train().to_vec()) + .unwrap(); + let estimator: Option> = match project.task { + Task::regression => Some(Box::new( + smartcore::linear::linear_regression::LinearRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(), + )), + Task::classification => Some(Box::new( + smartcore::linear::logistic_regression::LogisticRegression::fit( + &x_train, + &y_train, + Default::default(), + ) + .unwrap(), + )), + }; + let bytes: Vec = rmp_serde::to_vec(&*estimator.as_ref().unwrap()).unwrap(); + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ] + ).unwrap(); + estimator + } + Algorithm::xgboost => { + let mut dtrain = + DMatrix::from_dense(&dataset.x_train(), dataset.num_train_rows).unwrap(); + let mut dtest = + DMatrix::from_dense(&dataset.x_test(), dataset.num_test_rows).unwrap(); + dtrain.set_labels(&dataset.y_train()).unwrap(); + dtest.set_labels(&dataset.y_test()).unwrap(); + + // specify datasets to evaluate against during training + let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; + + // configure objectives, metrics, etc. + let learning_params = + parameters::learning::LearningTaskParametersBuilder::default() + .objective(match project.task { + Task::regression => xgboost::parameters::learning::Objective::RegLinear, + Task::classification => { + xgboost::parameters::learning::Objective::RegLogistic + } + }) + .build() + .unwrap(); + + // configure the tree-based learning model's parameters + let tree_params = parameters::tree::TreeBoosterParametersBuilder::default() + .max_depth(match hyperparams.get("max_depth") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) + .eta(0.3) + .build() + .unwrap(); + + // overall configuration for Booster + let booster_params = parameters::BoosterParametersBuilder::default() + .booster_type(parameters::BoosterType::Tree(tree_params)) + .learning_params(learning_params) + .verbose(true) + .build() + .unwrap(); + + // overall configuration for training/evaluation + let params = parameters::TrainingParametersBuilder::default() + .dtrain(&dtrain) // dataset to train with + .boost_rounds(match hyperparams.get("n_estimators") { + Some(value) => value.as_u64().unwrap_or(2) as u32, + None => 2, + }) // number of training iterations + .booster_params(booster_params) // model parameters + .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration + .build() + .unwrap(); + + // train model, and print evaluation data + let bst = match Booster::train(¶ms) { + Ok(bst) => bst, + Err(err) => error!("{}", err), + }; + + let r: u64 = rand::random(); + let path = format!("/tmp/pgml_rust_{}.bin", r); + + bst.save(std::path::Path::new(&path)).unwrap(); + + let bytes = std::fs::read(&path).unwrap(); + Spi::get_one_with_args::( + "INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), bytes.into_datum()), + ] + ).unwrap(); + Some(Box::new(BoosterBox::new(bst))) + } + }; + } + + fn test(&mut self, project: &Project, dataset: &Dataset) { + let metrics = self + .estimator + .as_ref() + .unwrap() + .test(project.task, &dataset); + self.metrics = Some(JsonB(json!(metrics.clone()))); + Spi::get_one_with_args::( + "UPDATE pgml_rust.models SET metrics = $1 WHERE id = $2 RETURNING id", + vec![ + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(json!(metrics)).into_datum(), + ), + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + ], + ) + .unwrap(); + } +} diff --git a/pgml-extension/pgml_rust/src/orm/project.rs b/pgml-extension/pgml_rust/src/orm/project.rs new file mode 100644 index 000000000..0d740df42 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/project.rs @@ -0,0 +1,96 @@ +use std::str::FromStr; + +use pgx::*; + +use crate::orm::Snapshot; +use crate::orm::Task; + +#[derive(Debug)] +pub struct Project { + pub id: i64, + pub name: String, + pub task: Task, + pub created_at: Timestamp, + pub updated_at: Timestamp, +} + +impl Project { + pub fn find(id: i64) -> Option { + let mut project: Option = None; + + Spi::connect(|client| { + let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE id = $1 LIMIT 1;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), id.into_datum()), + ]) + ).first(); + if result.len() > 0 { + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); + } + Ok(Some(1)) + }); + + project + } + + pub fn find_by_name(name: &str) -> Option { + let mut project = None; + + Spi::connect(|client| { + let result = client.select("SELECT id, name, task, created_at, updated_at FROM pgml_rust.projects WHERE name = $1 LIMIT 1;", + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), + ]) + ).first(); + if result.len() > 0 { + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: Task::from_str(result.get_datum(3).unwrap()).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); + } + Ok(Some(1)) + }); + + project + } + + pub fn create(name: &str, task: Task) -> Project { + let mut project: Option = None; + + Spi::connect(|client| { + let result = client.select(r#"INSERT INTO pgml_rust.projects (name, task) VALUES ($1, $2::pgml_rust.task) RETURNING id, name, task, created_at, updated_at;"#, + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), + ]) + ).first(); + if result.len() > 0 { + project = Some(Project { + id: result.get_datum(1).unwrap(), + name: result.get_datum(2).unwrap(), + task: result.get_datum(3).unwrap(), + created_at: result.get_datum(4).unwrap(), + updated_at: result.get_datum(5).unwrap(), + }); + } + Ok(Some(1)) + }); + project.unwrap() + } + + pub fn last_snapshot(&self) -> Option { + Snapshot::find_last_by_project_id(self.id) + } +} diff --git a/pgml-extension/pgml_rust/src/orm/sampling.rs b/pgml-extension/pgml_rust/src/orm/sampling.rs new file mode 100644 index 000000000..ff19d97b4 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/sampling.rs @@ -0,0 +1,30 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Sampling { + random, + last, +} + +impl std::str::FromStr for Sampling { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "random" => Ok(Sampling::random), + "last" => Ok(Sampling::last), + _ => Err(()), + } + } +} + +impl std::string::ToString for Sampling { + fn to_string(&self) -> String { + match *self { + Sampling::random => "random".to_string(), + Sampling::last => "last".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/orm/search.rs b/pgml-extension/pgml_rust/src/orm/search.rs new file mode 100644 index 000000000..f96170fd8 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/search.rs @@ -0,0 +1,33 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Search { + grid, + random, + none, +} + +impl std::str::FromStr for Search { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "grid" => Ok(Search::grid), + "random" => Ok(Search::random), + "none" => Ok(Search::none), + _ => Err(()), + } + } +} + +impl std::string::ToString for Search { + fn to_string(&self) -> String { + match *self { + Search::grid => "grid".to_string(), + Search::random => "random".to_string(), + Search::none => "none".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/orm/snapshot.rs b/pgml-extension/pgml_rust/src/orm/snapshot.rs new file mode 100644 index 000000000..fdb175983 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/snapshot.rs @@ -0,0 +1,300 @@ +use std::collections::HashMap; + +use pgx::*; +use serde_json::json; + +use crate::orm::Dataset; +use crate::orm::Sampling; + +#[derive(Debug)] +pub struct Snapshot { + pub id: i64, + pub relation_name: String, + pub y_column_name: Vec, + pub test_size: f32, + pub test_sampling: Sampling, + pub status: String, + pub columns: Option, + pub analysis: Option, + pub created_at: Timestamp, + pub updated_at: Timestamp, +} + +impl Snapshot { + pub fn find_last_by_project_id(project_id: i64) -> Option { + let mut snapshot = None; + Spi::connect(|client| { + let result = client.select( + "SELECT snapshots.id, snapshots.relation_name, snapshots.y_column_name, snapshots.test_size, snapshots.test_sampling, snapshots.status, snapshots.columns, snapshots.analysis, snapshots.created_at, snapshots.updated_at + FROM pgml_rust.snapshots + JOIN pgml_rust.models + ON models.snapshot_id = snapshots.id + AND models.project_id = $1 + ORDER BY snapshots.id DESC + LIMIT 1; + ", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + ]) + ).first(); + if result.len() > 0 { + snapshot = Some(Snapshot { + id: result.get_datum(1).unwrap(), + relation_name: result.get_datum(2).unwrap(), + y_column_name: result.get_datum(3).unwrap(), + test_size: result.get_datum(4).unwrap(), + test_sampling: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + columns: result.get_datum(7), + analysis: result.get_datum(8), + created_at: result.get_datum(9).unwrap(), + updated_at: result.get_datum(10).unwrap(), + }); + } + Ok(Some(1)) + }); + snapshot + } + + pub fn create( + relation_name: &str, + y_column_name: &str, + test_size: f32, + test_sampling: Sampling, + ) -> Snapshot { + let mut snapshot: Option = None; + + Spi::connect(|client| { + let result = client.select("INSERT INTO pgml_rust.snapshots (relation_name, y_column_name, test_size, test_sampling, status) VALUES ($1, $2, $3, $4::pgml_rust.sampling, $5) RETURNING id, relation_name, y_column_name, test_size, test_sampling, status, columns, analysis, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()), + (PgBuiltInOids::TEXTARRAYOID.oid(), vec![y_column_name].into_datum()), + (PgBuiltInOids::FLOAT4OID.oid(), test_size.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), test_sampling.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()), + ]) + ).first(); + let mut s = Snapshot { + id: result.get_datum(1).unwrap(), + relation_name: result.get_datum(2).unwrap(), + y_column_name: result.get_datum(3).unwrap(), + test_size: result.get_datum(4).unwrap(), + test_sampling: result.get_datum(5).unwrap(), + status: result.get_datum(6).unwrap(), + columns: None, + analysis: None, + created_at: result.get_datum(9).unwrap(), + updated_at: result.get_datum(10).unwrap(), + }; + let mut sql = format!( + r#"CREATE TABLE "pgml_rust"."snapshot_{}" AS SELECT * FROM {}"#, + s.id, s.relation_name + ); + if s.test_sampling == Sampling::random { + sql += " ORDER BY random()"; + } + client.select(&sql, None, None); + client.select( + r#"UPDATE "pgml_rust"."snapshots" SET status = 'snapped' WHERE id = $1"#, + None, + Some(vec![(PgBuiltInOids::INT8OID.oid(), s.id.into_datum())]), + ); + s.analyze(); + snapshot = Some(s); + Ok(Some(1)) + }); + + snapshot.unwrap() + } + + fn analyze(&mut self) { + Spi::connect(|client| { + let parts = self + .relation_name + .split(".") + .map(|name| name.to_string()) + .collect::>(); + let (schema_name, table_name) = match parts.len() { + 1 => (String::from("public"), parts[0].clone()), + 2 => (parts[0].clone(), parts[1].clone()), + _ => error!( + "Relation name {} is not parsable into schema name and table name", + self.relation_name + ), + }; + let mut columns = HashMap::::new(); + client.select("SELECT column_name::TEXT, data_type::TEXT FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2", + None, + Some(vec![ + (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()), + ])) + .for_each(|row| { + columns.insert(row[1].value::().unwrap(), row[2].value::().unwrap()); + }); + + for column in &self.y_column_name { + if !columns.contains_key(column) { + error!( + "Column `{}` not found. Did you pass the correct `y_column_name`?", + column + ) + } + } + + // We have to pull this analysis data into Rust as opposed to using Postgres + // json_build_object(...), because Postgres functions have a limit of 100 arguments. + // Any table that has more than 10 columns will exceed the Postgres limit since we + // calculate 10 statistics per column. + let mut stats = vec![r#"count(*)::FLOAT4 AS "samples""#.to_string()]; + let mut fields = vec!["samples".to_string()]; + for (column, data_type) in &columns { + match data_type.as_str() { + "real" | "double precision" | "smallint" | "integer" | "bigint" | "boolean" => { + let column = column.to_string(); + let quoted_column = match data_type.as_str() { + "boolean" => format!(r#""{}"::INT"#, column), + _ => format!(r#""{}""#, column), + }; + stats.push(format!(r#"min({quoted_column})::FLOAT4 AS "{column}_min""#)); + stats.push(format!(r#"max({quoted_column})::FLOAT4 AS "{column}_max""#)); + stats.push(format!( + r#"avg({quoted_column})::FLOAT4 AS "{column}_mean""# + )); + stats.push(format!( + r#"stddev({quoted_column})::FLOAT4 AS "{column}_stddev""# + )); + stats.push(format!(r#"percentile_disc(0.25) within group (order by {quoted_column})::FLOAT4 AS "{column}_p25""#)); + stats.push(format!(r#"percentile_disc(0.5) within group (order by {quoted_column})::FLOAT4 AS "{column}_p50""#)); + stats.push(format!(r#"percentile_disc(0.75) within group (order by {quoted_column})::FLOAT4 AS "{column}_p75""#)); + stats.push(format!( + r#"count({quoted_column})::FLOAT4 AS "{column}_count""# + )); + stats.push(format!( + r#"count(distinct {quoted_column})::FLOAT4 AS "{column}_distinct""# + )); + stats.push(format!( + r#"sum(({quoted_column} IS NULL)::INT)::FLOAT4 AS "{column}_nulls""# + )); + fields.push(format!("{column}_min")); + fields.push(format!("{column}_max")); + fields.push(format!("{column}_mean")); + fields.push(format!("{column}_stddev")); + fields.push(format!("{column}_p25")); + fields.push(format!("{column}_p50")); + fields.push(format!("{column}_p75")); + fields.push(format!("{column}_count")); + fields.push(format!("{column}_distinct")); + fields.push(format!("{column}_nulls")); + } + &_ => {} + } + } + + let stats = stats.join(","); + let sql = format!(r#"SELECT {stats} FROM "pgml_rust"."snapshot_{}""#, self.id); + let result = client.select(&sql, Some(1), None).first(); + let mut analysis = HashMap::new(); + for (i, field) in fields.iter().enumerate() { + analysis.insert( + field.to_owned(), + result + .get_datum::((i + 1).try_into().unwrap()) + .unwrap(), + ); + } + let analysis_datum = JsonB(json!(analysis.clone())); + let column_datum = JsonB(json!(columns.clone())); + self.analysis = Some(JsonB(json!(analysis))); + self.columns = Some(JsonB(json!(columns))); + client.select("UPDATE pgml_rust.snapshots SET status = 'complete', analysis = $1, columns = $2 WHERE id = $3", Some(1), Some(vec![ + (PgBuiltInOids::JSONBOID.oid(), analysis_datum.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), column_datum.into_datum()), + (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), + ])); + + Ok(Some(1)) + }); + } + + pub fn dataset(&self) -> Dataset { + let mut data = None; + Spi::connect(|client| { + let json: &serde_json::Value = &self.columns.as_ref().unwrap().0; + let feature_columns = json + .as_object() + .unwrap() + .keys() + .filter_map(|column| match self.y_column_name.contains(column) { + true => None, + false => Some(format!("{}::FLOAT4", column)), + }) + .collect::>(); + let label_columns = self + .y_column_name + .iter() + .map(|column| format!("{}::FLOAT4", column)) + .collect::>(); + + let sql = format!( + "SELECT {}, {} FROM {}", + feature_columns.join(", "), + label_columns.join(", "), + self.snapshot_name() + ); + + info!("Fetching data: {}", sql); + let result = client.select(&sql, None, None); + let mut x = Vec::with_capacity(result.len() * feature_columns.len()); + let mut y = Vec::with_capacity(result.len() * label_columns.len()); + result.for_each(|row| { + // Postgres Arrays arrays are 1 indexed and so are SPI tuples... + for i in 1..feature_columns.len() + 1 { + x.push(row[i].value::().unwrap()); + } + for j in feature_columns.len() + 1..feature_columns.len() + label_columns.len() + 1 + { + y.push(row[j].value::().unwrap()); + } + }); + let num_rows = x.len() / feature_columns.len(); + let num_test_rows = if self.test_size > 1.0 { + self.test_size as usize + } else { + (num_rows as f32 * self.test_size).round() as usize + }; + let num_train_rows = num_rows - num_test_rows; + if num_train_rows <= 0 { + error!( + "test_size = {} is too large. There are only {} samples.", + num_test_rows, num_rows + ); + } + info!( + "got features {:?} labels {:?} rows {:?}", + feature_columns.len(), + label_columns.len(), + num_rows + ); + data = Some(Dataset { + x: x, + y: y, + num_features: feature_columns.len(), + num_labels: label_columns.len(), + num_rows: num_rows, + num_test_rows: num_test_rows, + num_train_rows: num_train_rows, + }); + + Ok(Some(())) + }); + + data.unwrap() + } + + fn snapshot_name(&self) -> String { + format!("pgml_rust.snapshot_{}", self.id) + } +} diff --git a/pgml-extension/pgml_rust/src/orm/strategy.rs b/pgml-extension/pgml_rust/src/orm/strategy.rs new file mode 100644 index 000000000..d4bf493e6 --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/strategy.rs @@ -0,0 +1,33 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Strategy { + best_score, + most_recent, + rollback, +} + +impl std::str::FromStr for Strategy { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "best_score" => Ok(Strategy::best_score), + "most_recent" => Ok(Strategy::most_recent), + "rollback" => Ok(Strategy::rollback), + _ => Err(()), + } + } +} + +impl std::string::ToString for Strategy { + fn to_string(&self) -> String { + match *self { + Strategy::best_score => "best_score".to_string(), + Strategy::most_recent => "most_recent".to_string(), + Strategy::rollback => "rollback".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/orm/task.rs b/pgml-extension/pgml_rust/src/orm/task.rs new file mode 100644 index 000000000..e2027f2ec --- /dev/null +++ b/pgml-extension/pgml_rust/src/orm/task.rs @@ -0,0 +1,30 @@ +use pgx::*; +use serde::Deserialize; + +#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)] +#[allow(non_camel_case_types)] +pub enum Task { + regression, + classification, +} + +impl std::str::FromStr for Task { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "regression" => Ok(Task::regression), + "classification" => Ok(Task::classification), + _ => Err(()), + } + } +} + +impl std::string::ToString for Task { + fn to_string(&self) -> String { + match *self { + Task::regression => "regression".to_string(), + Task::classification => "classification".to_string(), + } + } +} diff --git a/pgml-extension/pgml_rust/src/vectors.rs b/pgml-extension/pgml_rust/src/vectors.rs index 31b2af077..411f0b7eb 100644 --- a/pgml-extension/pgml_rust/src/vectors.rs +++ b/pgml-extension/pgml_rust/src/vectors.rs @@ -1,476 +1,635 @@ use pgx::*; +#[pg_extern(immutable, parallel_safe, strict, name = "add")] +fn add_scalar_s(vector: Vec, addend: f32) -> Vec { + vector.as_slice().iter().map(|a| a + addend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "add")] +fn add_scalar_d(vector: Vec, addend: f64) -> Vec { + vector.as_slice().iter().map(|a| a + addend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "subtract")] +fn subtract_scalar_s(vector: Vec, subtahend: f32) -> Vec { + vector.as_slice().iter().map(|a| a - subtahend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "subtract")] +fn subtract_scalar_d(vector: Vec, subtahend: f64) -> Vec { + vector.as_slice().iter().map(|a| a - subtahend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "multiply")] +fn multiply_scalar_s(vector: Vec, multiplicand: f32) -> Vec { + vector.as_slice().iter().map(|a| a * multiplicand).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "multiply")] +fn multiply_scalar_d(vector: Vec, multiplicand: f64) -> Vec { + vector.as_slice().iter().map(|a| a * multiplicand).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "divide")] +fn divide_scalar_s(vector: Vec, dividend: f32) -> Vec { + vector.as_slice().iter().map(|a| a / dividend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "divide")] +fn divide_scalar_d(vector: Vec, dividend: f64) -> Vec { + vector.as_slice().iter().map(|a| a / dividend).collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "add")] +fn add_vector_s(vector: Vec, addend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(addend.as_slice().iter()) + .map(|(a, b)| a + b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "add")] +fn add_vector_d(vector: Vec, addend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(addend.as_slice().iter()) + .map(|(a, b)| a + b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "subtract")] +fn subtract_vector_s(vector: Vec, subtahend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(subtahend.as_slice().iter()) + .map(|(a, b)| a - b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "subtract")] +fn subtract_vector_d(vector: Vec, subtahend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(subtahend.as_slice().iter()) + .map(|(a, b)| a - b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "multiply")] +fn multiply_vector_s(vector: Vec, multiplicand: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(multiplicand.as_slice().iter()) + .map(|(a, b)| a * b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "multiply")] +fn multiply_vector_d(vector: Vec, multiplicand: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(multiplicand.as_slice().iter()) + .map(|(a, b)| a * b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "divide")] +fn divide_vector_s(vector: Vec, dividend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(dividend.as_slice().iter()) + .map(|(a, b)| a / b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "divide")] +fn divide_vector_d(vector: Vec, dividend: Vec) -> Vec { + vector + .as_slice() + .iter() + .zip(dividend.as_slice().iter()) + .map(|(a, b)| a / b) + .collect() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] +fn norm_l0_s(vector: Vec) -> f32 { + vector + .as_slice() + .iter() + .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) + .sum() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] +fn norm_l0_d(vector: Vec) -> f64 { + vector + .as_slice() + .iter() + .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) + .sum() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] +fn norm_l1_s(vector: Vec) -> f32 { + unsafe { blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] +fn norm_l1_d(vector: Vec) -> f64 { + unsafe { blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] +fn norm_l2_s(vector: Vec) -> f32 { + unsafe { blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] +fn norm_l2_d(vector: Vec) -> f64 { + unsafe { blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] +fn norm_max_s(vector: Vec) -> f32 { + unsafe { + let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + vector[index - 1].abs() + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] +fn norm_max_d(vector: Vec) -> f64 { + unsafe { + let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + vector[index - 1].abs() + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] +fn normalize_l1_s(vector: Vec) -> Vec { + let norm: f32; + unsafe { + norm = blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); + } + divide_scalar_s(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] +fn normalize_l1_d(vector: Vec) -> Vec { + let norm: f64; + unsafe { + norm = blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); + } + divide_scalar_d(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] +fn normalize_l2_s(vector: Vec) -> Vec { + let norm: f32; + unsafe { + norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + } + divide_scalar_s(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] +fn normalize_l2_d(vector: Vec) -> Vec { + let norm: f64; + unsafe { + norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + } + divide_scalar_d(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] +fn normalize_max_s(vector: Vec) -> Vec { + let norm; + unsafe { + let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + norm = vector[index - 1].abs(); + } + divide_scalar_s(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] +fn normalize_max_d(vector: Vec) -> Vec { + let norm; + unsafe { + let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + norm = vector[index - 1].abs(); + } + divide_scalar_d(vector, norm) +} + +#[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] +fn distance_l1_s(vector: Vec, other: Vec) -> f32 { + vector + .as_slice() + .iter() + .zip(other.as_slice().iter()) + .map(|(a, b)| (a - b).abs()) + .sum() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] +fn distance_l1_d(vector: Vec, other: Vec) -> f64 { + vector + .as_slice() + .iter() + .zip(other.as_slice().iter()) + .map(|(a, b)| (a - b).abs()) + .sum() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] +fn distance_l2_s(vector: Vec, other: Vec) -> f32 { + vector + .as_slice() + .iter() + .zip(other.as_slice().iter()) + .map(|(a, b)| (a - b).powf(2.0)) + .sum::() + .sqrt() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] +fn distance_l2_d(vector: Vec, other: Vec) -> f64 { + vector + .as_slice() + .iter() + .zip(other.as_slice().iter()) + .map(|(a, b)| (a - b).powf(2.0)) + .sum::() + .sqrt() +} + +#[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] +fn dot_product_s(vector: Vec, other: Vec) -> f32 { + unsafe { + blas::sdot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ) + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] +fn dot_product_d(vector: Vec, other: Vec) -> f64 { + unsafe { + blas::ddot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ) + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] +fn cosine_similarity_s(vector: Vec, other: Vec) -> f32 { + unsafe { + let dot = blas::sdot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ); + let a_norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let b_norm = blas::snrm2(other.len().try_into().unwrap(), other.as_slice(), 1); + dot / (a_norm * b_norm) + } +} + +#[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] +fn cosine_similarity_d(vector: Vec, other: Vec) -> f64 { + unsafe { + let dot = blas::ddot( + vector.len().try_into().unwrap(), + vector.as_slice(), + 1, + other.as_slice(), + 1, + ); + let a_norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let b_norm = blas::dnrm2(other.len().try_into().unwrap(), other.as_slice(), 1); + dot / (a_norm * b_norm) + } +} + +#[cfg(any(test, feature = "pg_test"))] #[pg_schema] -mod pgml { +mod tests { use super::*; - #[pg_extern(immutable, parallel_safe, strict, name="add")] - fn add_scalar_s(vector: Vec, addend: f32) -> Vec { - vector.as_slice().iter().map(|a| a + addend).collect() + #[pg_test] + fn test_add_scalar_s() { + assert_eq!( + add_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), + [2.0, 3.0, 4.0].to_vec() + ) + } + + #[pg_test] + fn test_add_scalar_d() { + assert_eq!( + add_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), + [2.0, 3.0, 4.0].to_vec() + ) + } + + #[pg_test] + fn test_subtract_scalar_s() { + assert_eq!( + subtract_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), + [0.0, 1.0, 2.0].to_vec() + ) + } + + #[pg_test] + fn test_subtract_scalar_d() { + assert_eq!( + subtract_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), + [0.0, 1.0, 2.0].to_vec() + ) + } + + #[pg_test] + fn test_multiply_scalar_s() { + assert_eq!( + multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), + [2.0, 4.0, 6.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="add")] - fn add_scalar_d(vector: Vec, addend: f64) -> Vec { - vector.as_slice().iter().map(|a| a + addend).collect() + #[pg_test] + fn test_multiply_scalar_d() { + assert_eq!( + multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), + [2.0, 4.0, 6.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="subtract")] - fn subtract_scalar_s(vector: Vec, subtahend: f32) -> Vec { - vector.as_slice().iter().map(|a| a - subtahend).collect() + #[pg_test] + fn test_divide_scalar_s() { + assert_eq!( + divide_scalar_s([2.0, 4.0, 6.0].to_vec(), 2.0), + [1.0, 2.0, 3.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="subtract")] - fn subtract_scalar_d(vector: Vec, subtahend: f64) -> Vec { - vector.as_slice().iter().map(|a| a - subtahend).collect() + #[pg_test] + fn test_divide_scalar_d() { + assert_eq!( + divide_scalar_d([2.0, 4.0, 6.0].to_vec(), 2.0), + [1.0, 2.0, 3.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="multiply")] - fn multiply_scalar_s(vector: Vec, multiplicand: f32) -> Vec { - vector.as_slice().iter().map(|a| a * multiplicand).collect() + #[pg_test] + fn test_add_vector_s() { + assert_eq!( + add_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [2.0, 4.0, 6.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="multiply")] - fn multiply_scalar_d(vector: Vec, multiplicand: f64) -> Vec { - vector.as_slice().iter().map(|a| a * multiplicand).collect() + #[pg_test] + fn test_add_vector_d() { + assert_eq!( + add_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [2.0, 4.0, 6.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="divide")] - fn divide_scalar_s(vector: Vec, dividend: f32) -> Vec { - vector.as_slice().iter().map(|a| a / dividend).collect() + #[pg_test] + fn test_subtract_vector_s() { + assert_eq!( + subtract_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [0.0, 0.0, 0.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="divide")] - fn divide_scalar_d(vector: Vec, dividend: f64) -> Vec { - vector.as_slice().iter().map(|a| a / dividend).collect() + #[pg_test] + fn test_subtract_vector_d() { + assert_eq!( + subtract_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [0.0, 0.0, 0.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="add")] - fn add_vector_s(vector: Vec, addend: Vec) -> Vec { - vector.as_slice().iter() - .zip(addend.as_slice().iter()) - .map(|(a, b)| a + b ).collect() + #[pg_test] + fn test_multiply_vector_s() { + assert_eq!( + multiply_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 4.0, 9.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="add")] - fn add_vector_d(vector: Vec, addend: Vec) -> Vec { - vector.as_slice().iter() - .zip(addend.as_slice().iter()) - .map(|(a, b)| a + b ).collect() + #[pg_test] + fn test_multiply_vector_d() { + assert_eq!( + multiply_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 4.0, 9.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="subtract")] - fn subtract_vector_s(vector: Vec, subtahend: Vec) -> Vec { - vector.as_slice().iter() - .zip(subtahend.as_slice().iter()) - .map(|(a, b)| a - b ).collect() + #[pg_test] + fn test_divide_vector_s() { + assert_eq!( + divide_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 1.0, 1.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="subtract")] - fn subtract_vector_d(vector: Vec, subtahend: Vec) -> Vec { - vector.as_slice().iter() - .zip(subtahend.as_slice().iter()) - .map(|(a, b)| a - b ).collect() + #[pg_test] + fn test_divide_vector_d() { + assert_eq!( + divide_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + [1.0, 1.0, 1.0].to_vec() + ) } - #[pg_extern(immutable, parallel_safe, strict, name="multiply")] - fn multiply_vector_s(vector: Vec, multiplicand: Vec) -> Vec { - vector.as_slice().iter() - .zip(multiplicand.as_slice().iter()) - .map(|(a, b)| a * b ).collect() + #[pg_test] + fn test_norm_l0_s() { + assert_eq!(norm_l0_s([1.0, 2.0, 3.0].to_vec()), 3.0) } - #[pg_extern(immutable, parallel_safe, strict, name="multiply")] - fn multiply_vector_d(vector: Vec, multiplicand: Vec) -> Vec { - vector.as_slice().iter() - .zip(multiplicand.as_slice().iter()) - .map(|(a, b)| a * b ).collect() + #[pg_test] + fn test_norm_l0_d() { + assert_eq!(norm_l0_d([1.0, 2.0, 3.0].to_vec()), 3.0) } - #[pg_extern(immutable, parallel_safe, strict, name="divide")] - fn divide_vector_s(vector: Vec, dividend: Vec) -> Vec { - vector.as_slice().iter() - .zip(dividend.as_slice().iter()) - .map(|(a, b)| a / b ).collect() + #[pg_test] + fn test_norm_l1_s() { + assert_eq!(norm_l1_s([1.0, 2.0, 3.0].to_vec()), 6.0) } - #[pg_extern(immutable, parallel_safe, strict, name="divide")] - fn divide_vector_d(vector: Vec, dividend: Vec) -> Vec { - vector.as_slice().iter() - .zip(dividend.as_slice().iter()) - .map(|(a, b)| a / b ).collect() + #[pg_test] + fn test_norm_l1_d() { + assert_eq!(norm_l1_d([1.0, 2.0, 3.0].to_vec()), 6.0) } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l0")] - fn norm_l0_s(vector: Vec) -> f32 { - vector.as_slice().iter().map(|a| if *a == 0.0 { 0.0 } else { 1.0 } ).sum() + #[pg_test] + fn test_norm_l2_s() { + assert_eq!(norm_l2_s([1.0, 2.0, 3.0].to_vec()), 3.7416575); } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l0")] - fn norm_l0_d(vector: Vec) -> f64 { - vector.as_slice().iter().map(|a| if *a == 0.0 { 0.0 } else { 1.0 } ).sum() + #[pg_test] + fn test_norm_l2_d() { + assert_eq!(norm_l2_d([1.0, 2.0, 3.0].to_vec()), 3.7416573867739413); } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l1")] - fn norm_l1_s(vector: Vec) -> f32 { - unsafe { - blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) - } + #[pg_test] + fn test_norm_max_s() { + assert_eq!(norm_max_s([1.0, 2.0, 3.0].to_vec()), 3.0); + assert_eq!(norm_max_s([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l1")] - fn norm_l1_d(vector: Vec) -> f64 { - unsafe { - blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) - } + #[pg_test] + fn test_norm_max_d() { + assert_eq!(norm_max_d([1.0, 2.0, 3.0].to_vec()), 3.0); + assert_eq!(norm_max_d([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l2")] - fn norm_l2_s(vector: Vec) -> f32 { - unsafe { - blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) - } + #[pg_test] + fn test_normalize_l1_s() { + assert_eq!( + normalize_l1_s([1.0, 2.0, 3.0].to_vec()), + [0.16666667, 0.33333334, 0.5].to_vec() + ); } - #[pg_extern(immutable, parallel_safe, strict, name="norm_l2")] - fn norm_l2_d(vector: Vec) -> f64 { - unsafe { - blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) - } + #[pg_test] + fn test_normalize_l1_d() { + assert_eq!( + normalize_l1_d([1.0, 2.0, 3.0].to_vec()), + [0.16666666666666666, 0.3333333333333333, 0.5].to_vec() + ); } - #[pg_extern(immutable, parallel_safe, strict, name="norm_max")] - fn norm_max_s(vector: Vec) -> f32 { - unsafe { - let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); - vector[index - 1].abs() - } + #[pg_test] + fn test_normalize_l2_s() { + assert_eq!( + normalize_l2_s([1.0, 2.0, 3.0].to_vec()), + [0.26726124, 0.5345225, 0.8017837].to_vec() + ); } - #[pg_extern(immutable, parallel_safe, strict, name="norm_max")] - fn norm_max_d(vector: Vec) -> f64 { - unsafe { - let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); - vector[index - 1].abs() - } + #[pg_test] + fn test_normalize_l2_d() { + assert_eq!( + normalize_l2_d([1.0, 2.0, 3.0].to_vec()), + [0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec() + ); } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_l1")] - fn normalize_l1_s(vector: Vec) -> Vec { - let norm: f32; - unsafe { - norm = blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); - } - divide_scalar_s(vector, norm) + #[pg_test] + fn test_normalize_max_s() { + assert_eq!( + normalize_max_s([1.0, 2.0, 3.0].to_vec()), + [0.33333334, 0.6666667, 1.0].to_vec() + ); } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_l1")] - fn normalize_l1_d(vector: Vec) -> Vec { - let norm: f64; - unsafe { - norm = blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); - } - divide_scalar_d(vector, norm) + #[pg_test] + fn test_normalize_max_d() { + assert_eq!( + normalize_max_d([1.0, 2.0, 3.0].to_vec()), + [0.3333333333333333, 0.6666666666666666, 1.0].to_vec() + ); } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_l2")] - fn normalize_l2_s(vector: Vec) -> Vec { - let norm: f32; - unsafe { - norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - } - divide_scalar_s(vector, norm) + #[pg_test] + fn test_distance_l1_s() { + assert_eq!( + distance_l1_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); + } + + #[pg_test] + fn test_distance_l1_d() { + assert_eq!( + distance_l1_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); + } + + #[pg_test] + fn test_distance_l2_s() { + assert_eq!( + distance_l2_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); } - #[pg_extern(immutable, parallel_safe, strict, name="normalize_l2")] - fn normalize_l2_d(vector: Vec) -> Vec { - let norm: f64; - unsafe { - norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - } - divide_scalar_d(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name="normalize_max")] - fn normalize_max_s(vector: Vec) -> Vec { - let norm; - unsafe { - let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); - norm = vector[index - 1].abs(); - } - divide_scalar_s(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name="normalize_max")] - fn normalize_max_d(vector: Vec) -> Vec { - let norm; - unsafe { - let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); - norm = vector[index - 1].abs(); - } - divide_scalar_d(vector, norm) - } - - #[pg_extern(immutable, parallel_safe, strict, name="distance_l1")] - fn distance_l1_s(vector: Vec, other: Vec) -> f32 { - vector.as_slice().iter() - .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).abs() ).sum() - } - - #[pg_extern(immutable, parallel_safe, strict, name="distance_l1")] - fn distance_l1_d(vector: Vec, other: Vec) -> f64 { - vector.as_slice().iter() - .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).abs() ).sum() - } - - #[pg_extern(immutable, parallel_safe, strict, name="distance_l2")] - fn distance_l2_s(vector: Vec, other: Vec) -> f32 { - vector.as_slice().iter() - .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).powf(2.0) ).sum::().sqrt() - } - - #[pg_extern(immutable, parallel_safe, strict, name="distance_l2")] - fn distance_l2_d(vector: Vec, other: Vec) -> f64 { - vector.as_slice().iter() - .zip(other.as_slice().iter()) - .map(|(a, b)| (a - b).powf(2.0) ).sum::().sqrt() - } - - #[pg_extern(immutable, parallel_safe, strict, name="dot_product")] - fn dot_product_s(vector: Vec, other: Vec) -> f32 { - unsafe { - blas::sdot(vector.len().try_into().unwrap(), vector.as_slice(), 1, other.as_slice(), 1) - } - } - - #[pg_extern(immutable, parallel_safe, strict, name="dot_product")] - fn dot_product_d(vector: Vec, other: Vec) -> f64 { - unsafe { - blas::ddot(vector.len().try_into().unwrap(), vector.as_slice(), 1, other.as_slice(), 1) - } - } - - #[pg_extern(immutable, parallel_safe, strict, name="cosine_similarity")] - fn cosine_similarity_s(vector: Vec, other: Vec) -> f32 { - unsafe { - let dot = blas::sdot(vector.len().try_into().unwrap(), vector.as_slice(), 1, other.as_slice(), 1); - let a_norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - let b_norm = blas::snrm2(other.len().try_into().unwrap(), other.as_slice(), 1); - dot / (a_norm * b_norm) - } - } - - #[pg_extern(immutable, parallel_safe, strict, name="cosine_similarity")] - fn cosine_similarity_d(vector: Vec, other: Vec) -> f64 { - unsafe { - let dot = blas::ddot(vector.len().try_into().unwrap(), vector.as_slice(), 1, other.as_slice(), 1); - let a_norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - let b_norm = blas::dnrm2(other.len().try_into().unwrap(), other.as_slice(), 1); - dot / (a_norm * b_norm) - } - } - - #[cfg(any(test, feature = "pg_test"))] - #[pg_schema] - mod tests { - use super::*; - - #[pg_test] - fn test_add_scalar_s() { - assert_eq!(add_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), [2.0, 3.0, 4.0].to_vec()) - } - - #[pg_test] - fn test_add_scalar_d() { - assert_eq!(add_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), [2.0, 3.0, 4.0].to_vec()) - } - - #[pg_test] - fn test_subtract_scalar_s() { - assert_eq!(subtract_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), [0.0, 1.0, 2.0].to_vec()) - } - - #[pg_test] - fn test_subtract_scalar_d() { - assert_eq!(subtract_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), [0.0, 1.0, 2.0].to_vec()) - } - - #[pg_test] - fn test_multiply_scalar_s() { - assert_eq!(multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), [2.0, 4.0, 6.0].to_vec()) - } - - #[pg_test] - fn test_multiply_scalar_d() { - assert_eq!(multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), [2.0, 4.0, 6.0].to_vec()) - } - - #[pg_test] - fn test_divide_scalar_s() { - assert_eq!(divide_scalar_s([2.0, 4.0, 6.0].to_vec(), 2.0), [1.0, 2.0, 3.0].to_vec()) - } - - #[pg_test] - fn test_divide_scalar_d() { - assert_eq!(divide_scalar_d([2.0, 4.0, 6.0].to_vec(), 2.0), [1.0, 2.0, 3.0].to_vec()) - } - - #[pg_test] - fn test_add_vector_s() { - assert_eq!(add_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [2.0, 4.0, 6.0].to_vec()) - } - - #[pg_test] - fn test_add_vector_d() { - assert_eq!(add_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [2.0, 4.0, 6.0].to_vec()) - } - - #[pg_test] - fn test_subtract_vector_s() { - assert_eq!(subtract_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [0.0, 0.0, 0.0].to_vec()) - } - - #[pg_test] - fn test_subtract_vector_d() { - assert_eq!(subtract_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [0.0, 0.0, 0.0].to_vec()) - } - - #[pg_test] - fn test_multiply_vector_s() { - assert_eq!(multiply_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [1.0, 4.0, 9.0].to_vec()) - } - - #[pg_test] - fn test_multiply_vector_d() { - assert_eq!(multiply_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [1.0, 4.0, 9.0].to_vec()) - } - - #[pg_test] - fn test_divide_vector_s() { - assert_eq!(divide_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [1.0, 1.0, 1.0].to_vec()) - } - - #[pg_test] - fn test_divide_vector_d() { - assert_eq!(divide_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), [1.0, 1.0, 1.0].to_vec()) - } - - #[pg_test] - fn test_norm_l0_s() { - assert_eq!(norm_l0_s([1.0, 2.0, 3.0].to_vec()), 3.0) - } - - #[pg_test] - fn test_norm_l0_d() { - assert_eq!(norm_l0_d([1.0, 2.0, 3.0].to_vec()), 3.0) - } - - #[pg_test] - fn test_norm_l1_s() { - assert_eq!(norm_l1_s([1.0, 2.0, 3.0].to_vec()), 6.0) - } - - #[pg_test] - fn test_norm_l1_d() { - assert_eq!(norm_l1_d([1.0, 2.0, 3.0].to_vec()), 6.0) - } - - #[pg_test] - fn test_norm_l2_s() { - assert_eq!(norm_l2_s([1.0, 2.0, 3.0].to_vec()), 3.7416575); - } - - #[pg_test] - fn test_norm_l2_d() { - assert_eq!(norm_l2_d([1.0, 2.0, 3.0].to_vec()), 3.7416573867739413); - } - - #[pg_test] - fn test_norm_max_s() { - assert_eq!(norm_max_s([1.0, 2.0, 3.0].to_vec()), 3.0); - assert_eq!(norm_max_s([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); - } - - #[pg_test] - fn test_norm_max_d() { - assert_eq!(norm_max_d([1.0, 2.0, 3.0].to_vec()), 3.0); - assert_eq!(norm_max_d([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); - } - - #[pg_test] - fn test_normalize_l1_s() { - assert_eq!(normalize_l1_s([1.0, 2.0, 3.0].to_vec()), [0.16666667, 0.33333334, 0.5].to_vec()); - } - - #[pg_test] - fn test_normalize_l1_d() { - assert_eq!(normalize_l1_d([1.0, 2.0, 3.0].to_vec()), [0.16666666666666666, 0.3333333333333333, 0.5].to_vec()); - } - - #[pg_test] - fn test_normalize_l2_s() { - assert_eq!(normalize_l2_s([1.0, 2.0, 3.0].to_vec()), [0.26726124, 0.5345225, 0.8017837].to_vec()); - } - - #[pg_test] - fn test_normalize_l2_d() { - assert_eq!(normalize_l2_d([1.0, 2.0, 3.0].to_vec()), [0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec()); - } - - #[pg_test] - fn test_normalize_max_s() { - assert_eq!(normalize_max_s([1.0, 2.0, 3.0].to_vec()), [0.33333334, 0.6666667, 1.0].to_vec()); - } - - #[pg_test] - fn test_normalize_max_d() { - assert_eq!(normalize_max_d([1.0, 2.0, 3.0].to_vec()), [0.3333333333333333, 0.6666666666666666, 1.0].to_vec()); - } - - #[pg_test] - fn test_distance_l1_s() { - assert_eq!(distance_l1_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.0); - } - - #[pg_test] - fn test_distance_l1_d() { - assert_eq!(distance_l1_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.0); - } - - #[pg_test] - fn test_distance_l2_s() { - assert_eq!(distance_l2_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.0); - } - - #[pg_test] - fn test_distance_l2_d() { - assert_eq!(distance_l2_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.0); - } - - #[pg_test] - fn test_dot_product_s() { - assert_eq!(dot_product_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 14.0); - assert_eq!(dot_product_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), 20.0); - } - - #[pg_test] - fn test_dot_product_d() { - assert_eq!(dot_product_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 14.0); - assert_eq!(dot_product_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), 20.0); - } - - #[pg_test] - fn test_cosine_similarity_s() { - assert_eq!(cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 0.99999994); - assert_eq!(cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), 0.9925833); - } - - #[pg_test] - fn test_cosine_similarity_d() { - assert_eq!(cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), 1.0); - assert_eq!(cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), 0.9925833339709303); - } + #[pg_test] + fn test_distance_l2_d() { + assert_eq!( + distance_l2_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.0 + ); + } + + #[pg_test] + fn test_dot_product_s() { + assert_eq!( + dot_product_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 14.0 + ); + assert_eq!( + dot_product_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 20.0 + ); + } + + #[pg_test] + fn test_dot_product_d() { + assert_eq!( + dot_product_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 14.0 + ); + assert_eq!( + dot_product_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 20.0 + ); + } + + #[pg_test] + fn test_cosine_similarity_s() { + assert_eq!( + cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 0.99999994 + ); + assert_eq!( + cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 0.9925833 + ); + } + + #[pg_test] + fn test_cosine_similarity_d() { + assert_eq!( + cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), + 1.0 + ); + assert_eq!( + cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), + 0.9925833339709303 + ); } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy