Skip to content

Commit 63963d4

Browse files
authored
Stratified sampling (#1336)
I verified tests locally, because I wasn't able to figure out how to get them running via github actions...
1 parent 347168a commit 63963d4

File tree

5 files changed

+241
-34
lines changed

5 files changed

+241
-34
lines changed

.github/workflows/ci.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ jobs:
4747
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
4848
run: |
4949
git submodule update --init --recursive
50+
- name: Get current version
51+
id: current-version
52+
run: echo "CI_BRANCH=$(git name-rev --name-only HEAD)" >> $GITHUB_OUTPUT
5053
- name: Run tests
54+
env:
55+
CI_BRANCH: ${{ steps.current-version.outputs.CI_BRANCH }}
5156
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
5257
run: |
5358
curl https://sh.rustup.rs -sSf | sh -s -- -y
@@ -58,8 +63,13 @@ jobs:
5863
cargo pgrx init
5964
fi
6065
66+
git checkout master
67+
echo "\q" | cargo pgrx run
68+
psql -p 28816 -h localhost -d pgml -P pager -c "CREATE EXTENSION pgml;"
69+
git checkout $CI_BRANCH
70+
echo "\q" | cargo pgrx run
71+
psql -p 28816 -h localhost -d pgml -P pager -c "ALTER EXTENSION pgml UPDATE;"
6172
cargo pgrx test
62-
6373
# cargo pgrx start
6474
# psql -p 28815 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql
6575
# cargo pgrx stop

pgml-extension/sql/pgml--2.8.1--2.8.2.sql

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,102 @@ CREATE FUNCTION pgml."deploy"(
2525
AS 'MODULE_PATHNAME', 'deploy_strategy_wrapper';
2626

2727
ALTER TYPE pgml.strategy ADD VALUE 'specific';
28+
29+
ALTER TYPE pgml.Sampling ADD VALUE 'stratified';
30+
31+
-- src/api.rs:534
32+
-- pgml::api::snapshot
33+
DROP FUNCTION IF EXISTS pgml."snapshot"(text, text, real, pgml.Sampling, jsonb);
34+
CREATE FUNCTION pgml."snapshot"(
35+
"relation_name" TEXT, /* &str */
36+
"y_column_name" TEXT, /* &str */
37+
"test_size" real DEFAULT 0.25, /* f32 */
38+
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
39+
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
40+
) RETURNS TABLE (
41+
"relation" TEXT, /* alloc::string::String */
42+
"y_column_name" TEXT /* alloc::string::String */
43+
)
44+
STRICT
45+
LANGUAGE c /* Rust */
46+
AS 'MODULE_PATHNAME', 'snapshot_wrapper';
47+
48+
-- src/api.rs:802
49+
-- pgml::api::tune
50+
DROP FUNCTION IF EXISTS pgml."tune"(text, text, text, text, text, jsonb, real, pgml.Sampling, bool, bool);
51+
CREATE FUNCTION pgml."tune"(
52+
"project_name" TEXT, /* &str */
53+
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
54+
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
55+
"y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
56+
"model_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
57+
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
58+
"test_size" real DEFAULT 0.25, /* f32 */
59+
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
60+
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
61+
"materialize_snapshot" bool DEFAULT false /* bool */
62+
) RETURNS TABLE (
63+
"status" TEXT, /* alloc::string::String */
64+
"task" TEXT, /* alloc::string::String */
65+
"algorithm" TEXT, /* alloc::string::String */
66+
"deployed" bool /* bool */
67+
)
68+
PARALLEL SAFE
69+
LANGUAGE c /* Rust */
70+
AS 'MODULE_PATHNAME', 'tune_wrapper';
71+
72+
-- src/api.rs:92
73+
-- pgml::api::train
74+
DROP FUNCTION IF EXISTS pgml."train"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
75+
CREATE FUNCTION pgml."train"(
76+
"project_name" TEXT, /* &str */
77+
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
78+
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
79+
"y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
80+
"algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
81+
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
82+
"search" pgml.Search DEFAULT NULL, /* core::option::Option<pgml::orm::search::Search> */
83+
"search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
84+
"search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
85+
"test_size" real DEFAULT 0.25, /* f32 */
86+
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
87+
"runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option<pgml::orm::runtime::Runtime> */
88+
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
89+
"materialize_snapshot" bool DEFAULT false, /* bool */
90+
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
91+
) RETURNS TABLE (
92+
"project" TEXT, /* alloc::string::String */
93+
"task" TEXT, /* alloc::string::String */
94+
"algorithm" TEXT, /* alloc::string::String */
95+
"deployed" bool /* bool */
96+
)
97+
LANGUAGE c /* Rust */
98+
AS 'MODULE_PATHNAME', 'train_wrapper';
99+
100+
-- src/api.rs:138
101+
-- pgml::api::train_joint
102+
DROP FUNCTION IF EXISTS pgml."train_joint"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
103+
CREATE FUNCTION pgml."train_joint"(
104+
"project_name" TEXT, /* &str */
105+
"task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
106+
"relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
107+
"y_column_name" TEXT[] DEFAULT NULL, /* core::option::Option<alloc::vec::Vec<alloc::string::String>> */
108+
"algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
109+
"hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
110+
"search" pgml.Search DEFAULT NULL, /* core::option::Option<pgml::orm::search::Search> */
111+
"search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
112+
"search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
113+
"test_size" real DEFAULT 0.25, /* f32 */
114+
"test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
115+
"runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option<pgml::orm::runtime::Runtime> */
116+
"automatic_deploy" bool DEFAULT true, /* core::option::Option<bool> */
117+
"materialize_snapshot" bool DEFAULT false, /* bool */
118+
"preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
119+
) RETURNS TABLE (
120+
"project" TEXT, /* alloc::string::String */
121+
"task" TEXT, /* alloc::string::String */
122+
"algorithm" TEXT, /* alloc::string::String */
123+
"deployed" bool /* bool */
124+
)
125+
LANGUAGE c /* Rust */
126+
AS 'MODULE_PATHNAME', 'train_joint_wrapper';

pgml-extension/src/api.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ fn train(
100100
search_params: default!(JsonB, "'{}'"),
101101
search_args: default!(JsonB, "'{}'"),
102102
test_size: default!(f32, 0.25),
103-
test_sampling: default!(Sampling, "'last'"),
103+
test_sampling: default!(Sampling, "'stratified'"),
104104
runtime: default!(Option<Runtime>, "NULL"),
105105
automatic_deploy: default!(Option<bool>, true),
106106
materialize_snapshot: default!(bool, false),
@@ -146,7 +146,7 @@ fn train_joint(
146146
search_params: default!(JsonB, "'{}'"),
147147
search_args: default!(JsonB, "'{}'"),
148148
test_size: default!(f32, 0.25),
149-
test_sampling: default!(Sampling, "'last'"),
149+
test_sampling: default!(Sampling, "'stratified'"),
150150
runtime: default!(Option<Runtime>, "NULL"),
151151
automatic_deploy: default!(Option<bool>, true),
152152
materialize_snapshot: default!(bool, false),
@@ -535,7 +535,7 @@ fn snapshot(
535535
relation_name: &str,
536536
y_column_name: &str,
537537
test_size: default!(f32, 0.25),
538-
test_sampling: default!(Sampling, "'last'"),
538+
test_sampling: default!(Sampling, "'stratified'"),
539539
preprocess: default!(JsonB, "'{}'"),
540540
) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> {
541541
Snapshot::create(
@@ -807,7 +807,7 @@ fn tune(
807807
model_name: default!(Option<&str>, "NULL"),
808808
hyperparams: default!(JsonB, "'{}'"),
809809
test_size: default!(f32, 0.25),
810-
test_sampling: default!(Sampling, "'last'"),
810+
test_sampling: default!(Sampling, "'stratified'"),
811811
automatic_deploy: default!(Option<bool>, true),
812812
materialize_snapshot: default!(bool, false),
813813
) -> TableIterator<

pgml-extension/src/orm/sampling.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
use pgrx::*;
22
use serde::Deserialize;
33

4+
use super::snapshot::Column;
5+
46
#[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)]
57
#[allow(non_camel_case_types)]
68
pub enum Sampling {
79
random,
810
last,
11+
stratified,
912
}
1013

1114
impl std::str::FromStr for Sampling {
@@ -15,6 +18,7 @@ impl std::str::FromStr for Sampling {
1518
match input {
1619
"random" => Ok(Sampling::random),
1720
"last" => Ok(Sampling::last),
21+
"stratified" => Ok(Sampling::stratified),
1822
_ => Err(()),
1923
}
2024
}
@@ -25,6 +29,111 @@ impl std::string::ToString for Sampling {
2529
match *self {
2630
Sampling::random => "random".to_string(),
2731
Sampling::last => "last".to_string(),
32+
Sampling::stratified => "stratified".to_string(),
2833
}
2934
}
3035
}
36+
37+
impl Sampling {
38+
// Implementing the sampling strategy in SQL
39+
// Effectively orders the table according to the train/test split
40+
// e.g. first N rows are train, last M rows are test
41+
// where M is configured by the user
42+
pub fn get_sql(&self, relation_name: &str, y_column_names: Vec<Column>) -> String {
43+
let col_string = y_column_names
44+
.iter()
45+
.map(|c| c.quoted_name())
46+
.collect::<Vec<String>>()
47+
.join(", ");
48+
match *self {
49+
Sampling::random => {
50+
format!("SELECT * FROM {relation_name} ORDER BY RANDOM()")
51+
}
52+
Sampling::last => {
53+
format!("SELECT * FROM {relation_name}")
54+
}
55+
Sampling::stratified => {
56+
format!(
57+
"
58+
SELECT *
59+
FROM (
60+
SELECT
61+
*,
62+
ROW_NUMBER() OVER(PARTITION BY {col_string} ORDER BY RANDOM()) AS rn
63+
FROM {relation_name}
64+
) AS subquery
65+
ORDER BY rn, RANDOM();
66+
"
67+
)
68+
}
69+
}
70+
}
71+
}
72+
73+
#[cfg(test)]
74+
mod tests {
75+
use crate::orm::snapshot::{Preprocessor, Statistics};
76+
77+
use super::*;
78+
79+
fn get_column_fixtures() -> Vec<Column> {
80+
vec![
81+
Column {
82+
name: "col1".to_string(),
83+
pg_type: "text".to_string(),
84+
nullable: false,
85+
label: true,
86+
position: 0,
87+
size: 0,
88+
array: false,
89+
preprocessor: Preprocessor::default(),
90+
statistics: Statistics::default(),
91+
},
92+
Column {
93+
name: "col2".to_string(),
94+
pg_type: "text".to_string(),
95+
nullable: false,
96+
label: true,
97+
position: 0,
98+
size: 0,
99+
array: false,
100+
preprocessor: Preprocessor::default(),
101+
statistics: Statistics::default(),
102+
},
103+
]
104+
}
105+
106+
#[test]
107+
fn test_get_sql_random_sampling() {
108+
let sampling = Sampling::random;
109+
let columns = get_column_fixtures();
110+
let sql = sampling.get_sql("my_table", columns);
111+
assert_eq!(sql, "SELECT * FROM my_table ORDER BY RANDOM()");
112+
}
113+
114+
#[test]
115+
fn test_get_sql_last_sampling() {
116+
let sampling = Sampling::last;
117+
let columns = get_column_fixtures();
118+
let sql = sampling.get_sql("my_table", columns);
119+
assert_eq!(sql, "SELECT * FROM my_table");
120+
}
121+
122+
#[test]
123+
fn test_get_sql_stratified_sampling() {
124+
let sampling = Sampling::stratified;
125+
let columns = get_column_fixtures();
126+
let sql = sampling.get_sql("my_table", columns);
127+
let expected_sql = "
128+
SELECT *
129+
FROM (
130+
SELECT
131+
*,
132+
ROW_NUMBER() OVER(PARTITION BY \"col1\", \"col2\" ORDER BY RANDOM()) AS rn
133+
FROM my_table
134+
) AS subquery
135+
ORDER BY rn, RANDOM();
136+
";
137+
assert_eq!(sql, expected_sql);
138+
}
139+
}

pgml-extension/src/orm/snapshot.rs

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ pub(crate) struct Preprocessor {
119119
}
120120

121121
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
122-
pub(crate) struct Column {
122+
pub struct Column {
123123
pub(crate) name: String,
124124
pub(crate) pg_type: String,
125125
pub(crate) nullable: bool,
@@ -147,7 +147,7 @@ impl Column {
147147
)
148148
}
149149

150-
fn quoted_name(&self) -> String {
150+
pub(crate) fn quoted_name(&self) -> String {
151151
format!(r#""{}""#, self.name)
152152
}
153153

@@ -608,13 +608,8 @@ impl Snapshot {
608608
};
609609

610610
if materialized {
611-
let mut sql = format!(
612-
r#"CREATE TABLE "pgml"."snapshot_{}" AS SELECT * FROM {}"#,
613-
s.id, s.relation_name
614-
);
615-
if s.test_sampling == Sampling::random {
616-
sql += " ORDER BY random()";
617-
}
611+
let sampled_query = s.test_sampling.get_sql(&s.relation_name, s.columns.clone());
612+
let sql = format!(r#"CREATE TABLE "pgml"."snapshot_{}" AS {}"#, s.id, sampled_query);
618613
client.update(&sql, None, None).unwrap();
619614
}
620615
snapshot = Some(s);
@@ -742,26 +737,20 @@ impl Snapshot {
742737
}
743738

744739
fn select_sql(&self) -> String {
745-
format!(
746-
"SELECT {} FROM {} {}",
747-
self.columns
748-
.iter()
749-
.map(|c| c.quoted_name())
750-
.collect::<Vec<String>>()
751-
.join(", "),
752-
self.relation_name_quoted(),
753-
match self.materialized {
754-
// If the snapshot is materialized, we already randomized it.
755-
true => "",
756-
false => {
757-
if self.test_sampling == Sampling::random {
758-
"ORDER BY random()"
759-
} else {
760-
""
761-
}
762-
}
763-
},
764-
)
740+
match self.materialized {
741+
true => {
742+
format!(
743+
"SELECT {} FROM {}",
744+
self.columns
745+
.iter()
746+
.map(|c| c.quoted_name())
747+
.collect::<Vec<String>>()
748+
.join(", "),
749+
self.relation_name_quoted()
750+
)
751+
}
752+
false => self.test_sampling.get_sql(&self.relation_name_quoted(), self.columns.clone()),
753+
}
765754
}
766755

767756
fn train_test_split(&self, num_rows: usize) -> (usize, usize) {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy