diff --git a/.github/workflows/package-extension.yml b/.github/workflows/package-extension.yml index 1eaa62c1a..2bfefea6c 100644 --- a/.github/workflows/package-extension.yml +++ b/.github/workflows/package-extension.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: inputs: packageVersion: - default: "2.6.0" + default: "2.7.0" jobs: build: diff --git a/pgml-dashboard/Cargo.lock b/pgml-dashboard/Cargo.lock index b1d4c4d29..753b0e4b0 100644 --- a/pgml-dashboard/Cargo.lock +++ b/pgml-dashboard/Cargo.lock @@ -2078,7 +2078,7 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pgml-dashboard" -version = "2.6.0" +version = "2.7.0" dependencies = [ "aho-corasick 0.7.20", "anyhow", diff --git a/pgml-dashboard/Cargo.toml b/pgml-dashboard/Cargo.toml index d0fcf4fb4..456118e4b 100644 --- a/pgml-dashboard/Cargo.toml +++ b/pgml-dashboard/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml-dashboard" -version = "2.6.0" +version = "2.7.0" edition = "2021" authors = ["PostgresML "] license = "MIT" diff --git a/pgml-dashboard/content/docs/guides/setup/developers.md b/pgml-dashboard/content/docs/guides/setup/developers.md index 8d7da7f3d..4eb3226ff 100644 --- a/pgml-dashboard/content/docs/guides/setup/developers.md +++ b/pgml-dashboard/content/docs/guides/setup/developers.md @@ -127,7 +127,7 @@ SELECT pgml.version(); postgres=# select pgml.version(); version ------------------- - 2.6.0 + 2.7.0 (1 row) ``` diff --git a/pgml-dashboard/content/docs/guides/setup/v2/installation.md b/pgml-dashboard/content/docs/guides/setup/v2/installation.md index 23ad76557..06bd87039 100644 --- a/pgml-dashboard/content/docs/guides/setup/v2/installation.md +++ b/pgml-dashboard/content/docs/guides/setup/v2/installation.md @@ -217,7 +217,7 @@ SELECT pgml.version(); postgres=# select pgml.version(); version ------------------- - 2.6.0 + 2.7.0 (1 row) ``` diff --git a/pgml-dashboard/content/docs/guides/training/algorithm_selection.md b/pgml-dashboard/content/docs/guides/training/algorithm_selection.md index 38cc48eda..23aa51a1d 100644 --- a/pgml-dashboard/content/docs/guides/training/algorithm_selection.md +++ b/pgml-dashboard/content/docs/guides/training/algorithm_selection.md @@ -2,7 +2,7 @@ We currently support regression and classification algorithms from [scikit-learn](https://scikit-learn.org/), [XGBoost](https://xgboost.readthedocs.io/), and [LightGBM](https://lightgbm.readthedocs.io/). -## Algorithms +## Supervised Algorithms ### Gradient Boosting Algorithm | Regression | Classification @@ -54,6 +54,18 @@ Algorithm | Regression | Classification `kernel_ridge` | [KernelRidge](https://scikit-learn.org/stable/modules/generated/sklearn.kernel_ridge.KernelRidge.html) | - `gaussian_process` | [GaussianProcessRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessRegressor.html) | [GaussianProcessClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessClassifier.html) +## Unsupervised Algorithms + +### Clustering + +|Algorithm | Reference | +|---|-------------------------------------------------------------------------------------------------------------------| +`affinity_propagation` | [AffinityPropagation](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AffinityPropagation.html) +`birch` | [Birch](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html) +`kmeans` | [K-Means](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html) +`mini_batch_kmeans` | [MiniBatchKMeans](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html) + + ## Comparing Algorithms Any of the above algorithms can be passed to our `pgml.train()` function using the `algorithm` parameter. If the parameter is omitted, linear regression is used by default. diff --git a/pgml-dashboard/src/models.rs b/pgml-dashboard/src/models.rs index 11f2e9563..c0b1f22db 100644 --- a/pgml-dashboard/src/models.rs +++ b/pgml-dashboard/src/models.rs @@ -57,6 +57,7 @@ impl Project { "summarization" => Ok("rouge_ngram_f1"), "translation" => Ok("bleu"), "text_generation" | "text2text" => Ok("perplexity"), + "cluster" => Ok("silhouette"), task => Err(anyhow::anyhow!("Unhandled task: {}", task)), } } @@ -68,6 +69,7 @@ impl Project { "summarization" => Ok("Rouge Ngram F1"), "translation" => Ok("Bleu"), "text_generation" | "text2text" => Ok("Perplexity"), + "cluster" => Ok("silhouette"), task => Err(anyhow::anyhow!("Unhandled task: {}", task)), } } @@ -544,7 +546,7 @@ impl Model { pub struct Snapshot { pub id: i64, pub relation_name: String, - pub y_column_name: Vec, + pub y_column_name: Option>, pub test_size: f32, pub test_sampling: Option, pub status: String, @@ -686,28 +688,42 @@ impl Snapshot { } } - pub fn features<'a>(&'a self) -> Option>> { + pub fn features(&self) -> Option>> { match self.columns() { - Some(columns) => Some( - columns - .into_iter() - .filter(|column| { - !self - .y_column_name - .contains(&column["name"].as_str().unwrap().to_string()) - }) - .collect(), - ), + Some(columns) => { + if self.y_column_name.is_none() { + return Some(columns.into_iter().collect()); + } + + Some( + columns + .into_iter() + .filter(|column| { + !self + .y_column_name + .as_ref() + .unwrap() + .contains(&column["name"].as_str().unwrap().to_string()) + }) + .collect(), + ) + } None => None, } } - pub fn labels<'a>(&'a self) -> Option>> { + pub fn labels(&self) -> Option>> { + if self.y_column_name.is_none() { + return Some(Vec::new()); + } + self.columns().map(|columns| { columns .into_iter() .filter(|column| { self.y_column_name + .as_ref() + .unwrap() .contains(&column["name"].as_str().unwrap().to_string()) }) .collect() diff --git a/pgml-dashboard/templates/content/dashboard/panels/snapshot.html b/pgml-dashboard/templates/content/dashboard/panels/snapshot.html index 86adcf47f..9963a90f5 100644 --- a/pgml-dashboard/templates/content/dashboard/panels/snapshot.html +++ b/pgml-dashboard/templates/content/dashboard/panels/snapshot.html @@ -73,9 +73,11 @@

bubble_chartFeatures

%>

<%= name %> <%= feature["pg_type"].as_str().unwrap() | upper %>

- <% for y_column_name in snapshot.y_column_name.iter() { %> + <% if snapshot.y_column_name.as_ref().is_some() { %> + <% for y_column_name in snapshot.y_column_name.as_ref().unwrap().iter() { %>
<% } %> + <% } %> <% } %> @@ -102,11 +104,13 @@

<%= name %> <%= feature["pg_type"].as_str().unwrap() | upper %>, <%= model.key_metric(project).unwrap() %>, [0, 1]); <% } %> - <% for y_column_name in snapshot.y_column_name.iter() { %> + <% if snapshot.y_column_name.as_ref().is_some() { %> + <% for y_column_name in snapshot.y_column_name.as_ref().unwrap().iter() { %> setTimeout(renderDistribution, delay, "<%= y_column_name %>", <%= y_column_name %>_samples, NaN); setTimeout(renderOutliers, delay, "<%= y_column_name %>", <%= y_column_name %>_samples, <%= snapshot.target_stddev(y_column_name) %>) <% } %> - + <% } %> + var delay = 600; <% for feature in snapshot.features().unwrap().iter() { @@ -116,9 +120,11 @@

<%= name %> <%= feature["pg_type"].as_str().unwrap() | upper %>", <%= name_machine %>_samples, NaN); - <% for y_column_name in snapshot.y_column_name.iter() { %> + <% if snapshot.y_column_name.as_ref().is_some() { %> + <% for y_column_name in snapshot.y_column_name.as_ref().unwrap().iter() { %> setTimeout(renderCorrelation, delay, "<%= name_machine %>", "<%= y_column_name %>", <%= name_machine %>_samples, <%= y_column_name %>_samples); <% } %> + <% } %> <% } %> } renderCharts(); diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 137cbfbda..ff4235028 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -89,7 +89,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -166,7 +166,7 @@ dependencies = [ "log", "peeking_take_while", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "regex", "rustc-hash", "shlex 0.1.1", @@ -187,7 +187,7 @@ dependencies = [ "log", "peeking_take_while", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "regex", "rustc-hash", "shlex 1.1.0", @@ -197,19 +197,18 @@ dependencies = [ [[package]] name = "bindgen" -version = "0.65.1" +version = "0.66.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfdf7b466f9a4903edc73f95d6d2bcd5baf8ae620638762244d3f60143643cc5" +checksum = "f2b84e06fc203107bfbad243f4aba2af864eb7db3b1cf46ea0a023b0b433d2a7" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.3.3", "cexpr 0.6.0", "clang-sys", "lazy_static", "lazycell", "peeking_take_while", - "prettyplease", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "regex", "rustc-hash", "shlex 1.1.0", @@ -224,9 +223,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.1" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6776fc96284a0bb647b615056fc496d1fe1644a7ab01829818a6d91cae888b84" +checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" [[package]] name = "bitvec" @@ -405,7 +404,7 @@ checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" dependencies = [ "heck", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -553,7 +552,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd4056f63fce3b82d852c3da92b08ea59959890813a7f4ce9c0ff85b10cf301b" dependencies = [ - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -576,7 +575,7 @@ dependencies = [ "fnv", "ident_case", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "strsim 0.10.0", "syn 1.0.109", ] @@ -588,7 +587,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ "darling_core", - "quote 1.0.28", + "quote 1.0.29", "syn 1.0.109", ] @@ -630,7 +629,7 @@ checksum = "1f91d4cfa921f1c05904dc3c57b4a32c38aed3340cce209f3a6fd1478babafc4" dependencies = [ "darling", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 1.0.109", ] @@ -745,7 +744,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a4da76b3b6116d758c7ba93f7ec6a35d2e2cf24feda76c6e38a375f4d5c59f2" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 1.0.109", ] @@ -762,6 +761,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "equivalent" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88bffebc5d80432c9b140ee17875ff173a8ab62faad5b257da912bd2f6c1c0a1" + [[package]] name = "erased-serde" version = "0.3.25" @@ -904,7 +909,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -963,7 +968,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e77ac7b51b8e6313251737fcef4b1c01a2ea102bde68415b62c0ee9268fec357" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -994,6 +999,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "heapless" version = "0.7.16" @@ -1081,10 +1092,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", "serde", ] +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", +] + [[package]] name = "indoc" version = "1.0.9" @@ -1159,9 +1180,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.145" +version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc86cde3ff845662b8f4ef6cb50ea0e20c524eb3d29ae048287e06a1b3fa6a81" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "libloading" @@ -1444,7 +1465,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af5a8477ac96877b5bd1fd67e0c28736c12943aba24eda92b127e036b0c8f400" dependencies = [ - "indexmap", + "indexmap 1.9.3", "itertools", "ndarray", "noisy_float", @@ -1641,7 +1662,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -1749,12 +1770,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" dependencies = [ "fixedbitset", - "indexmap", + "indexmap 1.9.3", ] [[package]] name = "pgml" -version = "2.6.0" +version = "2.7.0" dependencies = [ "anyhow", "blas", @@ -1762,7 +1783,7 @@ dependencies = [ "csv", "flate2", "heapless", - "indexmap", + "indexmap 1.9.3", "itertools", "lightgbm", "linfa", @@ -1789,12 +1810,12 @@ dependencies = [ [[package]] name = "pgrx" -version = "0.9.2" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0162bfa4dc9541cef1bf08cb7f0ed4a5ba65af9705d8457e75a77048bc584912" +checksum = "6186d4aa5911be4c00b52e555779deece35a7563c87fcfe794407dc2e9cc4dc1" dependencies = [ "atomic-traits", - "bitflags 2.3.1", + "bitflags 2.3.3", "bitvec", "enum-map", "heapless", @@ -1814,21 +1835,21 @@ dependencies = [ [[package]] name = "pgrx-macros" -version = "0.9.2" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7af5e583e8da2b77b81245ceab25e419d2d2fd05ed77338ef92bfde751a3b081" +checksum = "479a66a8c582e0fdf101178473315cb13eaa10829c154db742c35ec0279cdaec" dependencies = [ "pgrx-sql-entity-graph", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 1.0.109", ] [[package]] name = "pgrx-pg-config" -version = "0.9.2" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca08f088b133b83de7bfb9ad4953d763af4a3dcbcc326d73edbb47b77ba8032f" +checksum = "1e45c557631217a13859e223899c01d35982ef0c860ee5ab65af496f830b1316" dependencies = [ "cargo_toml", "dirs 5.0.1", @@ -1844,11 +1865,11 @@ dependencies = [ [[package]] name = "pgrx-pg-sys" -version = "0.9.2" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f328527c00c9a3c728a6e6fddb34b0ab0342203c3f68f9c1224fef5c6b572b" +checksum = "0dde896a17c638b6475d6fc12b571a176013a8486437bbc8a64ac2afb8ba5d58" dependencies = [ - "bindgen 0.65.1", + "bindgen 0.66.1", "eyre", "libc", "memoffset 0.9.0", @@ -1857,7 +1878,7 @@ dependencies = [ "pgrx-pg-config", "pgrx-sql-entity-graph", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "serde", "shlex 1.1.0", "sptr", @@ -1866,24 +1887,24 @@ dependencies = [ [[package]] name = "pgrx-sql-entity-graph" -version = "0.9.2" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c143e91adbfd63d947389f90515c4b5502b4d9368a412b041192a2b2b43fa46" +checksum = "b1e9abc71b018d90aa9b7a34fedf48b76da5d55c04d2ed2288096827bebbf403" dependencies = [ "convert_case", "eyre", "petgraph", "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 1.0.109", "unescape", ] [[package]] name = "pgrx-tests" -version = "0.9.2" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a605a0730e0e51cb825bab00bb1750a0513ae1fc6c72dd3a533981d06aff712f" +checksum = "39ac4ffedfa247f9d51421e4e2ac18c33d8d674350bad730f3fe5736bf298612" dependencies = [ "clap-cargo", "eyre", @@ -1986,21 +2007,11 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" -[[package]] -name = "prettyplease" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b69d39aab54d069e7f2fe8cb970493e7834601ca2d8c65fd7bbd183578080d1" -dependencies = [ - "proc-macro2", - "syn 2.0.18", -] - [[package]] name = "proc-macro2" -version = "1.0.59" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" +checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" dependencies = [ "unicode-ident", ] @@ -2050,7 +2061,7 @@ checksum = "94144a1266e236b1c932682136dc35a9dee8d3589728f68130c7c3861ef96b28" dependencies = [ "proc-macro2", "pyo3-macros-backend", - "quote 1.0.28", + "quote 1.0.29", "syn 1.0.109", ] @@ -2061,7 +2072,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8df9be978a2d2f0cdebabb03206ed73b11314701a5bfe71b0d753b81997777f" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 1.0.109", ] @@ -2073,9 +2084,9 @@ checksum = "7a6e920b65c65f10b2ae65c831a81a073a89edd28c7cce89475bff467ab4167a" [[package]] name = "quote" -version = "1.0.28" +version = "1.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" dependencies = [ "proc-macro2", ] @@ -2414,7 +2425,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -2424,7 +2435,7 @@ version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" dependencies = [ - "indexmap", + "indexmap 1.9.3", "itoa", "ryu", "serde", @@ -2432,9 +2443,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93107647184f6027e3b7dcb2e11034cf95ffa1e3a682c67951963ac69c1c007d" +checksum = "96426c9936fd7a0124915f9185ea1d20aa9445cc9821142f0a73bc9207a2e186" dependencies = [ "serde", ] @@ -2645,7 +2656,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "unicode-ident", ] @@ -2656,7 +2667,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "unicode-ident", ] @@ -2671,9 +2682,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.29.0" +version = "0.29.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02f1dc6930a439cc5d154221b5387d153f8183529b07c19aca24ea31e0a167e1" +checksum = "5bcd0346f90b6bc83526c7b180039a8acd26a5c848cc556d457f6472eb148122" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2771,7 +2782,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -2884,9 +2895,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6135d499e69981f9ff0ef2167955a5333c35e36f6937d382974566b3d5b94ec" +checksum = "1ebafdf5ad1220cb59e7d17cf4d2c72015297b75b19a10472f99b89225089240" dependencies = [ "serde", "serde_spanned", @@ -2896,20 +2907,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a76a9312f5ba4c2dec6b9161fdf25d87ad8a09256ccea5a556fef03c706a10f" +checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.19.10" +version = "0.19.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380d56e8670370eee6566b0bfd4265f65b3f432e8c6d85623f728d4fa31f739" +checksum = "266f016b7f039eec8a1a80dfe6156b633d208b9fccca5e4db1d6775b0c4e34a7" dependencies = [ - "indexmap", + "indexmap 2.0.0", "serde", "serde_spanned", "toml_datetime", @@ -2962,7 +2973,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c3e1c30cedd24fc597f7d37a721efdbdc2b1acae012c1ef1218f4c7c2c0f3e7" dependencies = [ "proc-macro2", - "quote 1.0.28", + "quote 1.0.29", "syn 2.0.18", ] @@ -3051,9 +3062,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.3.3" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2" +checksum = "d023da39d1fde5a8a3fe1f3e01ca9632ada0a63e9797de55a879d6e2236277be" dependencies = [ "getrandom", ] @@ -3323,7 +3334,7 @@ version = "0.2.0" source = "git+https://github.com/postgresml/rust-xgboost.git?branch=master#1c396f79d6275c51a0e1dcf8dea9200cb6cee546" dependencies = [ "derive_builder 0.11.2", - "indexmap", + "indexmap 1.9.3", "libc", "log", "tempfile", diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index f7e6c9d65..0a548aa05 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "2.6.0" +version = "2.7.0" edition = "2021" [lib] diff --git a/pgml-extension/Dockerfile b/pgml-extension/Dockerfile index 6af7f925b..b9c80c881 100644 --- a/pgml-extension/Dockerfile +++ b/pgml-extension/Dockerfile @@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.1-devel-ubuntu22.04 LABEL maintainer="team@postgresml.com" ARG DEBIAN_FRONTEND=noninteractive -ARG PGML_VERSION=2.6.0 +ARG PGML_VERSION=2.7.0 ENV TZ=Etc/UTC ENV PATH="/usr/local/cuda/bin:${PATH}" diff --git a/pgml-extension/examples/cluster.sql b/pgml-extension/examples/cluster.sql new file mode 100644 index 000000000..72db877b5 --- /dev/null +++ b/pgml-extension/examples/cluster.sql @@ -0,0 +1,43 @@ +-- This example trains models on the sklean digits dataset +-- which is a copy of the test set of the UCI ML hand-written digits datasets +-- https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits +-- +-- This demonstrates using a table with a single array feature column +-- for clustering. You could do something similar with a vector column +-- + +-- Exit on error (psql) +-- \set ON_ERROR_STOP true +\timing on + +SELECT pgml.load_dataset('digits'); + +-- create an unlabeled table of the images for unsupervised learning +CREATE VIEW pgml.digit_vectors AS +SELECT image FROM pgml.digits; + +-- view the dataset +SELECT left(image::text, 40) || ',...}' FROM pgml.digit_vectors LIMIT 10; + +-- train a simple model to classify the data +SELECT * FROM pgml.train('Handwritten Digit Clusters', 'cluster', 'pgml.digit_vectors', hyperparams => '{"n_clusters": 10}'); + +-- check out the predictions +SELECT target, pgml.predict('Handwritten Digit Clusters', image) AS prediction +FROM pgml.digits +LIMIT 10; + +SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'affinity_propagation', hyperparams => '{"n_clusters": 10}'); +SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'birch', hyperparams => '{"n_clusters": 10}'); +SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'kmeans', hyperparams => '{"n_clusters": 10}'); +SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'mini_batch_kmeans', hyperparams => '{"n_clusters": 10}'); + +-- Offline clustering algorithms are not currently supported +--SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'dbscan'); +--SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'feature_agglomeration'); +--SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'optics'); +--SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'spectral', hyperparams => '{"n_clusters": 10}'); +--SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'spectral_bi', hyperparams => '{"n_clusters": 10}'); +--SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'spectral_co', hyperparams => '{"n_clusters": 10}'); +--SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'mean_shift'); + diff --git a/pgml-extension/sql/pgml--2.6.0--2.7.0.sql b/pgml-extension/sql/pgml--2.6.0--2.7.0.sql new file mode 100644 index 000000000..4fabfada5 --- /dev/null +++ b/pgml-extension/sql/pgml--2.6.0--2.7.0.sql @@ -0,0 +1,19 @@ +ALTER TABLE pgml.snapshots ALTER COLUMN y_column_name DROP NOT NULL; + +ALTER FUNCTION pgml.embed(text, text[], jsonb) COST 1000000; +ALTER FUNCTION pgml.embed(text, text, jsonb) COST 1000000; +ALTER FUNCTION pgml.transform(jsonb, jsonb, text[], boolean) COST 10000000; +ALTER FUNCTION pgml.transform(text, jsonb, text[], boolean) COST 10000000; + +ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'cluster'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'affinity_propagation'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'agglomerative'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'birch'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'feature_agglomeration'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'mini_batch_kmeans'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'mean_shift'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'optics'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'spectral'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'spectral_bi'; +ALTER TYPE pgml.algorithm ADD VALUE IF NOT EXISTS 'spectral_co'; + diff --git a/pgml-extension/sql/schema.sql b/pgml-extension/sql/schema.sql index 876824c38..fc2ea0694 100644 --- a/pgml-extension/sql/schema.sql +++ b/pgml-extension/sql/schema.sql @@ -63,7 +63,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS projects_name_idx ON pgml.projects(name); CREATE TABLE IF NOT EXISTS pgml.snapshots( id BIGSERIAL PRIMARY KEY, relation_name TEXT NOT NULL, - y_column_name TEXT[] NOT NULL, + y_column_name TEXT[], test_size FLOAT4 NOT NULL, test_sampling pgml.sampling NOT NULL, status TEXT NOT NULL, diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index b67b32779..c192d86f2 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -188,7 +188,7 @@ fn train_joint( Some(project) => project, None => Project::create(project_name, match task { Some(task) => task, - None => error!("Project `{}` does not exist. To create a new project, provide the task (regression or classification).", project_name), + None => error!("Project `{}` does not exist. To create a new project, you must specify a `task`.", project_name), }), }; @@ -213,10 +213,13 @@ fn train_joint( relation_name ); + if project.task.is_supervised() && y_column_name.is_none() { + error!("You must pass a `y_column_name` when you pass a `relation_name` for a supervised task."); + } + let snapshot = Snapshot::create( relation_name, - y_column_name - .expect("You must pass a `y_column_name` when you pass a `relation_name`"), + y_column_name, test_size, test_sampling, materialize_snapshot, @@ -235,6 +238,13 @@ fn train_joint( } }; + // fix up default algorithm for clustering + let algorithm = if algorithm == Algorithm::linear && project.task == Task::cluster { + Algorithm::kmeans + } else { + algorithm + }; + // # Default repeatable random state when possible // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: @@ -273,24 +283,25 @@ fn train_joint( Some(true) | None => { if let Ok(Some(deployed_metrics)) = deployed_metrics { let deployed_metrics = deployed_metrics.0.as_object().unwrap(); - match project.task { - Task::classification => { - if deployed_metrics.get("f1").unwrap().as_f64() - > new_metrics.get("f1").unwrap().as_f64() - { - deploy = false; - } - } - Task::regression => { - if deployed_metrics.get("r2").unwrap().as_f64() - > new_metrics.get("r2").unwrap().as_f64() - { - deploy = false; - } - } - _ => error!( - "Training only supports `classification` and `regression` task types." - ), + let deployed_metric = deployed_metrics + .get(&project.task.default_target_metric()) + .unwrap() + .as_f64() + .unwrap(); + info!( + "Comparing to deployed model {}: {:?}", + project.task.default_target_metric(), + deployed_metric + ); + if project.task.value_is_better( + deployed_metric, + new_metrics + .get(&project.task.default_target_metric()) + .unwrap() + .as_f64() + .unwrap(), + ) { + deploy = false; } } } @@ -300,6 +311,8 @@ fn train_joint( if deploy { project.deploy(model.id); + } else { + warning!("Not deploying newly trained model."); } TableIterator::new( @@ -347,38 +360,13 @@ fn deploy( ); } match strategy { - Strategy::best_score => match task { - Task::classification | Task::question_answering | Task::text_classification => { - let _ = write!( - sql, - "{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST" - ); - } - Task::regression => { - let _ = write!( - sql, - "{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST" - ); - } - Task::summarization => { - let _ = write!( - sql, - "{predicate}\nORDER BY models.metrics->>'rouge_ngram_f1' DESC NULLS LAST" - ); - } - Task::text_generation | Task::text2text => { - let _ = write!( - sql, - "{predicate}\nORDER BY models.metrics->>'perplexity' ASC NULLS LAST" - ); - } - Task::translation => { - let _ = write!( - sql, - "{predicate}\nORDER BY models.metrics->>'bleu' DESC NULLS LAST" - ); - } - }, + Strategy::best_score => { + let _ = write!( + sql, + "{predicate}\n{}", + task.default_target_metric_sql_order() + ); + } Strategy::most_recent => { let _ = write!(sql, "{predicate}\nORDER by models.created_at DESC"); @@ -528,7 +516,7 @@ fn snapshot( ) -> TableIterator<'static, (name!(relation, String), name!(y_column_name, String))> { Snapshot::create( relation_name, - vec![y_column_name.to_string()], + Some(vec![y_column_name.to_string()]), test_size, test_sampling, true, @@ -733,9 +721,9 @@ fn tune( let snapshot = Snapshot::create( relation_name, - vec![y_column_name + Some(vec![y_column_name .expect("You must pass a `y_column_name` when you pass a `relation_name`") - .to_string()], + .to_string()]), test_size, test_sampling, materialize_snapshot, @@ -787,42 +775,19 @@ fn tune( Some(true) | None => { if let Ok(Some(deployed_metrics)) = deployed_metrics { let deployed_metrics = deployed_metrics.0.as_object().unwrap(); - match project.task { - Task::classification | Task::question_answering | Task::text_classification => { - if deployed_metrics.get("f1").unwrap().as_f64() - > new_metrics.get("f1").unwrap().as_f64() - { - deploy = false; - } - } - Task::regression => { - if deployed_metrics.get("r2").unwrap().as_f64() - > new_metrics.get("r2").unwrap().as_f64() - { - deploy = false; - } - } - Task::translation => { - if deployed_metrics.get("bleu").unwrap().as_f64() - > new_metrics.get("bleu").unwrap().as_f64() - { - deploy = false; - } - } - Task::summarization => { - if deployed_metrics.get("rouge_ngram_f1").unwrap().as_f64() - > new_metrics.get("rouge_ngram_f1").unwrap().as_f64() - { - deploy = false; - } - } - Task::text_generation | Task::text2text => { - if deployed_metrics.get("perplexity").unwrap().as_f64() - < new_metrics.get("perplexity").unwrap().as_f64() - { - deploy = false; - } - } + if project.task.value_is_better( + deployed_metrics + .get(&project.task.default_target_metric()) + .unwrap() + .as_f64() + .unwrap(), + new_metrics + .get(&project.task.default_target_metric()) + .unwrap() + .as_f64() + .unwrap(), + ) { + deploy = false; } } } @@ -991,7 +956,7 @@ mod tests { let snapshot = Snapshot::create( "pgml.diabetes", - vec!["target".to_string()], + Some(vec!["target".to_string()]), 0.5, Sampling::last, true, @@ -1007,7 +972,7 @@ mod tests { let result = std::panic::catch_unwind(|| { let _snapshot = Snapshot::create( "diabetes", - vec!["target".to_string()], + Some(vec!["target".to_string()]), 0.5, Sampling::last, true, diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 77a46e161..3af906c28 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -52,7 +52,7 @@ mod tests { let regression = Project::create("regression", Task::regression); let mut diabetes = Snapshot::create( "pgml.diabetes", - vec!["target".to_string()], + Some(vec!["target".to_string()]), 0.5, Sampling::last, false, @@ -61,7 +61,7 @@ mod tests { let classification = Project::create("classification", Task::classification); let mut breast_cancer = Snapshot::create( "pgml.breast_cancer", - vec!["malignant".to_string()], + Some(vec!["malignant".to_string()]), 0.5, Sampling::last, false, diff --git a/pgml-extension/src/bindings/sklearn.py b/pgml-extension/src/bindings/sklearn.py index 0d4b2930f..f1d762fcc 100644 --- a/pgml-extension/src/bindings/sklearn.py +++ b/pgml-extension/src/bindings/sklearn.py @@ -3,6 +3,7 @@ # Wrapper around Scikit-Learn loaded by PyO3 # in our Rust crate::engines::sklearn module. # +import sklearn.cluster import sklearn.linear_model import sklearn.kernel_ridge import sklearn.svm @@ -27,6 +28,9 @@ mean_squared_error, mean_absolute_error, confusion_matrix, + silhouette_score, + calinski_harabasz_score, + fowlkes_mallows_score, ) _ALGORITHM_MAP = { @@ -77,6 +81,17 @@ "xgboost_random_forest_classification": xgb.XGBRFClassifier, "lightgbm_regression": lightgbm.LGBMRegressor, "lightgbm_classification": lightgbm.LGBMClassifier, + "affinity_propagation_clustering": sklearn.cluster.AffinityPropagation, + "birch_clustering": sklearn.cluster.Birch, + "dbscan_clustering": sklearn.cluster.DBSCAN, + "feature_agglomeration_clustering": sklearn.cluster.FeatureAgglomeration, + "kmeans_clustering": sklearn.cluster.KMeans, + "mini_batch_kmeans_clustering": sklearn.cluster.MiniBatchKMeans, + "mean_shift_clustering": sklearn.cluster.MeanShift, + "optics_clustering": sklearn.cluster.OPTICS, + "spectral_clustering": sklearn.cluster.SpectralClustering, + "spectral_biclustering": sklearn.cluster.SpectralBiclustering, + "spectral_coclustering": sklearn.cluster.SpectralCoclustering, } @@ -95,7 +110,6 @@ def estimator(algorithm, num_features, num_targets, hyperparams): hyperparams = {} else: hyperparams = json.loads(hyperparams) - def train(X_train, y_train): instance = _ALGORITHM_MAP[algorithm](**hyperparams) if num_targets > 1 and algorithm in [ @@ -117,8 +131,9 @@ def train(X_train, y_train): X_train = np.asarray(X_train).reshape((-1, num_features)) - # Only support single value models for just now. - y_train = np.asarray(y_train).reshape((-1, num_targets)) + if num_targets > 0: + # Only support single value models for just now. + y_train = np.asarray(y_train).reshape((-1, num_targets)) instance.fit(X_train, y_train) return instance @@ -280,3 +295,12 @@ def classification_metrics(y_true, y_hat): "accuracy": accuracy, "mcc": mcc, } + + +def cluster_metrics(num_features, inputs_labels): + inputs = np.asarray(inputs_labels[0]).reshape((-1, num_features)) + labels = np.asarray(inputs_labels[1]).reshape((-1, 1)) + + return { + "silhouette": silhouette_score(inputs, labels), + } diff --git a/pgml-extension/src/bindings/sklearn.rs b/pgml-extension/src/bindings/sklearn.rs index 886504cff..404f7f720 100644 --- a/pgml-extension/src/bindings/sklearn.rs +++ b/pgml-extension/src/bindings/sklearn.rs @@ -1,3 +1,4 @@ +use pgrx::*; /// Scikit-Learn implementation. /// /// Scikit needs no introduction. It implements dozens of industry-standard @@ -297,6 +298,54 @@ pub fn lightgbm_classification(dataset: &Dataset, hyperparams: &Hyperparams) -> fit(dataset, hyperparams, "lightgbm_classification") } +pub fn affinity_propagation(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "affinity_propagation_clustering") +} + +pub fn agglomerative(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "agglomerative_clustering") +} + +pub fn birch(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "birch_clustering") +} + +pub fn dbscan(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "dbscan_clustering") +} + +pub fn feature_agglomeration(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "feature_agglomeration_clustering") +} + +pub fn kmeans(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "kmeans_clustering") +} + +pub fn mini_batch_kmeans(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "mini_batch_kmeans_clustering") +} + +pub fn mean_shift(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "mean_shift_clustering") +} + +pub fn optics(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "optics_clustering") +} + +pub fn spectral(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "spectral_clustering") +} + +pub fn spectral_bi(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "spectral_biclustering") +} + +pub fn spectral_co(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { + fit(dataset, hyperparams, "spectral_coclustering") +} + fn fit( dataset: &Dataset, hyperparams: &Hyperparams, @@ -533,6 +582,24 @@ pub fn classification_metrics( scores } +pub fn cluster_metrics( + num_features: usize, + inputs: &[f32], + labels: &[f32], +) -> HashMap { + Python::with_gil(|py| -> HashMap { + let calculate_metric = PY_MODULE.getattr(py, "cluster_metrics").unwrap(); + + let scores: HashMap = calculate_metric + .call1(py, (num_features, PyTuple::new(py, &[inputs, labels]))) + .unwrap() + .extract(py) + .unwrap(); + + scores + }) +} + pub fn package_version(name: &str) -> String { Python::with_gil(|py| -> String { let package = py.import(name).unwrap(); diff --git a/pgml-extension/src/orm/algorithm.rs b/pgml-extension/src/orm/algorithm.rs index 098e04adb..1cc3ca7af 100644 --- a/pgml-extension/src/orm/algorithm.rs +++ b/pgml-extension/src/orm/algorithm.rs @@ -38,6 +38,15 @@ pub enum Algorithm { linear_svm, lightgbm, transformers, + affinity_propagation, + birch, + feature_agglomeration, + mini_batch_kmeans, + mean_shift, + optics, + spectral, + spectral_bi, + spectral_co, } impl std::str::FromStr for Algorithm { @@ -79,6 +88,15 @@ impl std::str::FromStr for Algorithm { "linear_svm" => Ok(Algorithm::linear_svm), "lightgbm" => Ok(Algorithm::lightgbm), "transformers" => Ok(Algorithm::transformers), + "affinity_propagation" => Ok(Algorithm::affinity_propagation), + "birch" => Ok(Algorithm::birch), + "feature_agglomeration" => Ok(Algorithm::feature_agglomeration), + "mini_batch_kmeans" => Ok(Algorithm::mini_batch_kmeans), + "mean_shift" => Ok(Algorithm::mean_shift), + "optics" => Ok(Algorithm::optics), + "spectral" => Ok(Algorithm::spectral), + "spectral_bi" => Ok(Algorithm::spectral_bi), + "spectral_co" => Ok(Algorithm::spectral_co), _ => Err(()), } } @@ -123,6 +141,15 @@ impl std::string::ToString for Algorithm { Algorithm::linear_svm => "linear_svm".to_string(), Algorithm::lightgbm => "lightgbm".to_string(), Algorithm::transformers => "transformers".to_string(), + Algorithm::affinity_propagation => "transformers".to_string(), + Algorithm::birch => "birch".to_string(), + Algorithm::feature_agglomeration => "feature_agglomeration".to_string(), + Algorithm::mini_batch_kmeans => "mini_batch_kmeans".to_string(), + Algorithm::mean_shift => "mean_shift".to_string(), + Algorithm::optics => "optics".to_string(), + Algorithm::spectral => "spectral".to_string(), + Algorithm::spectral_bi => "spectral_bi".to_string(), + Algorithm::spectral_co => "spectral_co".to_string(), } } } diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index ed39ac267..ecdbf6d6e 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -411,6 +411,9 @@ impl Model { Algorithm::svm => linfa::Svm::fit, _ => todo!(), }, + Task::cluster => match self.algorithm { + _ => todo!(), + }, _ => error!("use pgml.tune for transformers tasks"), }, @@ -490,6 +493,14 @@ impl Model { Algorithm::lightgbm => sklearn::lightgbm_classification, _ => panic!("{:?} does not support classification", self.algorithm), }, + Task::cluster => match self.algorithm { + Algorithm::affinity_propagation => sklearn::affinity_propagation, + Algorithm::birch => sklearn::birch, + Algorithm::kmeans => sklearn::kmeans, + Algorithm::mini_batch_kmeans => sklearn::mini_batch_kmeans, + Algorithm::mean_shift => sklearn::mean_shift, + _ => panic!("{:?} does not support clustering", self.algorithm), + }, _ => error!("use pgml.tune for transformers tasks"), }, } @@ -553,6 +564,7 @@ impl Model { // The box is borrowed so that it may be reused by the caller #[allow(clippy::borrowed_box)] fn test(&self, dataset: &Dataset) -> IndexMap { + info!("Testing {:?} estimator {:?}", self.project.task, self); // Test the estimator on the data let y_hat = self.predict_batch(&dataset.x_test); let y_test = &dataset.y_test; @@ -655,7 +667,18 @@ impl Model { // This one is inaccurate, I have it in my TODO to reimplement. metrics.insert("mcc".to_string(), confusion_matrix.mcc()); } - _ => error!("no tests for huggingface"), + Task::cluster => { + #[cfg(all(feature = "python"))] + { + let sklearn_metrics = crate::bindings::sklearn::cluster_metrics( + dataset.num_features, + &dataset.x_test, + &y_hat, + ); + metrics.insert("silhouette".to_string(), sklearn_metrics["silhouette"]); + } + } + task => error!("No test metrics available for task: {:?}", task), } metrics @@ -801,16 +824,7 @@ impl Model { search_results.insert("n_splits".to_string(), json!(cv)); // Find the best estimator, hyperparams and metrics - let target_metric = match self.project.task { - Task::regression => "r2", - Task::classification => "f1", - Task::question_answering => "f1", - Task::translation => "blue", - Task::summarization => "rouge_ngram_f1", - Task::text_classification => "f1", - Task::text_generation => "perplexity", - Task::text2text => "perplexity", - }; + let target_metric = self.project.task.default_target_metric(); let mut i = 0; let mut best_index = 0; let mut best_metric = f32::NEG_INFINITY; @@ -826,7 +840,7 @@ impl Model { let fold_i = i / all_hyperparams.len(); let hyperparams_i = i % all_hyperparams.len(); let hyperparams = &all_hyperparams[hyperparams_i]; - let metric = *metrics.get(target_metric).unwrap(); + let metric = *metrics.get(&target_metric).unwrap(); fit_times[hyperparams_i][fold_i] = *metrics.get("fit_time").unwrap(); score_times[hyperparams_i][fold_i] = *metrics.get("score_time").unwrap(); test_scores[hyperparams_i][fold_i] = metric; diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 27f623347..7d2346ae7 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -413,7 +413,7 @@ impl Snapshot { let mut s = Snapshot { id: result.get(1).unwrap().unwrap(), relation_name: result.get(2).unwrap().unwrap(), - y_column_name: result.get(3).unwrap().unwrap(), + y_column_name: result.get(3).unwrap().unwrap_or_default(), test_size: result.get(4).unwrap().unwrap(), test_sampling: Sampling::from_str(result.get(5).unwrap().unwrap()).unwrap(), status: Status::from_str(result.get(6).unwrap().unwrap()).unwrap(), @@ -474,7 +474,7 @@ impl Snapshot { let mut s = Snapshot { id: result.get(1).unwrap().unwrap(), relation_name: result.get(2).unwrap().unwrap(), - y_column_name: result.get(3).unwrap().unwrap(), + y_column_name: result.get(3).unwrap().unwrap_or_default(), test_size: result.get(4).unwrap().unwrap(), test_sampling: Sampling::from_str(result.get(5).unwrap().unwrap()).unwrap(), status: Status::from_str(result.get(6).unwrap().unwrap()).unwrap(), @@ -495,7 +495,7 @@ impl Snapshot { pub fn create( relation_name: &str, - y_column_name: Vec, + y_column_name: Option>, test_size: f32, test_sampling: Sampling, materialized: bool, @@ -531,7 +531,10 @@ impl Snapshot { } let nullable = row[3].value::().unwrap().unwrap(); let position = row[4].value::().unwrap().unwrap() as usize; - let label = y_column_name.contains(&name); + let label = match y_column_name { + Some(ref y_column_name) => y_column_name.contains(&name), + None => false, + }; let mut statistics = Statistics::default(); let preprocessor = match preprocessors.get(&name) { Some(preprocessor) => { @@ -576,12 +579,15 @@ impl Snapshot { } ); }); - for column in &y_column_name { - if !columns.iter().any(|c| c.label && &c.name == column) { - error!( - "Column `{}` not found. Did you pass the correct `y_column_name`?", - column - ) + + if y_column_name.is_some() { + for column in y_column_name.as_ref().unwrap() { + if !columns.iter().any(|c| c.label && &c.name == column) { + error!( + "Column `{}` not found. Did you pass the correct `y_column_name`?", + column + ) + } } } @@ -601,7 +607,7 @@ impl Snapshot { let s = Snapshot { id: result.get(1).unwrap().unwrap(), relation_name: result.get(2).unwrap().unwrap(), - y_column_name: result.get(3).unwrap().unwrap(), + y_column_name: result.get(3).unwrap().unwrap_or_default(), test_size: result.get(4).unwrap().unwrap(), test_sampling: Sampling::from_str(result.get(5).unwrap().unwrap()).unwrap(), status, // 6 @@ -680,9 +686,12 @@ impl Snapshot { } pub(crate) fn num_classes(&self) -> usize { - match &self.first_label().statistics.categories { - Some(categories) => categories.len(), - None => self.first_label().statistics.distinct, + match &self.y_column_name.len() { + 0 => 0, + _ => match &self.first_label().statistics.categories { + Some(categories) => categories.len(), + None => self.first_label().statistics.distinct, + }, } } @@ -854,7 +863,6 @@ impl Snapshot { pub fn tabular_dataset(&mut self) -> Dataset { let numeric_encoded_dataset = self.numeric_encoded_dataset(); - // Analyze labels let label_data = ndarray::ArrayView2::from_shape( ( numeric_encoded_dataset.num_train_rows, @@ -863,17 +871,7 @@ impl Snapshot { &numeric_encoded_dataset.y_train, ) .unwrap(); - // The data for the first label - let target_data = label_data.columns().into_iter().next().unwrap(); - Zip::from(label_data.columns()) - .and(&self.label_positions()) - .for_each(|data, position| { - let column = &mut self.columns[position.column_position - 1]; // lookup the mutable one - column.analyze(&data, &target_data); - }); - - // Analyze features let feature_data = ndarray::ArrayView2::from_shape( ( numeric_encoded_dataset.num_train_rows, @@ -882,12 +880,36 @@ impl Snapshot { &numeric_encoded_dataset.x_train, ) .unwrap(); - Zip::from(feature_data.columns()) - .and(&self.feature_positions()) - .for_each(|data, position| { - let column = &mut self.columns[position.column_position - 1]; // lookup the mutable one - column.analyze(&data, &target_data); - }); + + // We only analyze supervised training sets that have labels for now. + if numeric_encoded_dataset.num_labels > 0 { + // We only analyze features against the first label in joint regression. + let target_data = label_data.columns().into_iter().next().unwrap(); + + // Analyze labels + Zip::from(label_data.columns()) + .and(&self.label_positions()) + .for_each(|data, position| { + let column = &mut self.columns[position.column_position - 1]; // lookup the mutable one + column.analyze(&data, &target_data); + }); + + // Analyze features + Zip::from(feature_data.columns()) + .and(&self.feature_positions()) + .for_each(|data, position| { + let column = &mut self.columns[position.column_position - 1]; // lookup the mutable one + column.analyze(&data, &target_data); + }); + } else { + // Analyze features for unsupervised learning + Zip::from(feature_data.columns()) + .and(&self.feature_positions()) + .for_each(|data, position| { + let column = &mut self.columns[position.column_position - 1]; // lookup the mutable one + column.analyze(&data, &data); + }); + } let mut analysis = IndexMap::new(); analysis.insert( diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs index bd9d69d56..b79a1f7dd 100644 --- a/pgml-extension/src/orm/task.rs +++ b/pgml-extension/src/orm/task.rs @@ -12,6 +12,7 @@ pub enum Task { text_classification, text_generation, text2text, + cluster, } // unfortunately the pgrx macro expands the enum names to underscore, but huggingface uses dash @@ -26,8 +27,65 @@ impl Task { Task::text_classification => "text_classification".to_string(), Task::text_generation => "text_generation".to_string(), Task::text2text => "text2text".to_string(), + Task::cluster => "cluster".to_string(), } } + + pub fn is_supervised(&self) -> bool { + match self { + Task::regression | Task::classification => true, + _ => false, + } + } + pub fn default_target_metric(&self) -> String { + match self { + Task::regression => "r2", + Task::classification => "f1", + Task::question_answering => "f1", + Task::translation => "blue", + Task::summarization => "rouge_ngram_f1", + Task::text_classification => "f1", + Task::text_generation => "perplexity", + Task::text2text => "perplexity", + Task::cluster => "silhouette", + } + .to_string() + } + + pub fn default_target_metric_positive(&self) -> bool { + match self { + Task::regression => true, + Task::classification => true, + Task::question_answering => true, + Task::translation => true, + Task::summarization => true, + Task::text_classification => true, + Task::text_generation => false, + Task::text2text => false, + Task::cluster => true, + } + } + + pub fn value_is_better(&self, value: f64, other: f64) -> bool { + if self.default_target_metric_positive() { + value > other + } else { + value < other + } + } + + pub fn default_target_metric_sql_order(&self) -> String { + let direction = if self.default_target_metric_positive() { + "DESC" + } else { + "ASC" + }; + format!( + "ORDER BY models.metrics->>'{}' {} NULL LAST", + self.default_target_metric(), + direction + ) + } } impl std::str::FromStr for Task { @@ -43,6 +101,7 @@ impl std::str::FromStr for Task { "text-classification" | "text_classification" => Ok(Task::text_classification), "text-generation" | "text_generation" => Ok(Task::text_generation), "text2text" => Ok(Task::text2text), + "cluster" => Ok(Task::cluster), _ => Err(()), } } @@ -59,6 +118,7 @@ impl std::string::ToString for Task { Task::text_classification => "text-classification".to_string(), Task::text_generation => "text-generation".to_string(), Task::text2text => "text2text".to_string(), + Task::cluster => "cluster".to_string(), } } } diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql index c1ef81d58..1b9e3771b 100644 --- a/pgml-extension/tests/test.sql +++ b/pgml-extension/tests/test.sql @@ -21,6 +21,7 @@ SELECT pgml.load_dataset('iris'); SELECT pgml.load_dataset('linnerud'); SELECT pgml.load_dataset('wine'); +\i examples/cluster.sql \i examples/binary_classification.sql \i examples/image_classification.sql \i examples/joint_regression.sql pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy