Skip to content

Commit 8fff76d

Browse files
authored
Add clustering algorithms (#795)
1 parent fc209d9 commit 8fff76d

File tree

23 files changed

+534
-247
lines changed

23 files changed

+534
-247
lines changed

.github/workflows/package-extension.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
workflow_dispatch:
55
inputs:
66
packageVersion:
7-
default: "2.6.0"
7+
default: "2.7.0"
88

99
jobs:
1010
build:

pgml-dashboard/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-dashboard/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml-dashboard"
3-
version = "2.6.0"
3+
version = "2.7.0"
44
edition = "2021"
55
authors = ["PostgresML <team@postgresml.org>"]
66
license = "MIT"

pgml-dashboard/content/docs/guides/setup/developers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ SELECT pgml.version();
127127
postgres=# select pgml.version();
128128
version
129129
-------------------
130-
2.6.0
130+
2.7.0
131131
(1 row)
132132
```
133133

pgml-dashboard/content/docs/guides/setup/v2/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ SELECT pgml.version();
217217
postgres=# select pgml.version();
218218
version
219219
-------------------
220-
2.6.0
220+
2.7.0
221221
(1 row)
222222
```
223223

pgml-dashboard/content/docs/guides/training/algorithm_selection.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
We currently support regression and classification algorithms from [scikit-learn](https://scikit-learn.org/), [XGBoost](https://xgboost.readthedocs.io/), and [LightGBM](https://lightgbm.readthedocs.io/).
44

5-
## Algorithms
5+
## Supervised Algorithms
66

77
### Gradient Boosting
88
Algorithm | Regression | Classification
@@ -54,6 +54,18 @@ Algorithm | Regression | Classification
5454
`kernel_ridge` | [KernelRidge](https://scikit-learn.org/stable/modules/generated/sklearn.kernel_ridge.KernelRidge.html) | -
5555
`gaussian_process` | [GaussianProcessRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessRegressor.html) | [GaussianProcessClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessClassifier.html)
5656

57+
## Unsupervised Algorithms
58+
59+
### Clustering
60+
61+
|Algorithm | Reference |
62+
|---|-------------------------------------------------------------------------------------------------------------------|
63+
`affinity_propagation` | [AffinityPropagation](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AffinityPropagation.html)
64+
`birch` | [Birch](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html)
65+
`kmeans` | [K-Means](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html)
66+
`mini_batch_kmeans` | [MiniBatchKMeans](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html)
67+
68+
5769
## Comparing Algorithms
5870

5971
Any of the above algorithms can be passed to our `pgml.train()` function using the `algorithm` parameter. If the parameter is omitted, linear regression is used by default.

pgml-dashboard/src/models.rs

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ impl Project {
5757
"summarization" => Ok("rouge_ngram_f1"),
5858
"translation" => Ok("bleu"),
5959
"text_generation" | "text2text" => Ok("perplexity"),
60+
"cluster" => Ok("silhouette"),
6061
task => Err(anyhow::anyhow!("Unhandled task: {}", task)),
6162
}
6263
}
@@ -68,6 +69,7 @@ impl Project {
6869
"summarization" => Ok("Rouge Ngram F<sup>1</sup>"),
6970
"translation" => Ok("Bleu"),
7071
"text_generation" | "text2text" => Ok("Perplexity"),
72+
"cluster" => Ok("silhouette"),
7173
task => Err(anyhow::anyhow!("Unhandled task: {}", task)),
7274
}
7375
}
@@ -544,7 +546,7 @@ impl Model {
544546
pub struct Snapshot {
545547
pub id: i64,
546548
pub relation_name: String,
547-
pub y_column_name: Vec<String>,
549+
pub y_column_name: Option<Vec<String>>,
548550
pub test_size: f32,
549551
pub test_sampling: Option<String>,
550552
pub status: String,
@@ -686,28 +688,42 @@ impl Snapshot {
686688
}
687689
}
688690

689-
pub fn features<'a>(&'a self) -> Option<Vec<&'a serde_json::Map<String, serde_json::Value>>> {
691+
pub fn features(&self) -> Option<Vec<&serde_json::Map<String, serde_json::Value>>> {
690692
match self.columns() {
691-
Some(columns) => Some(
692-
columns
693-
.into_iter()
694-
.filter(|column| {
695-
!self
696-
.y_column_name
697-
.contains(&column["name"].as_str().unwrap().to_string())
698-
})
699-
.collect(),
700-
),
693+
Some(columns) => {
694+
if self.y_column_name.is_none() {
695+
return Some(columns.into_iter().collect());
696+
}
697+
698+
Some(
699+
columns
700+
.into_iter()
701+
.filter(|column| {
702+
!self
703+
.y_column_name
704+
.as_ref()
705+
.unwrap()
706+
.contains(&column["name"].as_str().unwrap().to_string())
707+
})
708+
.collect(),
709+
)
710+
}
701711
None => None,
702712
}
703713
}
704714

705-
pub fn labels<'a>(&'a self) -> Option<Vec<&'a serde_json::Map<String, serde_json::Value>>> {
715+
pub fn labels(&self) -> Option<Vec<&serde_json::Map<String, serde_json::Value>>> {
716+
if self.y_column_name.is_none() {
717+
return Some(Vec::new());
718+
}
719+
706720
self.columns().map(|columns| {
707721
columns
708722
.into_iter()
709723
.filter(|column| {
710724
self.y_column_name
725+
.as_ref()
726+
.unwrap()
711727
.contains(&column["name"].as_str().unwrap().to_string())
712728
})
713729
.collect()

pgml-dashboard/templates/content/dashboard/panels/snapshot.html

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ <h2><span class="material-symbols-outlined">bubble_chart</span>Features</h2>
7373
%>
7474
<h3><%= name %>&nbsp;<code><%= feature["pg_type"].as_str().unwrap() | upper %></code></h3>
7575
<figure id="<%= name_machine %>_distribution"></figure>
76-
<% for y_column_name in snapshot.y_column_name.iter() { %>
76+
<% if snapshot.y_column_name.as_ref().is_some() { %>
77+
<% for y_column_name in snapshot.y_column_name.as_ref().unwrap().iter() { %>
7778
<figure id="<%= name_machine %>_correlation_<%= y_column_name %>"></figure>
7879
<% } %>
80+
<% } %>
7981
<% } %>
8082
</section>
8183

@@ -102,11 +104,13 @@ <h3><%= name %>&nbsp;<code><%= feature["pg_type"].as_str().unwrap() | upper %></
102104
renderModel(<%= model.id %>, <%= model.key_metric(project).unwrap() %>, [0, 1]);
103105
<% } %>
104106

105-
<% for y_column_name in snapshot.y_column_name.iter() { %>
107+
<% if snapshot.y_column_name.as_ref().is_some() { %>
108+
<% for y_column_name in snapshot.y_column_name.as_ref().unwrap().iter() { %>
106109
setTimeout(renderDistribution, delay, "<%= y_column_name %>", <%= y_column_name %>_samples, NaN);
107110
setTimeout(renderOutliers, delay, "<%= y_column_name %>", <%= y_column_name %>_samples, <%= snapshot.target_stddev(y_column_name) %>)
108111
<% } %>
109-
112+
<% } %>
113+
110114
var delay = 600;
111115

112116
<% for feature in snapshot.features().unwrap().iter() {
@@ -116,9 +120,11 @@ <h3><%= name %>&nbsp;<code><%= feature["pg_type"].as_str().unwrap() | upper %></
116120
delay += 200;
117121

118122
setTimeout(renderDistribution, delay, "<%= name_machine %>", <%= name_machine %>_samples, NaN);
119-
<% for y_column_name in snapshot.y_column_name.iter() { %>
123+
<% if snapshot.y_column_name.as_ref().is_some() { %>
124+
<% for y_column_name in snapshot.y_column_name.as_ref().unwrap().iter() { %>
120125
setTimeout(renderCorrelation, delay, "<%= name_machine %>", "<%= y_column_name %>", <%= name_machine %>_samples, <%= y_column_name %>_samples);
121126
<% } %>
127+
<% } %>
122128
<% } %>
123129
}
124130
renderCharts();

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy