Skip to content

fix and test preprocessing examples #1520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ubuntu-packages-and-docker-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
workflow_dispatch:
inputs:
packageVersion:
default: "2.8.2"
default: "2.9.1"
jobs:
#
# PostgresML extension.
Expand Down
2 changes: 1 addition & 1 deletion pgml-cms/docs/resources/developer-docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ SELECT pgml.version();
postgres=# select pgml.version();
version
-------------------
2.7.4
2.9.1
(1 row)
```
{% endtab %}
Expand Down
2 changes: 1 addition & 1 deletion pgml-cms/docs/resources/developer-docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ CREATE EXTENSION
pgml_test=# SELECT pgml.version();
version
---------
2.7.4
2.9.1
(1 row)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Time: 41.520 ms
postgresml=# SELECT pgml.version();
version
---------
2.7.13
2.9.1
(1 row)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,6 @@ Type "help" for help.
postgresml=> SELECT pgml.version();
version
---------
2.7.9
2.9.1
(1 row)
```
2 changes: 1 addition & 1 deletion pgml-extension/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pgml-extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pgml"
version = "2.9.0"
version = "2.9.1"
edition = "2021"

[lib]
Expand Down
33 changes: 33 additions & 0 deletions pgml-extension/examples/preprocessing.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
-- load the diamonds dataset, that contains text categorical variables
SELECT pgml.load_dataset('jdxcosta/diamonds');

-- view the data
SELECT * FROM pgml."jdxcosta/diamonds" LIMIT 10;

-- drop the Unamed column, since it's not useful for training (you could create a view instead)
ALTER TABLE pgml."jdxcosta/diamonds" DROP COLUMN "Unnamed: 0";

-- train a model using preprocessors to scale the numeric variables, and target encode the categoricals
SELECT pgml.train(
project_name => 'Diamond prices',
task => 'regression',
relation_name => 'pgml.jdxcosta/diamonds',
y_column_name => 'price',
algorithm => 'lightgbm',
preprocess => '{
"carat": {"scale": "standard"},
"depth": {"scale": "standard"},
"table": {"scale": "standard"},
"cut": {"encode": "target", "scale": "standard"},
"color": {"encode": "target", "scale": "standard"},
"clarity": {"encode": "target", "scale": "standard"}
}'
);

-- run some predictions, notice we're passing a heterogeneous row (tuple) as input, rather than a homogenous ARRAY[].
SELECT price, pgml.predict('Diamond prices', (carat, cut, color, clarity, depth, "table", x, y, z)) AS prediction
FROM pgml."jdxcosta/diamonds"
LIMIT 10;

-- This is a difficult dataset for more algorithms, which makes it a good challenge for preprocessing, and additional
-- feature engineering. What's next?
Empty file.
7 changes: 4 additions & 3 deletions pgml-extension/src/bindings/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ pub fn load_dataset(
.ok_or(anyhow!("dataset `data` key is not an object"))?;
let column_names = types
.iter()
.map(|(name, _type)| name.clone())
.map(|(name, _type)| format!("\"{}\"", name))
.collect::<Vec<String>>()
.join(", ");
let column_types = types
Expand All @@ -393,13 +393,14 @@ pub fn load_dataset(
"int64" => "INT8",
"int32" => "INT4",
"int16" => "INT2",
"int8" => "INT2",
"float64" => "FLOAT8",
"float32" => "FLOAT4",
"float16" => "FLOAT4",
"bool" => "BOOLEAN",
_ => bail!("unhandled dataset feature while reading dataset: {type_}"),
};
Ok(format!("{name} {type_}"))
Ok(format!("\"{name}\" {type_}"))
})
.collect::<Result<Vec<String>>>()?
.join(", ");
Expand Down Expand Up @@ -455,7 +456,7 @@ pub fn load_dataset(
.into_datum(),
)),
"dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())),
"int64" | "int32" | "int16" => row.push((
"int64" | "int32" | "int16" | "int8" => row.push((
PgBuiltInOids::INT8OID.oid(),
value
.as_i64()
Expand Down
13 changes: 6 additions & 7 deletions pgml-extension/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,12 @@ impl Model {
).unwrap().first();

if !result.is_empty() {
let project_id = result.get(2).unwrap().unwrap();
let project = Project::find(project_id).unwrap();
let snapshot_id = result.get(3).unwrap().unwrap();
let snapshot = Snapshot::find(snapshot_id).unwrap();
let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).unwrap();
let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).unwrap();

let project_id = result.get(2).unwrap().expect("project_id is i64");
let project = Project::find(project_id).expect("project doesn't exist");
let snapshot_id = result.get(3).unwrap().expect("snapshot_id is i64");
let snapshot = Snapshot::find(snapshot_id).expect("snapshot doesn't exist");
let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).expect("algorithm is malformed");
let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).expect("runtime is malformed");
let data = Spi::get_one_with_args::<Vec<u8>>(
"
SELECT data
Expand Down
4 changes: 2 additions & 2 deletions pgml-extension/src/orm/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Sampling {
Sampling::stratified => {
format!(
"
SELECT *
SELECT {col_string}
FROM (
SELECT
*,
Expand Down Expand Up @@ -125,7 +125,7 @@ mod tests {
let columns = get_column_fixtures();
let sql = sampling.get_sql("my_table", columns);
let expected_sql = "
SELECT *
SELECT \"col1\", \"col2\"
FROM (
SELECT
*,
Expand Down
24 changes: 17 additions & 7 deletions pgml-extension/src/orm/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,24 @@ impl Column {
if self.preprocessor.encode == Encode::target {
let categories = self.statistics.categories.as_mut().unwrap();
let mut sums = vec![0_f32; categories.len() + 1];
let mut total = 0.;
Zip::from(array).and(target).for_each(|&value, &target| {
total += target;
sums[value as usize] += target;
});
let avg_target = total / categories.len() as f32;
for category in categories.values_mut() {
let sum = sums[category.value as usize];
category.value = sum / category.members as f32;
if category.members > 0 {
let sum = sums[category.value as usize];
category.value = sum / category.members as f32;
} else {
// use avg target for categories w/ no members, e.g. __NULL__ category in a complete dataset
category.value = avg_target;
}
}
}

// Data is filtered for NaN because it is not well defined statistically, and they are counted as separate stat
// Data is filtered for NaN because it is not well-defined statistically, and they are counted as separate stat
let mut data = array
.iter()
.filter_map(|n| if n.is_nan() { None } else { Some(*n) })
Expand Down Expand Up @@ -404,7 +412,8 @@ impl Snapshot {
.first();
if !result.is_empty() {
let jsonb: JsonB = result.get(7).unwrap().unwrap();
let columns: Vec<Column> = serde_json::from_value(jsonb.0).unwrap();
let columns: Vec<Column> =
serde_json::from_value(jsonb.0).expect("invalid json description of columns");
// let jsonb: JsonB = result.get(8).unwrap();
// let analysis: Option<IndexMap<String, f32>> = Some(serde_json::from_value(jsonb.0).unwrap());
let mut s = Snapshot {
Expand Down Expand Up @@ -500,9 +509,10 @@ impl Snapshot {

let preprocessors: HashMap<String, Preprocessor> = serde_json::from_value(preprocess.0).expect("is valid");

let mut position = 0; // Postgres column positions are not updated when other columns are dropped, but we expect consecutive positions when we read the table.
Spi::connect(|mut client| {
let mut columns: Vec<Column> = Vec::new();
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
None,
Some(vec![
(PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()),
Expand All @@ -520,7 +530,7 @@ impl Snapshot {
pg_type = pg_type[1..].to_string() + "[]";
}
let nullable = row[3].value::<bool>().unwrap().unwrap();
let position = row[4].value::<i32>().unwrap().unwrap() as usize;
position += 1;
let label = match y_column_name {
Some(ref y_column_name) => y_column_name.contains(&name),
None => false,
Expand Down Expand Up @@ -1158,7 +1168,7 @@ impl Snapshot {
pub fn numeric_encoded_dataset(&mut self) -> Dataset {
let mut data = None;
Spi::connect(|client| {
// Postgres Arrays arrays are 1 indexed and so are SPI tuples...
// Postgres arrays are 1 indexed and so are SPI tuples...
let result = client.select(&self.select_sql(), None, None).unwrap();
let num_rows = result.len();
let (num_train_rows, num_test_rows) = self.train_test_split(num_rows);
Expand Down
1 change: 1 addition & 0 deletions pgml-extension/tests/test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ SELECT pgml.load_dataset('wine');
\i examples/regression.sql
\i examples/vectors.sql
\i examples/chunking.sql
\i examples/preprocessing.sql
-- transformers are generally too slow to run in the test suite
--\i examples/transformers.sql
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy