Skip to content

Commit ce4c105

Browse files
authored
start integrating smartcore for common algos in rust (#301)
1 parent 77deb6e commit ce4c105

File tree

17 files changed

+2000
-764
lines changed

17 files changed

+2000
-764
lines changed

pgml-extension/pgml_rust/Cargo.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@ pg14 = ["pgx/pg14", "pgx-tests/pg14" ]
1616
pg_test = []
1717

1818
[dependencies]
19-
pgx = "=0.4.5"
20-
xgboost = { path = "rust-xgboost" }
21-
rustlearn = "0.5"
19+
pgx = "0.4.5"
2220
once_cell = "1"
2321
rand = "0.8"
22+
xgboost = { path = "rust-xgboost" }
23+
smartcore = { version = "0.2.0", features = ["serde", "ndarray-bindings"] }
24+
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
2425
blas = { version = "0.22.0" }
2526
blas-src = { version = "0.8", features = ["openblas"] }
2627
openblas-src = { version = "0.10", features = ["cblas", "system"] }
28+
serde = { version = "1.0.2" }
29+
serde_json = { version = "1.0.85" }
30+
rmp-serde = { version = "1.1.0" }
31+
typetag = "0.2"
2732

2833
[dev-dependencies]
2934
pgx-tests = "=0.4.5"
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
comment = 'pgml_rust: Created by pgx'
1+
comment = 'pgml_rust: Created by the PostgresML team'
22
default_version = '@CARGO_VERSION@'
33
module_pathname = '$libdir/pgml_rust'
44
relocatable = false
55
superuser = false
6+
schema = 'pgml_rust'

pgml-extension/pgml_rust/sql/schema.sql

Lines changed: 130 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
CREATE SCHEMA IF NOT EXISTS pgml_rust;
2-
31
---
42
--- Track of updates to data
53
---
@@ -33,43 +31,164 @@ BEGIN
3331
) THEN
3432
NEW.updated_at := clock_timestamp();
3533
END IF;
36-
RETURN NEW;
34+
RETURN new;
3735
END;
3836
$$
3937
LANGUAGE plpgsql;
4038

39+
4140
---
4241
--- Projects organize work
4342
---
4443
CREATE TABLE IF NOT EXISTS pgml_rust.projects(
4544
id BIGSERIAL PRIMARY KEY,
46-
name TEXT NOT NULL UNIQUE,
47-
task TEXT NOT NULL,
45+
name TEXT NOT NULL,
46+
task pgml_rust.task NOT NULL,
4847
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
4948
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp()
5049
);
5150
SELECT pgml_rust.auto_updated_at('pgml_rust.projects');
51+
CREATE UNIQUE INDEX IF NOT EXISTS projects_name_idx ON pgml_rust.projects(name);
5252

5353

54-
CREATE TABLE IF NOT EXISTS pgml_rust.models (
54+
---
55+
--- Snapshots freeze data for training
56+
---
57+
CREATE TABLE IF NOT EXISTS pgml_rust.snapshots(
5558
id BIGSERIAL PRIMARY KEY,
56-
project_id BIGINT NOT NULL REFERENCES pgml_rust.projects(id),
57-
algorithm VARCHAR,
58-
data BYTEA
59+
relation_name TEXT NOT NULL,
60+
y_column_name TEXT[] NOT NULL,
61+
test_size FLOAT4 NOT NULL,
62+
test_sampling pgml_rust.sampling NOT NULL,
63+
status TEXT NOT NULL,
64+
columns JSONB,
65+
analysis JSONB,
66+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
67+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp()
5968
);
69+
SELECT pgml_rust.auto_updated_at('pgml_rust.snapshots');
70+
6071

6172
---
62-
--- Deployments determine which model is live
73+
--- Models save the learned parameters
74+
---
75+
CREATE TABLE IF NOT EXISTS pgml_rust.models(
76+
id BIGSERIAL PRIMARY KEY,
77+
project_id BIGINT NOT NULL,
78+
snapshot_id BIGINT NOT NULL,
79+
algorithm TEXT NOT NULL,
80+
hyperparams JSONB NOT NULL,
81+
status TEXT NOT NULL,
82+
metrics JSONB,
83+
search pgml_rust.search,
84+
search_params JSONB NOT NULL,
85+
search_args JSONB NOT NULL,
86+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
87+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
88+
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id),
89+
CONSTRAINT snapshot_id_fk FOREIGN KEY(snapshot_id) REFERENCES pgml_rust.snapshots(id)
90+
);
91+
CREATE INDEX IF NOT EXISTS models_project_id_idx ON pgml_rust.models(project_id);
92+
CREATE INDEX IF NOT EXISTS models_snapshot_id_idx ON pgml_rust.models(snapshot_id);
93+
SELECT pgml_rust.auto_updated_at('pgml_rust.models');
94+
95+
96+
---
97+
--- Deployements determine which model is live
6398
---
6499
CREATE TABLE IF NOT EXISTS pgml_rust.deployments(
65100
id BIGSERIAL PRIMARY KEY,
66101
project_id BIGINT NOT NULL,
67102
model_id BIGINT NOT NULL,
68-
strategy TEXT NOT NULL,
103+
strategy pgml_rust.strategy NOT NULL,
69104
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
70105
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml_rust.projects(id),
71106
CONSTRAINT model_id_fk FOREIGN KEY(model_id) REFERENCES pgml_rust.models(id)
72107
);
73108
CREATE INDEX IF NOT EXISTS deployments_project_id_created_at_idx ON pgml_rust.deployments(project_id);
74109
CREATE INDEX IF NOT EXISTS deployments_model_id_created_at_idx ON pgml_rust.deployments(model_id);
75110
SELECT pgml_rust.auto_updated_at('pgml_rust.deployments');
111+
112+
---
113+
--- Distribute serialized models consistently for HA
114+
---
115+
CREATE TABLE IF NOT EXISTS pgml_rust.files(
116+
id BIGSERIAL PRIMARY KEY,
117+
model_id BIGINT NOT NULL,
118+
path TEXT NOT NULL,
119+
part INTEGER NOT NULL,
120+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
121+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
122+
data BYTEA NOT NULL
123+
);
124+
CREATE UNIQUE INDEX IF NOT EXISTS files_model_id_path_part_idx ON pgml_rust.files(model_id, path, part);
125+
SELECT pgml_rust.auto_updated_at('pgml_rust.files');
126+
127+
---
128+
--- Quick status check on the system.
129+
---
130+
DROP VIEW IF EXISTS pgml_rust.overview;
131+
CREATE VIEW pgml_rust.overview AS
132+
SELECT
133+
p.name,
134+
d.created_at AS deployed_at,
135+
p.task,
136+
m.algorithm,
137+
s.relation_name,
138+
s.y_column_name,
139+
s.test_sampling,
140+
s.test_size
141+
FROM pgml_rust.projects p
142+
INNER JOIN pgml_rust.models m ON p.id = m.project_id
143+
INNER JOIN pgml_rust.deployments d ON d.project_id = p.id
144+
AND d.model_id = m.id
145+
INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id
146+
ORDER BY d.created_at DESC;
147+
148+
149+
---
150+
--- List details of trained models.
151+
---
152+
DROP VIEW IF EXISTS pgml_rust.trained_models;
153+
CREATE VIEW pgml_rust.trained_models AS
154+
SELECT
155+
m.id,
156+
p.name,
157+
p.task,
158+
m.algorithm,
159+
m.created_at,
160+
s.test_sampling,
161+
s.test_size,
162+
d.model_id IS NOT NULL AS deployed
163+
FROM pgml_rust.projects p
164+
INNER JOIN pgml_rust.models m ON p.id = m.project_id
165+
INNER JOIN pgml_rust.snapshots s ON s.id = m.snapshot_id
166+
LEFT JOIN (
167+
SELECT DISTINCT ON(project_id)
168+
project_id, model_id, created_at
169+
FROM pgml_rust.deployments
170+
ORDER BY project_id, created_at desc
171+
) d ON d.model_id = m.id
172+
ORDER BY m.created_at DESC;
173+
174+
175+
---
176+
--- List details of deployed models.
177+
---
178+
DROP VIEW IF EXISTS pgml_rust.deployed_models;
179+
CREATE VIEW pgml_rust.deployed_models AS
180+
SELECT
181+
m.id,
182+
p.name,
183+
p.task,
184+
m.algorithm,
185+
d.created_at as deployed_at
186+
FROM pgml_rust.projects p
187+
INNER JOIN (
188+
SELECT DISTINCT ON(project_id)
189+
project_id, model_id, created_at
190+
FROM pgml_rust.deployments
191+
ORDER BY project_id, created_at desc
192+
) d ON d.project_id = p.id
193+
INNER JOIN pgml_rust.models m ON m.id = d.model_id
194+
ORDER BY p.name ASC;

pgml-extension/pgml_rust/src/api.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use pgx::*;
2+
3+
use crate::orm::Algorithm;
4+
use crate::orm::Model;
5+
use crate::orm::Project;
6+
use crate::orm::Sampling;
7+
use crate::orm::Search;
8+
use crate::orm::Snapshot;
9+
use crate::orm::Strategy;
10+
use crate::orm::Task;
11+
12+
#[pg_extern]
13+
fn train(
14+
project_name: &str,
15+
task: Option<default!(Task, "NULL")>,
16+
relation_name: Option<default!(&str, "NULL")>,
17+
y_column_name: Option<default!(&str, "NULL")>,
18+
algorithm: default!(Algorithm, "'linear'"),
19+
hyperparams: default!(JsonB, "'{}'"),
20+
search: Option<default!(Search, "NULL")>,
21+
search_params: default!(JsonB, "'{}'"),
22+
search_args: default!(JsonB, "'{}'"),
23+
test_size: default!(f32, 0.25),
24+
test_sampling: default!(Sampling, "'last'"),
25+
) {
26+
let project = match Project::find_by_name(project_name) {
27+
Some(project) => project,
28+
None => Project::create(project_name, task.unwrap()),
29+
};
30+
if task.is_some() && task.unwrap() != project.task {
31+
error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task);
32+
}
33+
let snapshot = match relation_name {
34+
None => project.last_snapshot().expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."),
35+
Some(relation_name) => Snapshot::create(relation_name, y_column_name.expect("You must pass a `y_column_name` when you pass a `relation_name`"), test_size, test_sampling)
36+
};
37+
38+
// # Default repeatable random state when possible
39+
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
40+
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
41+
// hyperparams["random_state"] = 0
42+
43+
let model = Model::create(
44+
&project,
45+
&snapshot,
46+
algorithm,
47+
hyperparams,
48+
search,
49+
search_params,
50+
search_args,
51+
);
52+
53+
// TODO move deployment into a struct and only deploy if new model is better than old model
54+
Spi::get_one_with_args::<i64>(
55+
"INSERT INTO pgml_rust.deployments (project_id, model_id, strategy) VALUES ($1, $2, $3::pgml_rust.strategy) RETURNING id",
56+
vec![
57+
(PgBuiltInOids::INT8OID.oid(), project.id.into_datum()),
58+
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
59+
(PgBuiltInOids::TEXTOID.oid(), Strategy::most_recent.to_string().into_datum()),
60+
]
61+
);
62+
}
63+
64+
#[pg_extern]
65+
fn predict(project_name: &str, features: Vec<f32>) -> f32 {
66+
let estimator = crate::orm::estimator::find_deployed_estimator_by_project_name(project_name);
67+
estimator.predict(features)
68+
}
69+
70+
// #[pg_extern]
71+
// fn return_table_example() -> impl std::Iterator<Item = (name!(id, Option<i64>), name!(title, Option<String>))> {
72+
// let tuple = Spi::get_two_with_args("SELECT 1 AS id, 2 AS title;", None, None)
73+
// vec![tuple].into_iter()
74+
// }
75+
76+
#[pg_extern]
77+
fn create_snapshot(
78+
relation_name: &str,
79+
y_column_name: &str,
80+
test_size: f32,
81+
test_sampling: Sampling,
82+
) -> i64 {
83+
let snapshot = Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
84+
info!("{:?}", snapshot);
85+
snapshot.id
86+
}
87+
88+
#[cfg(any(test, feature = "pg_test"))]
89+
#[pg_schema]
90+
mod tests {
91+
use super::*;
92+
93+
#[pg_test]
94+
fn test_project_lifecycle() {
95+
assert_eq!(Project::create("test", Task::regression).id, 1);
96+
assert_eq!(Project::find(1).id, 1);
97+
}
98+
99+
#[pg_test]
100+
fn test_snapshot_lifecycle() {
101+
let snapshot = Snapshot::create("test", "column", 0.5, Sampling::last);
102+
assert_eq!(snapshot.id, 1);
103+
}
104+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy