Content-Length: 32949 | pFad | http://github.com/postgresml/postgresml/pull/1336.patch
thub.com
From 0574574397373f56fbe15e5c77eb20017a8a59dc Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Tue, 16 Jan 2024 22:37:31 -0600
Subject: [PATCH 1/9] handle model deploy when no metrics to compare
---
pgml-extension/src/api.rs | 59 ++++++++++++++++++++++++---------------
1 file changed, 37 insertions(+), 22 deletions(-)
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index 380bfb330..24164a9f1 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -264,37 +264,52 @@ fn train_joint(
);
let mut deploy = true;
+
match automatic_deploy {
// Deploy only if metrics are better than previous model.
Some(true) | None => {
if let Ok(Some(deployed_metrics)) = deployed_metrics {
- let deployed_metrics = deployed_metrics.0.as_object().unwrap();
- let deployed_metric = deployed_metrics
- .get(&project.task.default_target_metric())
- .unwrap()
- .as_f64()
- .unwrap();
- info!(
- "Comparing to deployed model {}: {:?}",
- project.task.default_target_metric(),
- deployed_metric
- );
- if project.task.value_is_better(
- deployed_metric,
- new_metrics
- .get(&project.task.default_target_metric())
- .unwrap()
- .as_f64()
- .unwrap(),
- ) {
+ if let Some(deployed_metrics_obj) = deployed_metrics.0.as_object() {
+ let default_target_metric = project.task.default_target_metric();
+ let deployed_metric = deployed_metrics_obj
+ .get(&default_target_metric)
+ .and_then(|v| v.as_f64());
+ info!(
+ "Comparing to deployed model {}: {:?}",
+ default_target_metric, deployed_metric
+ );
+ if let (Some(deployed_metric_value), Some(new_metric_value)) = (
+ deployed_metric,
+ new_metrics.get(&default_target_metric).and_then(|v| v.as_f64()),
+ ) {
+ if project.task.value_is_better(deployed_metric_value, new_metric_value) {
+ warning!(
+ "New model's {} is not better than old model: {} is not better than {}",
+ &project.task.default_target_metric(),
+ new_metric_value,
+ deployed_metric_value
+ );
+ deploy = false;
+ }
+ } else {
+ warning!("Failed to retrieve or parse deployed/new metrics for {}. Ensure train/test split results in both positive and negative label records.",
+ &project.task.default_target_metric());
+ deploy = false;
+ }
+ } else {
+ warning!("Failed to parse deployed model metrics. Ensure train/test split results in both positive and negative label records.");
deploy = false;
}
+ } else {
+ warning!("Failed to obtain currently deployed model metrics. Check if the deployed model metrics are available and correctly formatted.");
+ deploy = false;
}
}
-
- Some(false) => deploy = false,
+ Some(false) => {
+ warning!("Automatic deployment disabled via configuration.");
+ deploy = false;
+ }
};
-
if deploy {
project.deploy(model.id, Strategy::new_score);
} else {
From 667f68ec88ab530e96828a11d25941080d7783a9 Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Tue, 16 Jan 2024 22:40:31 -0600
Subject: [PATCH 2/9] better warn msg
---
pgml-extension/src/api.rs | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index 24164a9f1..f2308a139 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -284,7 +284,7 @@ fn train_joint(
) {
if project.task.value_is_better(deployed_metric_value, new_metric_value) {
warning!(
- "New model's {} is not better than old model: {} is not better than {}",
+ "New model's {} is not better than current model. New: {}, Current {}",
&project.task.default_target_metric(),
new_metric_value,
deployed_metric_value
From d0ff7251d381ac575a7376ded27022431d21cf17 Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Wed, 17 Jan 2024 15:17:24 +0000
Subject: [PATCH 3/9] fix first run case
---
pgml-extension/src/api.rs | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index f2308a139..ada756c7a 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -266,7 +266,7 @@ fn train_joint(
let mut deploy = true;
match automatic_deploy {
- // Deploy only if metrics are better than previous model.
+ // Deploy only if metrics are better than previous model, or if its the first model
Some(true) | None => {
if let Ok(Some(deployed_metrics)) = deployed_metrics {
if let Some(deployed_metrics_obj) = deployed_metrics.0.as_object() {
@@ -297,12 +297,9 @@ fn train_joint(
deploy = false;
}
} else {
- warning!("Failed to parse deployed model metrics. Ensure train/test split results in both positive and negative label records.");
+ warning!("Failed to parse deployed model metrics. Check data types of model metadata on pgml.models.metrics");
deploy = false;
}
- } else {
- warning!("Failed to obtain currently deployed model metrics. Check if the deployed model metrics are available and correctly formatted.");
- deploy = false;
}
}
Some(false) => {
From accaab001f02130048871bf2ab6b407660683452 Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Wed, 17 Jan 2024 22:12:49 +0000
Subject: [PATCH 4/9] impl stratified
---
pgml-extension/src/orm/sampling.rs | 109 +++++++++++++++++++++++++++++
pgml-extension/src/orm/snapshot.rs | 34 ++-------
2 files changed, 114 insertions(+), 29 deletions(-)
diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs
index 6bb3d7b5a..442f683ac 100644
--- a/pgml-extension/src/orm/sampling.rs
+++ b/pgml-extension/src/orm/sampling.rs
@@ -1,11 +1,14 @@
use pgrx::*;
use serde::Deserialize;
+use super::snapshot::Column;
+
#[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)]
#[allow(non_camel_case_types)]
pub enum Sampling {
random,
last,
+ stratified_random,
}
impl std::str::FromStr for Sampling {
@@ -15,6 +18,7 @@ impl std::str::FromStr for Sampling {
match input {
"random" => Ok(Sampling::random),
"last" => Ok(Sampling::last),
+ "stratified_random" => Ok(Sampling::stratified_random),
_ => Err(()),
}
}
@@ -25,6 +29,111 @@ impl std::string::ToString for Sampling {
match *self {
Sampling::random => "random".to_string(),
Sampling::last => "last".to_string(),
+ Sampling::stratified_random => "stratified_random".to_string(),
}
}
}
+
+impl Sampling {
+ // Implementing the sampling strategy in SQL
+ // Effectively orders the table according to the train/test split
+ // e.g. first N rows are train, last M rows are test
+ // where M is configured by the user
+ pub fn get_sql(&self, relation_name: &str, y_column_names: Vec) -> String {
+ let col_string = y_column_names
+ .iter()
+ .map(|c| c.quoted_name())
+ .collect::>()
+ .join(", ");
+ match *self {
+ Sampling::random => {
+ format!("SELECT {col_string} FROM {relation_name} ORDER BY RANDOM()")
+ }
+ Sampling::last => {
+ format!("SELECT {col_string} FROM {relation_name}")
+ }
+ Sampling::stratified_random => {
+ format!(
+ "
+ SELECT *
+ FROM (
+ SELECT
+ *,
+ ROW_NUMBER() OVER(PARTITION BY {col_string} ORDER BY RANDOM()) AS rn
+ FROM {relation_name}
+ ) AS subquery
+ ORDER BY rn, RANDOM();
+ "
+ )
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::orm::snapshot::{Preprocessor, Statistics};
+
+ use super::*;
+
+ fn get_column_fixtures() -> Vec {
+ vec![
+ Column {
+ name: "col1".to_string(),
+ pg_type: "text".to_string(),
+ nullable: false,
+ label: true,
+ position: 0,
+ size: 0,
+ array: false,
+ preprocessor: Preprocessor::default(),
+ statistics: Statistics::default(),
+ },
+ Column {
+ name: "col2".to_string(),
+ pg_type: "text".to_string(),
+ nullable: false,
+ label: true,
+ position: 0,
+ size: 0,
+ array: false,
+ preprocessor: Preprocessor::default(),
+ statistics: Statistics::default(),
+ },
+ ]
+ }
+
+ #[test]
+ fn test_get_sql_random_sampling() {
+ let sampling = Sampling::random;
+ let columns = get_column_fixtures();
+ let sql = sampling.get_sql("my_table", columns);
+ assert_eq!(sql, "SELECT \"col1\", \"col2\" FROM my_table ORDER BY RANDOM()");
+ }
+
+ #[test]
+ fn test_get_sql_last_sampling() {
+ let sampling = Sampling::last;
+ let columns = get_column_fixtures();
+ let sql = sampling.get_sql("my_table", columns);
+ assert_eq!(sql, "SELECT \"col1\", \"col2\" FROM my_table");
+ }
+
+ #[test]
+ fn test_get_sql_stratified_random_sampling() {
+ let sampling = Sampling::stratified_random;
+ let columns = get_column_fixtures();
+ let sql = sampling.get_sql("my_table", columns);
+ let expected_sql = "
+ SELECT *
+ FROM (
+ SELECT
+ *,
+ ROW_NUMBER() OVER(PARTITION BY \"col1\", \"col2\" ORDER BY RANDOM()) AS rn
+ FROM my_table
+ ) AS subquery
+ ORDER BY rn, RANDOM();
+ ";
+ assert_eq!(sql, expected_sql);
+ }
+}
diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs
index 6a5973148..9b478fe8a 100644
--- a/pgml-extension/src/orm/snapshot.rs
+++ b/pgml-extension/src/orm/snapshot.rs
@@ -119,7 +119,7 @@ pub(crate) struct Preprocessor {
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
-pub(crate) struct Column {
+pub struct Column {
pub(crate) name: String,
pub(crate) pg_type: String,
pub(crate) nullable: bool,
@@ -147,7 +147,7 @@ impl Column {
)
}
- fn quoted_name(&self) -> String {
+ pub(crate) fn quoted_name(&self) -> String {
format!(r#""{}""#, self.name)
}
@@ -608,13 +608,8 @@ impl Snapshot {
};
if materialized {
- let mut sql = format!(
- r#"CREATE TABLE "pgml"."snapshot_{}" AS SELECT * FROM {}"#,
- s.id, s.relation_name
- );
- if s.test_sampling == Sampling::random {
- sql += " ORDER BY random()";
- }
+ let sampled_query = s.test_sampling.get_sql(&s.relation_name, s.columns.clone());
+ let sql = format!(r#"CREATE TABLE "pgml"."snapshot_{}" AS {}"#, s.id, sampled_query);
client.update(&sql, None, None).unwrap();
}
snapshot = Some(s);
@@ -742,26 +737,7 @@ impl Snapshot {
}
fn select_sql(&self) -> String {
- format!(
- "SELECT {} FROM {} {}",
- self.columns
- .iter()
- .map(|c| c.quoted_name())
- .collect::>()
- .join(", "),
- self.relation_name(),
- match self.materialized {
- // If the snapshot is materialized, we already randomized it.
- true => "",
- false => {
- if self.test_sampling == Sampling::random {
- "ORDER BY random()"
- } else {
- ""
- }
- }
- },
- )
+ self.test_sampling.get_sql(&self.relation_name(), self.columns.clone())
}
fn train_test_split(&self, num_rows: usize) -> (usize, usize) {
From 58310256fae05912aedda7fda7c0aee1ee10f9ea Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Wed, 17 Jan 2024 22:29:24 +0000
Subject: [PATCH 5/9] handle case where exists has no metrics
---
pgml-extension/src/api.rs | 36 +++++++++++++++++++++---------------
1 file changed, 21 insertions(+), 15 deletions(-)
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index ada756c7a..1580de944 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -278,23 +278,29 @@ fn train_joint(
"Comparing to deployed model {}: {:?}",
default_target_metric, deployed_metric
);
- if let (Some(deployed_metric_value), Some(new_metric_value)) = (
- deployed_metric,
- new_metrics.get(&default_target_metric).and_then(|v| v.as_f64()),
- ) {
- if project.task.value_is_better(deployed_metric_value, new_metric_value) {
- warning!(
- "New model's {} is not better than current model. New: {}, Current {}",
- &project.task.default_target_metric(),
- new_metric_value,
- deployed_metric_value
- );
+ let new_metric = new_metrics.get(&default_target_metric).and_then(|v| v.as_f64());
+
+ match (deployed_metric, new_metric) {
+ (Some(deployed), Some(new)) => {
+ // only compare metrics when both new and old model have metrics to compare
+ if project.task.value_is_better(deployed, new) {
+ warning!(
+ "New model's {} is not better than current model. New: {}, Current {}",
+ &default_target_metric,
+ new,
+ deployed
+ );
+ deploy = false;
+ }
+ }
+ (None, None) => {
+ warning!("No metrics available for both deployed and new model. Deploying new model.")
+ }
+ (Some(_deployed), None) => {
+ warning!("No metrics for new model. Retaining old model.");
deploy = false;
}
- } else {
- warning!("Failed to retrieve or parse deployed/new metrics for {}. Ensure train/test split results in both positive and negative label records.",
- &project.task.default_target_metric());
- deploy = false;
+ (None, Some(_new)) => warning!("No metrics for deployed model. Deploying new model."),
}
} else {
warning!("Failed to parse deployed model metrics. Check data types of model metadata on pgml.models.metrics");
From 7710b13378adf5431cdc7afbec92368d55b590dc Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Wed, 17 Jan 2024 22:44:26 +0000
Subject: [PATCH 6/9] change default samping to stratified
---
pgml-extension/src/api.rs | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index 1580de944..17e3af172 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -100,7 +100,7 @@ fn train(
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'last'"),
+ test_sampling: default!(Sampling, "'stratified_random'"),
runtime: default!(Option, "NULL"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
@@ -146,7 +146,7 @@ fn train_joint(
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'last'"),
+ test_sampling: default!(Sampling, "'stratified_random'"),
runtime: default!(Option, "NULL"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
@@ -535,7 +535,7 @@ fn snapshot(
relation_name: &str,
y_column_name: &str,
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'last'"),
+ test_sampling: default!(Sampling, "'stratified_random'"),
preprocess: default!(JsonB, "'{}'"),
) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> {
Snapshot::create(
@@ -807,7 +807,7 @@ fn tune(
model_name: default!(Option<&str>, "NULL"),
hyperparams: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'last'"),
+ test_sampling: default!(Sampling, "'stratified_random'"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
) -> TableIterator<
From 9b2c44aa1d716aa6111ed22c5a2c7669cbde92de Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Fri, 19 Jan 2024 04:19:31 +0000
Subject: [PATCH 7/9] no rando when already materialized
---
pgml-extension/src/orm/sampling.rs | 8 ++++----
pgml-extension/src/orm/snapshot.rs | 15 ++++++++++++++-
2 files changed, 18 insertions(+), 5 deletions(-)
diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs
index 442f683ac..e0b3bf238 100644
--- a/pgml-extension/src/orm/sampling.rs
+++ b/pgml-extension/src/orm/sampling.rs
@@ -47,10 +47,10 @@ impl Sampling {
.join(", ");
match *self {
Sampling::random => {
- format!("SELECT {col_string} FROM {relation_name} ORDER BY RANDOM()")
+ format!("SELECT * FROM {relation_name} ORDER BY RANDOM()")
}
Sampling::last => {
- format!("SELECT {col_string} FROM {relation_name}")
+ format!("SELECT * FROM {relation_name}")
}
Sampling::stratified_random => {
format!(
@@ -108,7 +108,7 @@ mod tests {
let sampling = Sampling::random;
let columns = get_column_fixtures();
let sql = sampling.get_sql("my_table", columns);
- assert_eq!(sql, "SELECT \"col1\", \"col2\" FROM my_table ORDER BY RANDOM()");
+ assert_eq!(sql, "SELECT * FROM my_table ORDER BY RANDOM()");
}
#[test]
@@ -116,7 +116,7 @@ mod tests {
let sampling = Sampling::last;
let columns = get_column_fixtures();
let sql = sampling.get_sql("my_table", columns);
- assert_eq!(sql, "SELECT \"col1\", \"col2\" FROM my_table");
+ assert_eq!(sql, "SELECT * FROM my_table");
}
#[test]
diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs
index 9b478fe8a..12b9e9813 100644
--- a/pgml-extension/src/orm/snapshot.rs
+++ b/pgml-extension/src/orm/snapshot.rs
@@ -737,7 +737,20 @@ impl Snapshot {
}
fn select_sql(&self) -> String {
- self.test_sampling.get_sql(&self.relation_name(), self.columns.clone())
+ match self.materialized {
+ true => {
+ format!(
+ "SELECT {} FROM {}",
+ self.columns
+ .iter()
+ .map(|c| c.quoted_name())
+ .collect::>()
+ .join(", "),
+ self.relation_name()
+ )
+ }
+ false => self.test_sampling.get_sql(&self.relation_name(), self.columns.clone()),
+ }
}
fn train_test_split(&self, num_rows: usize) -> (usize, usize) {
From 489bce180212e8e8692b76b5a00b2231036c8621 Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Fri, 19 Jan 2024 04:32:03 +0000
Subject: [PATCH 8/9] update enum and function signatures
---
pgml-extension/sql/pgml--2.8.1--2.8.2.sql | 106 ++++++++++++++++++++++
pgml-extension/src/api.rs | 8 +-
pgml-extension/src/orm/sampling.rs | 12 +--
3 files changed, 116 insertions(+), 10 deletions(-)
diff --git a/pgml-extension/sql/pgml--2.8.1--2.8.2.sql b/pgml-extension/sql/pgml--2.8.1--2.8.2.sql
index 2c6264fb9..3f00d9b71 100644
--- a/pgml-extension/sql/pgml--2.8.1--2.8.2.sql
+++ b/pgml-extension/sql/pgml--2.8.1--2.8.2.sql
@@ -25,3 +25,109 @@ CREATE FUNCTION pgml."deploy"(
AS 'MODULE_PATHNAME', 'deploy_strategy_wrapper';
ALTER TYPE pgml.strategy ADD VALUE 'specific';
+
+-- src/orm/sampling.rs:6
+-- pgml::orm::sampling::Sampling
+DROP TYPE IF EXISTS pgml.Sampling;
+CREATE TYPE pgml.Sampling AS ENUM (
+ 'random',
+ 'last',
+ 'stratified'
+);
+
+-- src/api.rs:534
+-- pgml::api::snapshot
+DROP FUNCTION IF EXISTS pgml."snapshot"(text, text, real, pgml.Sampling, jsonb);
+CREATE FUNCTION pgml."snapshot"(
+ "relation_name" TEXT, /* &str */
+ "y_column_name" TEXT, /* &str */
+ "test_size" real DEFAULT 0.25, /* f32 */
+ "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
+ "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
+) RETURNS TABLE (
+ "relation" TEXT, /* alloc::string::String */
+ "y_column_name" TEXT /* alloc::string::String */
+)
+STRICT
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'snapshot_wrapper';
+
+-- src/api.rs:802
+-- pgml::api::tune
+DROP FUNCTION IF EXISTS pgml."tune"(text, text, text, text, text, jsonb, real, pgml.Sampling, bool, bool);
+CREATE FUNCTION pgml."tune"(
+ "project_name" TEXT, /* &str */
+ "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "model_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "test_size" real DEFAULT 0.25, /* f32 */
+ "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
+ "automatic_deploy" bool DEFAULT true, /* core::option::Option */
+ "materialize_snapshot" bool DEFAULT false /* bool */
+) RETURNS TABLE (
+ "status" TEXT, /* alloc::string::String */
+ "task" TEXT, /* alloc::string::String */
+ "algorithm" TEXT, /* alloc::string::String */
+ "deployed" bool /* bool */
+)
+PARALLEL SAFE
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'tune_wrapper';
+
+-- src/api.rs:92
+-- pgml::api::train
+DROP FUNCTION IF EXISTS pgml."train"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
+CREATE FUNCTION pgml."train"(
+ "project_name" TEXT, /* &str */
+ "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "y_column_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
+ "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "search" pgml.Search DEFAULT NULL, /* core::option::Option */
+ "search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "test_size" real DEFAULT 0.25, /* f32 */
+ "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
+ "runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option */
+ "automatic_deploy" bool DEFAULT true, /* core::option::Option */
+ "materialize_snapshot" bool DEFAULT false, /* bool */
+ "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
+) RETURNS TABLE (
+ "project" TEXT, /* alloc::string::String */
+ "task" TEXT, /* alloc::string::String */
+ "algorithm" TEXT, /* alloc::string::String */
+ "deployed" bool /* bool */
+)
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'train_wrapper';
+
+-- src/api.rs:138
+-- pgml::api::train_joint
+DROP FUNCTION IF EXISTS pgml."train_joint"(text, text, text, text, pgml.Algorithm, jsonb, pgml.Search, jsonb, jsonb, real, pgml.Sampling, pgml.Runtime, bool, bool, jsonb);
+CREATE FUNCTION pgml."train_joint"(
+ "project_name" TEXT, /* &str */
+ "task" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "relation_name" TEXT DEFAULT NULL, /* core::option::Option<&str> */
+ "y_column_name" TEXT[] DEFAULT NULL, /* core::option::Option> */
+ "algorithm" pgml.Algorithm DEFAULT 'linear', /* pgml::orm::algorithm::Algorithm */
+ "hyperparams" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "search" pgml.Search DEFAULT NULL, /* core::option::Option */
+ "search_params" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "search_args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
+ "test_size" real DEFAULT 0.25, /* f32 */
+ "test_sampling" pgml.Sampling DEFAULT 'stratified', /* pgml::orm::sampling::Sampling */
+ "runtime" pgml.Runtime DEFAULT NULL, /* core::option::Option */
+ "automatic_deploy" bool DEFAULT true, /* core::option::Option */
+ "materialize_snapshot" bool DEFAULT false, /* bool */
+ "preprocess" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */
+) RETURNS TABLE (
+ "project" TEXT, /* alloc::string::String */
+ "task" TEXT, /* alloc::string::String */
+ "algorithm" TEXT, /* alloc::string::String */
+ "deployed" bool /* bool */
+)
+LANGUAGE c /* Rust */
+AS 'MODULE_PATHNAME', 'train_joint_wrapper';
diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs
index 17e3af172..7fd5012c8 100644
--- a/pgml-extension/src/api.rs
+++ b/pgml-extension/src/api.rs
@@ -100,7 +100,7 @@ fn train(
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'stratified_random'"),
+ test_sampling: default!(Sampling, "'stratified'"),
runtime: default!(Option, "NULL"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
@@ -146,7 +146,7 @@ fn train_joint(
search_params: default!(JsonB, "'{}'"),
search_args: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'stratified_random'"),
+ test_sampling: default!(Sampling, "'stratified'"),
runtime: default!(Option, "NULL"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
@@ -535,7 +535,7 @@ fn snapshot(
relation_name: &str,
y_column_name: &str,
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'stratified_random'"),
+ test_sampling: default!(Sampling, "'stratified'"),
preprocess: default!(JsonB, "'{}'"),
) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> {
Snapshot::create(
@@ -807,7 +807,7 @@ fn tune(
model_name: default!(Option<&str>, "NULL"),
hyperparams: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
- test_sampling: default!(Sampling, "'stratified_random'"),
+ test_sampling: default!(Sampling, "'stratified'"),
automatic_deploy: default!(Option, true),
materialize_snapshot: default!(bool, false),
) -> TableIterator<
diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs
index e0b3bf238..2ecd66f5d 100644
--- a/pgml-extension/src/orm/sampling.rs
+++ b/pgml-extension/src/orm/sampling.rs
@@ -8,7 +8,7 @@ use super::snapshot::Column;
pub enum Sampling {
random,
last,
- stratified_random,
+ stratified,
}
impl std::str::FromStr for Sampling {
@@ -18,7 +18,7 @@ impl std::str::FromStr for Sampling {
match input {
"random" => Ok(Sampling::random),
"last" => Ok(Sampling::last),
- "stratified_random" => Ok(Sampling::stratified_random),
+ "stratified" => Ok(Sampling::stratified),
_ => Err(()),
}
}
@@ -29,7 +29,7 @@ impl std::string::ToString for Sampling {
match *self {
Sampling::random => "random".to_string(),
Sampling::last => "last".to_string(),
- Sampling::stratified_random => "stratified_random".to_string(),
+ Sampling::stratified => "stratified".to_string(),
}
}
}
@@ -52,7 +52,7 @@ impl Sampling {
Sampling::last => {
format!("SELECT * FROM {relation_name}")
}
- Sampling::stratified_random => {
+ Sampling::stratified => {
format!(
"
SELECT *
@@ -120,8 +120,8 @@ mod tests {
}
#[test]
- fn test_get_sql_stratified_random_sampling() {
- let sampling = Sampling::stratified_random;
+ fn test_get_sql_stratified_sampling() {
+ let sampling = Sampling::stratified;
let columns = get_column_fixtures();
let sql = sampling.get_sql("my_table", columns);
let expected_sql = "
From 0af75a8efdb6f287c09b2f965addbe32b4670427 Mon Sep 17 00:00:00 2001
From: Adam Hendel
Date: Thu, 29 Feb 2024 15:20:47 +0000
Subject: [PATCH 9/9] add upgrade test
---
.github/workflows/ci.yml | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index e9b0b1412..e1859d8aa 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -47,7 +47,12 @@ jobs:
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
run: |
git submodule update --init --recursive
+ - name: Get current version
+ id: current-version
+ run: echo "CI_BRANCH=$(git name-rev --name-only HEAD)" >> $GITHUB_OUTPUT
- name: Run tests
+ env:
+ CI_BRANCH: ${{ steps.current-version.outputs.CI_BRANCH }}
if: steps.pgml_extension_changed.outputs.PGML_EXTENSION_CHANGED_FILES != '0'
run: |
curl https://sh.rustup.rs -sSf | sh -s -- -y
@@ -58,8 +63,13 @@ jobs:
cargo pgrx init
fi
+ git checkout master
+ echo "\q" | cargo pgrx run
+ psql -p 28816 -h localhost -d pgml -P pager -c "CREATE EXTENSION pgml;"
+ git checkout $CI_BRANCH
+ echo "\q" | cargo pgrx run
+ psql -p 28816 -h localhost -d pgml -P pager -c "ALTER EXTENSION pgml UPDATE;"
cargo pgrx test
-
# cargo pgrx start
# psql -p 28815 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql
# cargo pgrx stop
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/postgresml/postgresml/pull/1336.patch
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy