Content-Length: 15514 | pFad | http://github.com/postgresml/postgresml/pull/1336.diff
thub.com
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index e9b0b1412..e1859d8aa 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -47,7 +47,12 @@ jobs:
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
run: |
git submodule update --init --recursive
+ - name: Get current version
+ id: current-version
+ run: echo "CI_BRANCH=$(git name-rev --name-only HEAD)" >> $GITHUB_OUTPUT
- name: Run tests
+ env:
+ CI_BRANCH: ${{ steps.current-version.outputs.CI_BRANCH }}
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
run: |
curl https://sh.rustup.rs -sSf | sh -s -- -y
@@ -58,8 +63,13 @@ jobs:
cargo pgrx init
fi
+ git checkout master
+ echo "\q" | cargo pgrx run
+ psql -p 28816 -h localhost -d pgml -P pager -c "CREATE EXTENSION pgml;"
+ git checkout $CI_BRANCH
+ echo "\q" | cargo pgrx run
+ psql -p 28816 -h localhost -d pgml -P pager -c "ALTER EXTENSION pgml UPDATE;"
cargo pgrx test
-
# cargo pgrx start
# psql -p 28815 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql
# cargo pgrx stop
diff --git a/pgml-extension/sql/pgml--2.8.1--2.8.2.sql b/pgml-extension/sql/pgml--2.8.1--2.8.2.sql
index 2c6264fb9..98e2216e9 100644
--- a/pgml-extension/sql/pgml--2.8.1--2.8.2.sql
+++ b/pgml-extension/sql/pgml--2.8.1--2.8.2.sql
@@ -25,3 +25,102 @@ CREATE FUNCTION pgml."deploy"(
AS 'MODULE_PATHNAME', 'deploy_strategy_wrapper';
ALTER TYPE pgml.strategy ADD VALUE 'specific';
+
+ALTER TYPE pgml.Sampling ADD VALUE 'stratified';
+
+-- src/api.rs:534
+-- pgml::api::snapshot
+DROP FUNCTION IF EXISTS pgml."snapshot"(text, text, real, pgml.Sampling, jsonb);
+CREATE FUNCTION pgml."snapshot"(
+ "relation_name" TEXT, /* &str */
+ "y_column_name" TEXT, /* &str */
+ "test_size" real DEFAULT 0.25, /* f32 */
+ "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
+ "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
+) RETURNS TABLE (
+ "relation" TEXT, /* alloc::string::String */
+ "y_column_name" TEXT /* alloc::string::String */
+)
+STRICT
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'snapshot_wrapper';
+
+-- src/api.rs:802
+-- pgml::api::tune
+DROP FUNCTION IF EXISTS pgml."tune"(text, text, text, text, text, jsonb, real, pgml.Sampling, bool, bool);
+CREATE FUNCTION pgml."tune"(
+ "project_name" TEXT, /* &str */
+ "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "model_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "test_size" real DEFAULT 0.25, /* f32 */
+ "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
+ "automatic_deploy" bool DEFAULT true, /* core::option::Option */
+ "materialize_snapshot" bool DEFAULT false /* bool */
+) RETURNS TABLE (
+ "status" TEXT, /* alloc::string::String */
+ "task" TEXT, /* alloc::string::String */
+ "algorithm" TEXT, /* alloc::string::String */
+ "deployed" bool /* bool */
+)
+PARALLEL SAFE
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'tune_wrapper';
+
+-- src/api.rs:92
+-- pgml::api::train
+DROP FUNCTION IF EXISTS pgml."train"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
+CREATE FUNCTION pgml."train"(
+ "project_name" TEXT, /* &str */
+ "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
+ "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "search" pgml.Search DEFAULT NULL, /* core::option::Option */
+ "search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "test_size" real DEFAULT 0.25, /* f32 */
+ "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
+ "runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option */
+ "automatic_deploy" bool DEFAULT true, /* core::option::Option */
+ "materialize_snapshot" bool DEFAULT false, /* bool */
+ "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
+) RETURNS TABLE (
+ "project" TEXT, /* alloc::string::String */
+ "task" TEXT, /* alloc::string::String */
+ "algorithm" TEXT, /* alloc::string::String */
+ "deployed" bool /* bool */
+)
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'train_wrapper';
+
+-- src/api.rs:138
+-- pgml::api::train_joint
+DROP FUNCTION IF EXISTS pgml."train_joint"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
+CREATE FUNCTION pgml."train_joint"(
+ "project_name" TEXT, /* &str */
+ "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "y_column_name" TEXT[] DEFAULT NULL, /* core::option::Option> */
+ "algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
+ "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "search" pgml.Search DEFAULT NULL, /* core::option::Option */
+ "search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "test_size" real DEFAULT 0.25, /* f32 */
+ "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
+ "runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option */
+ "automatic_deploy" bool DEFAULT true, /* core::option::Option */
+ "materialize_snapshot" bool DEFAULT false, /* bool */
+ "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
+) RETURNS TABLE (
+ "project" TEXT, /* alloc::string::String */
+ "task" TEXT, /* alloc::string::String */
+ "algorithm" TEXT, /* alloc::string::String */
+ "deployed" bool /* bool */
+)
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'train_joint_wrapper';
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index 1580de944..7fd5012c8 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -100,7 +100,7 @@ fn train(
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'last'"),
+ test_sampling: default!(Sampling, "'stratified'"),
runtime: default!(Option, "NULL"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
@@ -146,7 +146,7 @@ fn train_joint(
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'last'"),
+ test_sampling: default!(Sampling, "'stratified'"),
runtime: default!(Option, "NULL"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
@@ -535,7 +535,7 @@ fn snapshot(
relation_name: &str,
y_column_name: &str,
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'last'"),
+ test_sampling: default!(Sampling, "'stratified'"),
preprocess: default!(JsonB, "'{}'"),
) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> {
Snapshot::create(
@@ -807,7 +807,7 @@ fn tune(
model_name: default!(Option<&str>, "NULL"),
hyperparams: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'last'"),
+ test_sampling: default!(Sampling, "'stratified'"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
) -> TableIterator<
diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs
index 6bb3d7b5a..2ecd66f5d 100644
--- a/pgml-extension/src/orm/sampling.rs
+++ b/pgml-extension/src/orm/sampling.rs
@@ -1,11 +1,14 @@
use pgrx::*;
use serde::Deserialize;
+use super::snapshot::Column;
+
#[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)]
#[allow(non_camel_case_types)]
pub enum Sampling {
random,
last,
+ stratified,
}
impl std::str::FromStr for Sampling {
@@ -15,6 +18,7 @@ impl std::str::FromStr for Sampling {
match input {
"random" => Ok(Sampling::random),
"last" => Ok(Sampling::last),
+ "stratified" => Ok(Sampling::stratified),
_ => Err(()),
}
}
@@ -25,6 +29,111 @@ impl std::string::ToString for Sampling {
match *self {
Sampling::random => "random".to_string(),
Sampling::last => "last".to_string(),
+ Sampling::stratified => "stratified".to_string(),
}
}
}
+
+impl Sampling {
+ // Implementing the sampling strategy in SQL
+ // Effectively orders the table according to the train/test split
+ // e.g. first N rows are train, last M rows are test
+ // where M is configured by the user
+ pub fn get_sql(&self, relation_name: &str, y_column_names: Vec) -> String {
+ let col_string = y_column_names
+ .iter()
+ .map(|c| c.quoted_name())
+ .collect::>()
+ .join(", ");
+ match *self {
+ Sampling::random => {
+ format!("SELECT * FROM {relation_name} ORDER BY RANDOM()")
+ }
+ Sampling::last => {
+ format!("SELECT * FROM {relation_name}")
+ }
+ Sampling::stratified => {
+ format!(
+ "
+ SELECT *
+ FROM (
+ SELECT
+ *,
+ ROW_NUMBER() OVER(PARTITION BY {col_string} ORDER BY RANDOM()) AS rn
+ FROM {relation_name}
+ ) AS subquery
+ ORDER BY rn, RANDOM();
+ "
+ )
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::orm::snapshot::{Preprocessor, Statistics};
+
+ use super::*;
+
+ fn get_column_fixtures() -> Vec {
+ vec![
+ Column {
+ name: "col1".to_string(),
+ pg_type: "text".to_string(),
+ nullable: false,
+ label: true,
+ position: 0,
+ size: 0,
+ array: false,
+ preprocessor: Preprocessor::default(),
+ statistics: Statistics::default(),
+ },
+ Column {
+ name: "col2".to_string(),
+ pg_type: "text".to_string(),
+ nullable: false,
+ label: true,
+ position: 0,
+ size: 0,
+ array: false,
+ preprocessor: Preprocessor::default(),
+ statistics: Statistics::default(),
+ },
+ ]
+ }
+
+ #[test]
+ fn test_get_sql_random_sampling() {
+ let sampling = Sampling::random;
+ let columns = get_column_fixtures();
+ let sql = sampling.get_sql("my_table", columns);
+ assert_eq!(sql, "SELECT * FROM my_table ORDER BY RANDOM()");
+ }
+
+ #[test]
+ fn test_get_sql_last_sampling() {
+ let sampling = Sampling::last;
+ let columns = get_column_fixtures();
+ let sql = sampling.get_sql("my_table", columns);
+ assert_eq!(sql, "SELECT * FROM my_table");
+ }
+
+ #[test]
+ fn test_get_sql_stratified_sampling() {
+ let sampling = Sampling::stratified;
+ let columns = get_column_fixtures();
+ let sql = sampling.get_sql("my_table", columns);
+ let expected_sql = "
+ SELECT *
+ FROM (
+ SELECT
+ *,
+ ROW_NUMBER() OVER(PARTITION BY \"col1\", \"col2\" ORDER BY RANDOM()) AS rn
+ FROM my_table
+ ) AS subquery
+ ORDER BY rn, RANDOM();
+ ";
+ assert_eq!(sql, expected_sql);
+ }
+}
diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs
index 1cecd2a8c..9a0c22780 100644
--- a/pgml-extension/src/orm/snapshot.rs
+++ b/pgml-extension/src/orm/snapshot.rs
@@ -119,7 +119,7 @@ pub(crate) struct Preprocessor {
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
-pub(crate) struct Column {
+pub struct Column {
pub(crate) name: String,
pub(crate) pg_type: String,
pub(crate) nullable: bool,
@@ -147,7 +147,7 @@ impl Column {
)
}
- fn quoted_name(&self) -> String {
+ pub(crate) fn quoted_name(&self) -> String {
format!(r#""{}""#, self.name)
}
@@ -608,13 +608,8 @@ impl Snapshot {
};
if materialized {
- let mut sql = format!(
- r#"CREATE TABLE "pgml"."snapshot_{}" AS SELECT * FROM {}"#,
- s.id, s.relation_name
- );
- if s.test_sampling == Sampling::random {
- sql += " ORDER BY random()";
- }
+ let sampled_query = s.test_sampling.get_sql(&s.relation_name, s.columns.clone());
+ let sql = format!(r#"CREATE TABLE "pgml"."snapshot_{}" AS {}"#, s.id, sampled_query);
client.update(&sql, None, None).unwrap();
}
snapshot = Some(s);
@@ -742,26 +737,20 @@ impl Snapshot {
}
fn select_sql(&self) -> String {
- format!(
- "SELECT {} FROM {} {}",
- self.columns
- .iter()
- .map(|c| c.quoted_name())
- .collect::>()
- .join(", "),
- self.relation_name_quoted(),
- match self.materialized {
- // If the snapshot is materialized, we already randomized it.
- true => "",
- false => {
- if self.test_sampling == Sampling::random {
- "ORDER BY random()"
- } else {
- ""
- }
- }
- },
- )
+ match self.materialized {
+ true => {
+ format!(
+ "SELECT {} FROM {}",
+ self.columns
+ .iter()
+ .map(|c| c.quoted_name())
+ .collect::>()
+ .join(", "),
+ self.relation_name_quoted()
+ )
+ }
+ false => self.test_sampling.get_sql(&self.relation_name_quoted(), self.columns.clone()),
+ }
}
fn train_test_split(&self, num_rows: usize) -> (usize, usize) {
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/postgresml/postgresml/pull/1336.diff
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy