Skip to content

Commit a8d8218

Browse files
authored
fix and test preprocessing examples (#1520)
1 parent c3a8514 commit a8d8218

File tree

14 files changed

+70
-26
lines changed

14 files changed

+70
-26
lines changed

.github/workflows/ubuntu-packages-and-docker-image.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
workflow_dispatch:
55
inputs:
66
packageVersion:
7-
default: "2.8.2"
7+
default: "2.9.1"
88
jobs:
99
#
1010
# PostgresML extension.

pgml-cms/docs/resources/developer-docs/contributing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ SELECT pgml.version();
127127
postgres=# select pgml.version();
128128
version
129129
-------------------
130-
2.7.4
130+
2.9.1
131131
(1 row)
132132
```
133133
{% endtab %}

pgml-cms/docs/resources/developer-docs/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ CREATE EXTENSION
132132
pgml_test=# SELECT pgml.version();
133133
version
134134
---------
135-
2.7.4
135+
2.9.1
136136
(1 row)
137137
```
138138

pgml-cms/docs/resources/developer-docs/quick-start-with-docker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Time: 41.520 ms
8080
postgresml=# SELECT pgml.version();
8181
version
8282
---------
83-
2.7.13
83+
2.9.1
8484
(1 row)
8585
```
8686

pgml-cms/docs/resources/developer-docs/self-hosting/pooler.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,6 @@ Type "help" for help.
115115
postgresml=> SELECT pgml.version();
116116
version
117117
---------
118-
2.7.9
118+
2.9.1
119119
(1 row)
120120
```

pgml-extension/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.9.0"
3+
version = "2.9.1"
44
edition = "2021"
55

66
[lib]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
-- load the diamonds dataset, that contains text categorical variables
2+
SELECT pgml.load_dataset('jdxcosta/diamonds');
3+
4+
-- view the data
5+
SELECT * FROM pgml."jdxcosta/diamonds" LIMIT 10;
6+
7+
-- drop the Unamed column, since it's not useful for training (you could create a view instead)
8+
ALTER TABLE pgml."jdxcosta/diamonds" DROP COLUMN "Unnamed: 0";
9+
10+
-- train a model using preprocessors to scale the numeric variables, and target encode the categoricals
11+
SELECT pgml.train(
12+
project_name => 'Diamond prices',
13+
task => 'regression',
14+
relation_name => 'pgml.jdxcosta/diamonds',
15+
y_column_name => 'price',
16+
algorithm => 'lightgbm',
17+
preprocess => '{
18+
"carat": {"scale": "standard"},
19+
"depth": {"scale": "standard"},
20+
"table": {"scale": "standard"},
21+
"cut": {"encode": "target", "scale": "standard"},
22+
"color": {"encode": "target", "scale": "standard"},
23+
"clarity": {"encode": "target", "scale": "standard"}
24+
}'
25+
);
26+
27+
-- run some predictions, notice we're passing a heterogeneous row (tuple) as input, rather than a homogenous ARRAY[].
28+
SELECT price, pgml.predict('Diamond prices', (carat, cut, color, clarity, depth, "table", x, y, z)) AS prediction
29+
FROM pgml."jdxcosta/diamonds"
30+
LIMIT 10;
31+
32+
-- This is a difficult dataset for more algorithms, which makes it a good challenge for preprocessing, and additional
33+
-- feature engineering. What's next?

pgml-extension/sql/pgml--2.9.0--2.9.1.sql

Whitespace-only changes.

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ pub fn load_dataset(
380380
.ok_or(anyhow!("dataset `data` key is not an object"))?;
381381
let column_names = types
382382
.iter()
383-
.map(|(name, _type)| name.clone())
383+
.map(|(name, _type)| format!("\"{}\"", name))
384384
.collect::<Vec<String>>()
385385
.join(", ");
386386
let column_types = types
@@ -393,13 +393,14 @@ pub fn load_dataset(
393393
"int64" => "INT8",
394394
"int32" => "INT4",
395395
"int16" => "INT2",
396+
"int8" => "INT2",
396397
"float64" => "FLOAT8",
397398
"float32" => "FLOAT4",
398399
"float16" => "FLOAT4",
399400
"bool" => "BOOLEAN",
400401
_ => bail!("unhandled dataset feature while reading dataset: {type_}"),
401402
};
402-
Ok(format!("{name} {type_}"))
403+
Ok(format!("\"{name}\" {type_}"))
403404
})
404405
.collect::<Result<Vec<String>>>()?
405406
.join(", ");
@@ -455,7 +456,7 @@ pub fn load_dataset(
455456
.into_datum(),
456457
)),
457458
"dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())),
458-
"int64" | "int32" | "int16" => row.push((
459+
"int64" | "int32" | "int16" | "int8" => row.push((
459460
PgBuiltInOids::INT8OID.oid(),
460461
value
461462
.as_i64()

pgml-extension/src/orm/model.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -344,13 +344,12 @@ impl Model {
344344
).unwrap().first();
345345

346346
if !result.is_empty() {
347-
let project_id = result.get(2).unwrap().unwrap();
348-
let project = Project::find(project_id).unwrap();
349-
let snapshot_id = result.get(3).unwrap().unwrap();
350-
let snapshot = Snapshot::find(snapshot_id).unwrap();
351-
let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).unwrap();
352-
let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).unwrap();
353-
347+
let project_id = result.get(2).unwrap().expect("project_id is i64");
348+
let project = Project::find(project_id).expect("project doesn't exist");
349+
let snapshot_id = result.get(3).unwrap().expect("snapshot_id is i64");
350+
let snapshot = Snapshot::find(snapshot_id).expect("snapshot doesn't exist");
351+
let algorithm = Algorithm::from_str(result.get(4).unwrap().unwrap()).expect("algorithm is malformed");
352+
let runtime = Runtime::from_str(result.get(5).unwrap().unwrap()).expect("runtime is malformed");
354353
let data = Spi::get_one_with_args::<Vec<u8>>(
355354
"
356355
SELECT data

pgml-extension/src/orm/sampling.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl Sampling {
5555
Sampling::stratified => {
5656
format!(
5757
"
58-
SELECT *
58+
SELECT {col_string}
5959
FROM (
6060
SELECT
6161
*,
@@ -125,7 +125,7 @@ mod tests {
125125
let columns = get_column_fixtures();
126126
let sql = sampling.get_sql("my_table", columns);
127127
let expected_sql = "
128-
SELECT *
128+
SELECT \"col1\", \"col2\"
129129
FROM (
130130
SELECT
131131
*,

pgml-extension/src/orm/snapshot.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,24 @@ impl Column {
230230
if self.preprocessor.encode == Encode::target {
231231
let categories = self.statistics.categories.as_mut().unwrap();
232232
let mut sums = vec![0_f32; categories.len() + 1];
233+
let mut total = 0.;
233234
Zip::from(array).and(target).for_each(|&value, &target| {
235+
total += target;
234236
sums[value as usize] += target;
235237
});
238+
let avg_target = total / categories.len() as f32;
236239
for category in categories.values_mut() {
237-
let sum = sums[category.value as usize];
238-
category.value = sum / category.members as f32;
240+
if category.members > 0 {
241+
let sum = sums[category.value as usize];
242+
category.value = sum / category.members as f32;
243+
} else {
244+
// use avg target for categories w/ no members, e.g. __NULL__ category in a complete dataset
245+
category.value = avg_target;
246+
}
239247
}
240248
}
241249

242-
// Data is filtered for NaN because it is not well defined statistically, and they are counted as separate stat
250+
// Data is filtered for NaN because it is not well-defined statistically, and they are counted as separate stat
243251
let mut data = array
244252
.iter()
245253
.filter_map(|n| if n.is_nan() { None } else { Some(*n) })
@@ -404,7 +412,8 @@ impl Snapshot {
404412
.first();
405413
if !result.is_empty() {
406414
let jsonb: JsonB = result.get(7).unwrap().unwrap();
407-
let columns: Vec<Column> = serde_json::from_value(jsonb.0).unwrap();
415+
let columns: Vec<Column> =
416+
serde_json::from_value(jsonb.0).expect("invalid json description of columns");
408417
// let jsonb: JsonB = result.get(8).unwrap();
409418
// let analysis: Option<IndexMap<String, f32>> = Some(serde_json::from_value(jsonb.0).unwrap());
410419
let mut s = Snapshot {
@@ -500,9 +509,10 @@ impl Snapshot {
500509

501510
let preprocessors: HashMap<String, Preprocessor> = serde_json::from_value(preprocess.0).expect("is valid");
502511

512+
let mut position = 0; // Postgres column positions are not updated when other columns are dropped, but we expect consecutive positions when we read the table.
503513
Spi::connect(|mut client| {
504514
let mut columns: Vec<Column> = Vec::new();
505-
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN, ordinal_position::INTEGER FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
515+
client.select("SELECT column_name::TEXT, udt_name::TEXT, is_nullable::BOOLEAN FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position ASC",
506516
None,
507517
Some(vec![
508518
(PgBuiltInOids::TEXTOID.oid(), schema_name.into_datum()),
@@ -520,7 +530,7 @@ impl Snapshot {
520530
pg_type = pg_type[1..].to_string() + "[]";
521531
}
522532
let nullable = row[3].value::<bool>().unwrap().unwrap();
523-
let position = row[4].value::<i32>().unwrap().unwrap() as usize;
533+
position += 1;
524534
let label = match y_column_name {
525535
Some(ref y_column_name) => y_column_name.contains(&name),
526536
None => false,
@@ -1158,7 +1168,7 @@ impl Snapshot {
11581168
pub fn numeric_encoded_dataset(&mut self) -> Dataset {
11591169
let mut data = None;
11601170
Spi::connect(|client| {
1161-
// Postgres Arrays arrays are 1 indexed and so are SPI tuples...
1171+
// Postgres arrays are 1 indexed and so are SPI tuples...
11621172
let result = client.select(&self.select_sql(), None, None).unwrap();
11631173
let num_rows = result.len();
11641174
let (num_train_rows, num_test_rows) = self.train_test_split(num_rows);

pgml-extension/tests/test.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ SELECT pgml.load_dataset('wine');
3030
\i examples/regression.sql
3131
\i examples/vectors.sql
3232
\i examples/chunking.sql
33+
\i examples/preprocessing.sql
3334
-- transformers are generally too slow to run in the test suite
3435
--\i examples/transformers.sql

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy