Content-Length: 15514 | pFad | http://github.com/postgresml/postgresml/pull/1336.diff

thub.com diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e9b0b1412..e1859d8aa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,12 @@ jobs: if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0' run: | git submodule update --init --recursive + - name: Get current version + id: current-version + run: echo "CI_BRANCH=$(git name-rev --name-only HEAD)" >> $GITHUB_OUTPUT - name: Run tests + env: + CI_BRANCH: ${{ steps.current-version.outputs.CI_BRANCH }} if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0' run: | curl https://sh.rustup.rs -sSf | sh -s -- -y @@ -58,8 +63,13 @@ jobs: cargo pgrx init fi + git checkout master + echo "\q" | cargo pgrx run + psql -p 28816 -h localhost -d pgml -P pager -c "CREATE EXTENSION pgml;" + git checkout $CI_BRANCH + echo "\q" | cargo pgrx run + psql -p 28816 -h localhost -d pgml -P pager -c "ALTER EXTENSION pgml UPDATE;" cargo pgrx test - # cargo pgrx start # psql -p 28815 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql # cargo pgrx stop diff --git a/pgml-extension/sql/pgml--2.8.1--2.8.2.sql b/pgml-extension/sql/pgml--2.8.1--2.8.2.sql index 2c6264fb9..98e2216e9 100644 --- a/pgml-extension/sql/pgml--2.8.1--2.8.2.sql +++ b/pgml-extension/sql/pgml--2.8.1--2.8.2.sql @@ -25,3 +25,102 @@ CREATE FUNCTION pgml."deploy"( AS 'MODULE_PATHNAME', 'deploy_strategy_wrapper'; ALTER TYPE pgml.strategy ADD VALUE 'specific'; + +ALTER TYPE pgml.Sampling ADD VALUE 'stratified'; + +-- src/api.rs:534 +-- pgml::api::snapshot +DROP FUNCTION IF EXISTS pgml."snapshot"(text, text, real, pgml.Sampling, jsonb); +CREATE FUNCTION pgml."snapshot"( + "relation_name" TEXT, /* &str */ + "y_column_name" TEXT, /* &str */ + "test_size" real DEFAULT 0.25, /* f32 */ + "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */ + "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "relation" TEXT, /* alloc::string::String */ + "y_column_name" TEXT /* alloc::string::String */ +) +STRICT +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'snapshot_wrapper'; + +-- src/api.rs:802 +-- pgml::api::tune +DROP FUNCTION IF EXISTS pgml."tune"(text, text, text, text, text, jsonb, real, pgml.Sampling, bool, bool); +CREATE FUNCTION pgml."tune"( + "project_name" TEXT, /* &str */ + "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "model_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "test_size" real DEFAULT 0.25, /* f32 */ + "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */ + "automatic_deploy" bool DEFAULT true, /* core::option::Option */ + "materialize_snapshot" bool DEFAULT false /* bool */ +) RETURNS TABLE ( + "status" TEXT, /* alloc::string::String */ + "task" TEXT, /* alloc::string::String */ + "algorithm" TEXT, /* alloc::string::String */ + "deployed" bool /* bool */ +) +PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'tune_wrapper'; + +-- src/api.rs:92 +-- pgml::api::train +DROP FUNCTION IF EXISTS pgml."train"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb); +CREATE FUNCTION pgml."train"( + "project_name" TEXT, /* &str */ + "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */ + "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "search" pgml.Search DEFAULT NULL, /* core::option::Option */ + "search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "test_size" real DEFAULT 0.25, /* f32 */ + "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */ + "runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option */ + "automatic_deploy" bool DEFAULT true, /* core::option::Option */ + "materialize_snapshot" bool DEFAULT false, /* bool */ + "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "project" TEXT, /* alloc::string::String */ + "task" TEXT, /* alloc::string::String */ + "algorithm" TEXT, /* alloc::string::String */ + "deployed" bool /* bool */ +) +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'train_wrapper'; + +-- src/api.rs:138 +-- pgml::api::train_joint +DROP FUNCTION IF EXISTS pgml."train_joint"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb); +CREATE FUNCTION pgml."train_joint"( + "project_name" TEXT, /* &str */ + "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */ + "y_column_name" TEXT[] DEFAULT NULL, /* core::option::Option> */ + "algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */ + "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "search" pgml.Search DEFAULT NULL, /* core::option::Option */ + "search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "test_size" real DEFAULT 0.25, /* f32 */ + "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */ + "runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option */ + "automatic_deploy" bool DEFAULT true, /* core::option::Option */ + "materialize_snapshot" bool DEFAULT false, /* bool */ + "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "project" TEXT, /* alloc::string::String */ + "task" TEXT, /* alloc::string::String */ + "algorithm" TEXT, /* alloc::string::String */ + "deployed" bool /* bool */ +) +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'train_joint_wrapper'; diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 1580de944..7fd5012c8 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -100,7 +100,7 @@ fn train( search_params: default!(JsonB, "'{}'"), search_args: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), + test_sampling: default!(Sampling, "'stratified'"), runtime: default!(Option, "NULL"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), @@ -146,7 +146,7 @@ fn train_joint( search_params: default!(JsonB, "'{}'"), search_args: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), + test_sampling: default!(Sampling, "'stratified'"), runtime: default!(Option, "NULL"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), @@ -535,7 +535,7 @@ fn snapshot( relation_name: &str, y_column_name: &str, test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), + test_sampling: default!(Sampling, "'stratified'"), preprocess: default!(JsonB, "'{}'"), ) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> { Snapshot::create( @@ -807,7 +807,7 @@ fn tune( model_name: default!(Option<&str>, "NULL"), hyperparams: default!(JsonB, "'{}'"), test_size: default!(f32, 0.25), - test_sampling: default!(Sampling, "'last'"), + test_sampling: default!(Sampling, "'stratified'"), automatic_deploy: default!(Option, true), materialize_snapshot: default!(bool, false), ) -> TableIterator< diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs index 6bb3d7b5a..2ecd66f5d 100644 --- a/pgml-extension/src/orm/sampling.rs +++ b/pgml-extension/src/orm/sampling.rs @@ -1,11 +1,14 @@ use pgrx::*; use serde::Deserialize; +use super::snapshot::Column; + #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] #[allow(non_camel_case_types)] pub enum Sampling { random, last, + stratified, } impl std::str::FromStr for Sampling { @@ -15,6 +18,7 @@ impl std::str::FromStr for Sampling { match input { "random" => Ok(Sampling::random), "last" => Ok(Sampling::last), + "stratified" => Ok(Sampling::stratified), _ => Err(()), } } @@ -25,6 +29,111 @@ impl std::string::ToString for Sampling { match *self { Sampling::random => "random".to_string(), Sampling::last => "last".to_string(), + Sampling::stratified => "stratified".to_string(), } } } + +impl Sampling { + // Implementing the sampling strategy in SQL + // Effectively orders the table according to the train/test split + // e.g. first N rows are train, last M rows are test + // where M is configured by the user + pub fn get_sql(&self, relation_name: &str, y_column_names: Vec) -> String { + let col_string = y_column_names + .iter() + .map(|c| c.quoted_name()) + .collect::>() + .join(", "); + match *self { + Sampling::random => { + format!("SELECT * FROM {relation_name} ORDER BY RANDOM()") + } + Sampling::last => { + format!("SELECT * FROM {relation_name}") + } + Sampling::stratified => { + format!( + " + SELECT * + FROM ( + SELECT + *, + ROW_NUMBER() OVER(PARTITION BY {col_string} ORDER BY RANDOM()) AS rn + FROM {relation_name} + ) AS subquery + ORDER BY rn, RANDOM(); + " + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::orm::snapshot::{Preprocessor, Statistics}; + + use super::*; + + fn get_column_fixtures() -> Vec { + vec![ + Column { + name: "col1".to_string(), + pg_type: "text".to_string(), + nullable: false, + label: true, + position: 0, + size: 0, + array: false, + preprocessor: Preprocessor::default(), + statistics: Statistics::default(), + }, + Column { + name: "col2".to_string(), + pg_type: "text".to_string(), + nullable: false, + label: true, + position: 0, + size: 0, + array: false, + preprocessor: Preprocessor::default(), + statistics: Statistics::default(), + }, + ] + } + + #[test] + fn test_get_sql_random_sampling() { + let sampling = Sampling::random; + let columns = get_column_fixtures(); + let sql = sampling.get_sql("my_table", columns); + assert_eq!(sql, "SELECT * FROM my_table ORDER BY RANDOM()"); + } + + #[test] + fn test_get_sql_last_sampling() { + let sampling = Sampling::last; + let columns = get_column_fixtures(); + let sql = sampling.get_sql("my_table", columns); + assert_eq!(sql, "SELECT * FROM my_table"); + } + + #[test] + fn test_get_sql_stratified_sampling() { + let sampling = Sampling::stratified; + let columns = get_column_fixtures(); + let sql = sampling.get_sql("my_table", columns); + let expected_sql = " + SELECT * + FROM ( + SELECT + *, + ROW_NUMBER() OVER(PARTITION BY \"col1\", \"col2\" ORDER BY RANDOM()) AS rn + FROM my_table + ) AS subquery + ORDER BY rn, RANDOM(); + "; + assert_eq!(sql, expected_sql); + } +} diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 1cecd2a8c..9a0c22780 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -119,7 +119,7 @@ pub(crate) struct Preprocessor { } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] -pub(crate) struct Column { +pub struct Column { pub(crate) name: String, pub(crate) pg_type: String, pub(crate) nullable: bool, @@ -147,7 +147,7 @@ impl Column { ) } - fn quoted_name(&self) -> String { + pub(crate) fn quoted_name(&self) -> String { format!(r#""{}""#, self.name) } @@ -608,13 +608,8 @@ impl Snapshot { }; if materialized { - let mut sql = format!( - r#"CREATE TABLE "pgml"."snapshot_{}" AS SELECT * FROM {}"#, - s.id, s.relation_name - ); - if s.test_sampling == Sampling::random { - sql += " ORDER BY random()"; - } + let sampled_query = s.test_sampling.get_sql(&s.relation_name, s.columns.clone()); + let sql = format!(r#"CREATE TABLE "pgml"."snapshot_{}" AS {}"#, s.id, sampled_query); client.update(&sql, None, None).unwrap(); } snapshot = Some(s); @@ -742,26 +737,20 @@ impl Snapshot { } fn select_sql(&self) -> String { - format!( - "SELECT {} FROM {} {}", - self.columns - .iter() - .map(|c| c.quoted_name()) - .collect::>() - .join(", "), - self.relation_name_quoted(), - match self.materialized { - // If the snapshot is materialized, we already randomized it. - true => "", - false => { - if self.test_sampling == Sampling::random { - "ORDER BY random()" - } else { - "" - } - } - }, - ) + match self.materialized { + true => { + format!( + "SELECT {} FROM {}", + self.columns + .iter() + .map(|c| c.quoted_name()) + .collect::>() + .join(", "), + self.relation_name_quoted() + ) + } + false => self.test_sampling.get_sql(&self.relation_name_quoted(), self.columns.clone()), + } } fn train_test_split(&self, num_rows: usize) -> (usize, usize) {








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/pull/1336.diff

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy