Content-Length: 16770 | pFad | http://github.com/postgresml/postgresml/pull/1520.patch

thub.com From fb773bb307720ccfc7f7f566a448efb0b7c673bf Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 11 Jun 2024 11:18:02 -0700 Subject: [PATCH 1/2] fix and test preprocessing examples --- pgml-extension/examples/preprocessing.sql | 33 +++++++++++++++++++ .../src/bindings/transformers/mod.rs | 7 ++-- pgml-extension/src/orm/model.rs | 13 ++++---- pgml-extension/src/orm/sampling.rs | 2 +- pgml-extension/src/orm/snapshot.rs | 23 +++++++++---- pgml-extension/tests/test.sql | 1 + 6 files changed, 61 insertions(+), 18 deletions(-) create mode 100644 pgml-extension/examples/preprocessing.sql diff --git a/pgml-extension/examples/preprocessing.sql b/pgml-extension/examples/preprocessing.sql new file mode 100644 index 000000000..1e4d7b234 --- /dev/null +++ b/pgml-extension/examples/preprocessing.sql @@ -0,0 +1,33 @@ +-- load the diamonds dataset, that contains text categorical variables +SELECT pgml.load_dataset('jdxcosta/diamonds'); + +-- view the data +SELECT * FROM pgml."jdxcosta/diamonds" LIMIT 10; + +-- drop the Unamed column, since it's not useful for training (you could create a view instead) +ALTER TABLE pgml."jdxcosta/diamonds" DROP COLUMN "Unnamed: 0"; + +-- train a model using preprocessors to scale the numeric variables, and target encode the categoricals +SELECT pgml.train( + project_name => 'Diamond prices', + task => 'regression', + relation_name => 'pgml.jdxcosta/diamonds', + y_column_name => 'price', + algorithm => 'lightgbm', + preprocess => '{ + "carat": {"scale": "standard"}, + "depth": {"scale": "standard"}, + "table": {"scale": "standard"}, + "cut": {"encode": "target", "scale": "standard"}, + "color": {"encode": "target", "scale": "standard"}, + "clarity": {"encode": "target", "scale": "standard"} + }' +); + +-- run some predictions, notice we're passing a heterogeneous row (tuple) as input, rather than a homogenous ARRAY[]. +SELECT price, pgml.predict('Diamond prices', (carat, cut, color, clarity, depth, "table", x, y, z)) AS prediction +FROM pgml."jdxcosta/diamonds" +LIMIT 10; + +-- This is a difficult dataset for more algorithms, which makes it a good challenge for preprocessing, and additional +-- feature engineering. What's next? diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 6e529d394..cf3ba7fcc 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -380,7 +380,7 @@ pub fn load_dataset( .ok_or(anyhow!("dataset `data` key is not an object"))?; let column_names = types .iter() - .map(|(name, _type)| name.clone()) + .map(|(name, _type)| format!("\"{}\"", name) ) .collect::>() .join(", "); let column_types = types @@ -393,13 +393,14 @@ pub fn load_dataset( "int64" => "INT8", "int32" => "INT4", "int16" => "INT2", + "int8" => "INT2", "float64" => "FLOAT8", "float32" => "FLOAT4", "float16" => "FLOAT4", "bool" => "BOOLEAN", _ => bail!("unhandled dataset feature while reading dataset: {type_}"), }; - Ok(format!("{name} {type_}")) + Ok(format!("\"{name}\" {type_}")) }) .collect::>>()? .join(", "); @@ -455,7 +456,7 @@ pub fn load_dataset( .into_datum(), )), "dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())), - "int64" | "int32" | "int16" => row.push(( + "int64" | "int32" | "int16" | "int8" => row.push(( PgBuiltInOids::INT8OID.oid(), value .as_i64() diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index fb9eaae47..670e05651 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -344,13 +344,12 @@ impl Model { ).unwrap().first(); if !result.is_empty() { - let project_id = result.get(2).unwrap().unwrap(); - let project = Project::find(project_id).unwrap(); - let snapshot_id = result.get(3).unwrap().unwrap(); - let snapshot = Snapshot::find(snapshot_id).unwrap(); - let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).unwrap(); - let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).unwrap(); - + let project_id = result.get(2).unwrap().expect("project_id is i64"); + let project = Project::find(project_id).expect("project doesn't exist"); + let snapshot_id = result.get(3).unwrap().expect("snapshot_id is i64"); + let snapshot = Snapshot::find(snapshot_id).expect("snapshot doesn't exist"); + let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).expect("algorithm is malformed"); + let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).expect("runtime is malformed"); let data = Spi::get_one_with_args::>( " SELECT data diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs index 2ecd66f5d..831b74d5f 100644 --- a/pgml-extension/src/orm/sampling.rs +++ b/pgml-extension/src/orm/sampling.rs @@ -55,7 +55,7 @@ impl Sampling { Sampling::stratified => { format!( " - SELECT * + SELECT {col_string} FROM ( SELECT *, diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 4c8993ff6..b7470f6ad 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -230,16 +230,24 @@ impl Column { if self.preprocessor.encode == Encode::target { let categories = self.statistics.categories.as_mut().unwrap(); let mut sums = vec![0_f32; categories.len() + 1]; + let mut total = 0.; Zip::from(array).and(target).for_each(|&value, &target| { + total += target; sums[value as usize] += target; }); + let avg_target = total / categories.len() as f32; for category in categories.values_mut() { - let sum = sums[category.value as usize]; - category.value = sum / category.members as f32; + if category.members > 0 { + let sum = sums[category.value as usize]; + category.value = sum / category.members as f32; + } else { + // use avg target for categories w/ no members, e.g. __NULL__ category in a complete dataset + category.value = avg_target; + } } } - // Data is filtered for NaN because it is not well defined statistically, and they are counted as separate stat + // Data is filtered for NaN because it is not well-defined statistically, and they are counted as separate stat let mut data = array .iter() .filter_map(|n| if n.is_nan() { None } else { Some(*n) }) @@ -404,7 +412,7 @@ impl Snapshot { .first(); if !result.is_empty() { let jsonb: JsonB = result.get(7).unwrap().unwrap(); - let columns: Vec = serde_json::from_value(jsonb.0).unwrap(); + let columns: Vec = serde_json::from_value(jsonb.0).expect("invalid json description of columns"); // let jsonb: JsonB = result.get(8).unwrap(); // let analysis: Option> = Some(serde_json::from_value(jsonb.0).unwrap()); let mut s = Snapshot { @@ -500,9 +508,10 @@ impl Snapshot { let preprocessors: HashMap = serde_json::from_value(preprocess.0).expect("is valid"); + let mut position = 0; // Postgres column positions are not updated when other columns are dropped, but we expect consecutive positions when we read the table. Spi::connect(|mut client| { let mut columns: Vec = Vec::new(); - client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC", + client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC", None, Some(vec![ (PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()), @@ -520,7 +529,7 @@ impl Snapshot { pg_type = pg_type[1..].to_string() + "[]"; } let nullable = row[3].value::().unwrap().unwrap(); - let position = row[4].value::().unwrap().unwrap() as usize; + position += 1; let label = match y_column_name { Some(ref y_column_name) => y_column_name.contains(&name), None => false, @@ -1158,7 +1167,7 @@ impl Snapshot { pub fn numeric_encoded_dataset(&mut self) -> Dataset { let mut data = None; Spi::connect(|client| { - // Postgres Arrays arrays are 1 indexed and so are SPI tuples... + // Postgres arrays are 1 indexed and so are SPI tuples... let result = client.select(&self.select_sql(), None, None).unwrap(); let num_rows = result.len(); let (num_train_rows, num_test_rows) = self.train_test_split(num_rows); diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql index 2490678ee..2256e0ca4 100644 --- a/pgml-extension/tests/test.sql +++ b/pgml-extension/tests/test.sql @@ -30,5 +30,6 @@ SELECT pgml.load_dataset('wine'); \i examples/regression.sql \i examples/vectors.sql \i examples/chunking.sql +\i examples/preprocessing.sql -- transformers are generally too slow to run in the test suite --\i examples/transformers.sql From 7dff4aab22c553028687979aad6681b9fe07fbfd Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 11 Jun 2024 13:18:30 -0700 Subject: [PATCH 2/2] bump version for release --- .github/workflows/ubuntu-packages-and-docker-image.yml | 2 +- pgml-cms/docs/resources/developer-docs/contributing.md | 2 +- pgml-cms/docs/resources/developer-docs/installation.md | 2 +- .../docs/resources/developer-docs/quick-start-with-docker.md | 2 +- pgml-cms/docs/resources/developer-docs/self-hosting/pooler.md | 2 +- pgml-extension/Cargo.lock | 2 +- pgml-extension/Cargo.toml | 2 +- pgml-extension/sql/pgml--2.9.0--2.9.1.sql | 0 pgml-extension/src/bindings/transformers/mod.rs | 2 +- pgml-extension/src/orm/sampling.rs | 2 +- pgml-extension/src/orm/snapshot.rs | 3 ++- 11 files changed, 11 insertions(+), 10 deletions(-) create mode 100644 pgml-extension/sql/pgml--2.9.0--2.9.1.sql diff --git a/.github/workflows/ubuntu-packages-and-docker-image.yml b/.github/workflows/ubuntu-packages-and-docker-image.yml index 687b8dc4c..b493dd855 100644 --- a/.github/workflows/ubuntu-packages-and-docker-image.yml +++ b/.github/workflows/ubuntu-packages-and-docker-image.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: inputs: packageVersion: - default: "2.8.2" + default: "2.9.1" jobs: # # PostgresML extension. diff --git a/pgml-cms/docs/resources/developer-docs/contributing.md b/pgml-cms/docs/resources/developer-docs/contributing.md index a739d5ac9..59a3f3481 100644 --- a/pgml-cms/docs/resources/developer-docs/contributing.md +++ b/pgml-cms/docs/resources/developer-docs/contributing.md @@ -127,7 +127,7 @@ SELECT pgml.version(); postgres=# select pgml.version(); version ------------------- - 2.7.4 + 2.9.1 (1 row) ``` {% endtab %} diff --git a/pgml-cms/docs/resources/developer-docs/installation.md b/pgml-cms/docs/resources/developer-docs/installation.md index 03ae95ece..237b32fce 100644 --- a/pgml-cms/docs/resources/developer-docs/installation.md +++ b/pgml-cms/docs/resources/developer-docs/installation.md @@ -132,7 +132,7 @@ CREATE EXTENSION pgml_test=# SELECT pgml.version(); version --------- - 2.7.4 + 2.9.1 (1 row) ``` diff --git a/pgml-cms/docs/resources/developer-docs/quick-start-with-docker.md b/pgml-cms/docs/resources/developer-docs/quick-start-with-docker.md index a84c38999..bdfa1e8ce 100644 --- a/pgml-cms/docs/resources/developer-docs/quick-start-with-docker.md +++ b/pgml-cms/docs/resources/developer-docs/quick-start-with-docker.md @@ -80,7 +80,7 @@ Time: 41.520 ms postgresml=# SELECT pgml.version(); version --------- - 2.7.13 + 2.9.1 (1 row) ``` diff --git a/pgml-cms/docs/resources/developer-docs/self-hosting/pooler.md b/pgml-cms/docs/resources/developer-docs/self-hosting/pooler.md index 5887a9220..344fbd937 100644 --- a/pgml-cms/docs/resources/developer-docs/self-hosting/pooler.md +++ b/pgml-cms/docs/resources/developer-docs/self-hosting/pooler.md @@ -115,6 +115,6 @@ Type "help" for help. postgresml=> SELECT pgml.version(); version --------- - 2.7.9 + 2.9.1 (1 row) ``` diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 6c0f75838..76a5c60d1 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1746,7 +1746,7 @@ dependencies = [ [[package]] name = "pgml" -version = "2.9.0" +version = "2.9.1" dependencies = [ "anyhow", "blas", diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index 7787eb25c..3396ae2a5 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "2.9.0" +version = "2.9.1" edition = "2021" [lib] diff --git a/pgml-extension/sql/pgml--2.9.0--2.9.1.sql b/pgml-extension/sql/pgml--2.9.0--2.9.1.sql new file mode 100644 index 000000000..e69de29bb diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index cf3ba7fcc..a34e5bbbb 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -380,7 +380,7 @@ pub fn load_dataset( .ok_or(anyhow!("dataset `data` key is not an object"))?; let column_names = types .iter() - .map(|(name, _type)| format!("\"{}\"", name) ) + .map(|(name, _type)| format!("\"{}\"", name)) .collect::>() .join(", "); let column_types = types diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs index 831b74d5f..c48692394 100644 --- a/pgml-extension/src/orm/sampling.rs +++ b/pgml-extension/src/orm/sampling.rs @@ -125,7 +125,7 @@ mod tests { let columns = get_column_fixtures(); let sql = sampling.get_sql("my_table", columns); let expected_sql = " - SELECT * + SELECT \"col1\", \"col2\" FROM ( SELECT *, diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index b7470f6ad..7b1db546a 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -412,7 +412,8 @@ impl Snapshot { .first(); if !result.is_empty() { let jsonb: JsonB = result.get(7).unwrap().unwrap(); - let columns: Vec = serde_json::from_value(jsonb.0).expect("invalid json description of columns"); + let columns: Vec = + serde_json::from_value(jsonb.0).expect("invalid json description of columns"); // let jsonb: JsonB = result.get(8).unwrap(); // let analysis: Option> = Some(serde_json::from_value(jsonb.0).unwrap()); let mut s = Snapshot {








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/pull/1520.patch

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy