Skip to content

start integrating smartcore for common algos in rust #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Sep 10, 2022
Merged
11 changes: 8 additions & 3 deletions pgml-extension/pgml_rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@ pg14 = ["pgx/pg14", "pgx-tests/pg14" ]
pg_test = []

[dependencies]
pgx = "=0.4.5"
xgboost = { path = "rust-xgboost" }
rustlearn = "0.5"
pgx = "0.4.5"
once_cell = "1"
rand = "0.8"
xgboost = { path = "rust-xgboost" }
smartcore = { version = "0.2.0", features = ["serde", "ndarray-bindings"] }
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
blas = { version = "0.22.0" }
blas-src = { version = "0.8", features = ["openblas"] }
openblas-src = { version = "0.10", features = ["cblas", "system"] }
serde = { version = "1.0.2" }
serde_json = { version = "1.0.85" }
rmp-serde = { version = "1.1.0" }
typetag = "0.2"

[dev-dependencies]
pgx-tests = "=0.4.5"
Expand Down
3 changes: 2 additions & 1 deletion pgml-extension/pgml_rust/pgml_rust.control
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
comment = 'pgml_rust: Created by pgx'
comment = 'pgml_rust: Created by the PostgresML team'
default_version = '@CARGO_VERSION@'
module_pathname = '$libdir/pgml_rust'
relocatable = false
superuser = false
schema = 'pgml_rust'
141 changes: 130 additions & 11 deletions pgml-extension/pgml_rust/sql/schema.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
CREATE SCHEMA IF NOT EXISTS pgml_rust;

---
--- Track of updates to data
---
Expand Down Expand Up @@ -33,43 +31,164 @@ BEGIN
) THEN
NEW.updated_at := clock_timestamp();
END IF;
RETURN NEW;
RETURN new;
END;
$$
LANGUAGE plpgsql;


---
--- Projects organize work
---
CREATE TABLE IF NOT EXISTS pgml_rust.projects(
id BIGSERIAL PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
task TEXT NOT NULL,
name TEXT NOT NULL,
task pgml_rust.task NOT NULL,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp()
);
SELECT pgml_rust.auto_updated_at('pgml_rust.projects');
CREATE UNIQUE INDEX IF NOT EXISTS projects_name_idx ON pgml_rust.projects(name);


CREATE TABLE IF NOT EXISTS pgml_rust.models (
---
--- Snapshots freeze data for training
---
CREATE TABLE IF NOT EXISTS pgml_rust.snapshots(
id BIGSERIAL PRIMARY KEY,
project_id BIGINT NOT NULL REFERENCES pgml_rust.projects(id),
algorithm VARCHAR,
data BYTEA
relation_name TEXT NOT NULL,
y_column_name TEXT[] NOT NULL,
test_size FLOAT4 NOT NULL,
test_sampling pgml_rust.sampling NOT NULL,
status TEXT NOT NULL,
columns JSONB,
analysis JSONB,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp()
);
SELECT pgml_rust.auto_updated_at('pgml_rust.snapshots');


---
--- Deployments determine which model is live
--- Models save the learned parameters
---
CREATE TABLE IF NOT EXISTS pgml_rust.models(
id BIGSERIAL PRIMARY KEY,
project_id BIGINT NOT NULL,
snapshot_id BIGINT NOT NULL,
algorithm TEXT NOT NULL,
hyperparams JSONB NOT NULL,
status TEXT NOT NULL,
metrics JSONB,
search pgml_rust.search,
search_params JSONB NOT NULL,
search_args JSONB NOT NULL,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id),
CONSTRAINT snapshot_id_fk FOREIGN KEY(snapshot_id) REFERENCES pgml_rust.snapshots(id)
);
CREATE INDEX IF NOT EXISTS models_project_id_idx ON pgml_rust.models(project_id);
CREATE INDEX IF NOT EXISTS models_snapshot_id_idx ON pgml_rust.models(snapshot_id);
SELECT pgml_rust.auto_updated_at('pgml_rust.models');


---
--- Deployements determine which model is live
---
CREATE TABLE IF NOT EXISTS pgml_rust.deployments(
id BIGSERIAL PRIMARY KEY,
project_id BIGINT NOT NULL,
model_id BIGINT NOT NULL,
strategy TEXT NOT NULL,
strategy pgml_rust.strategy NOT NULL,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id),
CONSTRAINT model_id_fk FOREIGN KEY(model_id) REFERENCES pgml_rust.models(id)
);
CREATE INDEX IF NOT EXISTS deployments_project_id_created_at_idx ON pgml_rust.deployments(project_id);
CREATE INDEX IF NOT EXISTS deployments_model_id_created_at_idx ON pgml_rust.deployments(model_id);
SELECT pgml_rust.auto_updated_at('pgml_rust.deployments');

---
--- Distribute serialized models consistently for HA
---
CREATE TABLE IF NOT EXISTS pgml_rust.files(
id BIGSERIAL PRIMARY KEY,
model_id BIGINT NOT NULL,
path TEXT NOT NULL,
part INTEGER NOT NULL,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
data BYTEA NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS files_model_id_path_part_idx ON pgml_rust.files(model_id, path, part);
SELECT pgml_rust.auto_updated_at('pgml_rust.files');

---
--- Quick status check on the system.
---
DROP VIEW IF EXISTS pgml_rust.overview;
CREATE VIEW pgml_rust.overview AS
SELECT
p.name,
d.created_at AS deployed_at,
p.task,
m.algorithm,
s.relation_name,
s.y_column_name,
s.test_sampling,
s.test_size
FROM pgml_rust.projects p
INNER JOIN pgml_rust.models m ON p.id = m.project_id
INNER JOIN pgml_rust.deployments d ON d.project_id = p.id
AND d.model_id = m.id
INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id
ORDER BY d.created_at DESC;


---
--- List details of trained models.
---
DROP VIEW IF EXISTS pgml_rust.trained_models;
CREATE VIEW pgml_rust.trained_models AS
SELECT
m.id,
p.name,
p.task,
m.algorithm,
m.created_at,
s.test_sampling,
s.test_size,
d.model_id IS NOT NULL AS deployed
FROM pgml_rust.projects p
INNER JOIN pgml_rust.models m ON p.id = m.project_id
INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id
LEFT JOIN (
SELECT DISTINCT ON(project_id)
project_id, model_id, created_at
FROM pgml_rust.deployments
ORDER BY project_id, created_at desc
) d ON d.model_id = m.id
ORDER BY m.created_at DESC;


---
--- List details of deployed models.
---
DROP VIEW IF EXISTS pgml_rust.deployed_models;
CREATE VIEW pgml_rust.deployed_models AS
SELECT
m.id,
p.name,
p.task,
m.algorithm,
d.created_at as deployed_at
FROM pgml_rust.projects p
INNER JOIN (
SELECT DISTINCT ON(project_id)
project_id, model_id, created_at
FROM pgml_rust.deployments
ORDER BY project_id, created_at desc
) d ON d.project_id = p.id
INNER JOIN pgml_rust.models m ON m.id = d.model_id
ORDER BY p.name ASC;
104 changes: 104 additions & 0 deletions pgml-extension/pgml_rust/src/api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use pgx::*;

use crate::orm::Algorithm;
use crate::orm::Model;
use crate::orm::Project;
use crate::orm::Sampling;
use crate::orm::Search;
use crate::orm::Snapshot;
use crate::orm::Strategy;
use crate::orm::Task;

#[pg_extern]
fn train(
project_name: &str,
task: Option<default!(Task, "NULL")>,
relation_name: Option<default!(&str, "NULL")>,
y_column_name: Option<default!(&str, "NULL")>,
algorithm: default!(Algorithm, "'linear'"),
hyperparams: default!(JsonB, "'{}'"),
search: Option<default!(Search, "NULL")>,
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
test_sampling: default!(Sampling, "'last'"),
) {
let project = match Project::find_by_name(project_name) {
Some(project) => project,
None => Project::create(project_name, task.unwrap()),
};
if task.is_some() && task.unwrap() != project.task {
error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task);
}
let snapshot = match relation_name {
None => project.last_snapshot().expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."),
Some(relation_name) => Snapshot::create(relation_name, y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`"), test_size, test_sampling)
};

// # Default repeatable random state when possible
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
// hyperparams["random_state"] = 0

let model = Model::create(
&project,
&snapshot,
algorithm,
hyperparams,
search,
search_params,
search_args,
);

// TODO move deployment into a struct and only deploy if new model is better than old model
Spi::get_one_with_args::<i64>(
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id",
vec![
(PgBuiltInOids::INT8OID.oid(), project.id.into_datum()),
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
(PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()),
]
);
}

#[pg_extern]
fn predict(project_name: &str, features: Vec<f32>) -> f32 {
let estimator = crate::orm::estimator::find_deployed_estimator_by_project_name(project_name);
estimator.predict(features)
}

// #[pg_extern]
// fn return_table_example() -> impl std::Iterator<Item = (name!(id, Option<i64>), name!(title, Option<String>))> {
// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None)
// vec![tuple].into_iter()
// }

#[pg_extern]
fn create_snapshot(
relation_name: &str,
y_column_name: &str,
test_size: f32,
test_sampling: Sampling,
) -> i64 {
let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
info!("{:?}", snapshot);
snapshot.id
}

#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
use super::*;

#[pg_test]
fn test_project_lifecycle() {
assert_eq!(Project::create("test", Task::regression).id, 1);
assert_eq!(Project::find(1).id, 1);
}

#[pg_test]
fn test_snapshot_lifecycle() {
let snapshot = Snapshot::create("test", "column", 0.5, Sampling::last);
assert_eq!(snapshot.id, 1);
}
}
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy