> {
+ self.bindings
+ .as_ref()
+ .unwrap()
+ .predict(vector, self.num_features, self.num_classes)
+ }
}
diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs
index 1116d98ae..7c23d0861 100644
--- a/pgml-extension/src/orm/task.rs
+++ b/pgml-extension/src/orm/task.rs
@@ -6,31 +6,33 @@ use serde::Deserialize;
pub enum Task {
regression,
classification,
+ decomposition,
+ clustering,
question_answering,
summarization,
translation,
text_classification,
text_generation,
text2text,
- cluster,
embedding,
text_pair_classification,
conversation,
}
-// unfortunately the pgrx macro expands the enum names to underscore, but huggingface uses dash
+// unfortunately the pgrx macro expands the enum names to underscore, but hugging face uses dash
impl Task {
pub fn to_pg_enum(&self) -> String {
match *self {
Task::regression => "regression".to_string(),
Task::classification => "classification".to_string(),
+ Task::decomposition => "decomposition".to_string(),
+ Task::clustering => "clustering".to_string(),
Task::question_answering => "question_answering".to_string(),
Task::summarization => "summarization".to_string(),
Task::translation => "translation".to_string(),
Task::text_classification => "text_classification".to_string(),
Task::text_generation => "text_generation".to_string(),
Task::text2text => "text2text".to_string(),
- Task::cluster => "cluster".to_string(),
Task::embedding => "embedding".to_string(),
Task::text_pair_classification => "text_pair_classification".to_string(),
Task::conversation => "conversation".to_string(),
@@ -45,13 +47,14 @@ impl Task {
match self {
Task::regression => "r2",
Task::classification => "f1",
+ Task::decomposition => "cumulative_explained_variance",
+ Task::clustering => "silhouette",
Task::question_answering => "f1",
Task::translation => "blue",
Task::summarization => "rouge_ngram_f1",
Task::text_classification => "f1",
Task::text_generation => "perplexity",
Task::text2text => "perplexity",
- Task::cluster => "silhouette",
Task::embedding => error!("No default target metric for embedding task"),
Task::text_pair_classification => "f1",
Task::conversation => "bleu",
@@ -63,13 +66,14 @@ impl Task {
match self {
Task::regression => true,
Task::classification => true,
+ Task::decomposition => true,
+ Task::clustering => true,
Task::question_answering => true,
Task::translation => true,
Task::summarization => true,
Task::text_classification => true,
Task::text_generation => false,
Task::text2text => false,
- Task::cluster => true,
Task::embedding => error!("No default target metric positive for embedding task"),
Task::text_pair_classification => true,
Task::conversation => true,
@@ -105,13 +109,14 @@ impl std::str::FromStr for Task {
match input {
"regression" => Ok(Task::regression),
"classification" => Ok(Task::classification),
+ "decomposition" => Ok(Task::decomposition),
+ "clustering" => Ok(Task::clustering),
"question-answering" | "question_answering" => Ok(Task::question_answering),
"summarization" => Ok(Task::summarization),
"translation" => Ok(Task::translation),
"text-classification" | "text_classification" => Ok(Task::text_classification),
"text-generation" | "text_generation" => Ok(Task::text_generation),
"text2text" => Ok(Task::text2text),
- "cluster" => Ok(Task::cluster),
"text-pair-classification" | "text_pair_classification" => Ok(Task::text_pair_classification),
"conversation" => Ok(Task::conversation),
_ => Err(()),
@@ -124,13 +129,14 @@ impl std::string::ToString for Task {
match *self {
Task::regression => "regression".to_string(),
Task::classification => "classification".to_string(),
+ Task::decomposition => "decomposition".to_string(),
+ Task::clustering => "clustering".to_string(),
Task::question_answering => "question-answering".to_string(),
Task::summarization => "summarization".to_string(),
Task::translation => "translation".to_string(),
Task::text_classification => "text-classification".to_string(),
Task::text_generation => "text-generation".to_string(),
Task::text2text => "text2text".to_string(),
- Task::cluster => "cluster".to_string(),
Task::embedding => "embedding".to_string(),
Task::text_pair_classification => "text-pair-classification".to_string(),
Task::conversation => "conversation".to_string(),
diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql
index a6c75dee9..2490678ee 100644
--- a/pgml-extension/tests/test.sql
+++ b/pgml-extension/tests/test.sql
@@ -21,7 +21,8 @@ SELECT pgml.load_dataset('iris');
SELECT pgml.load_dataset('linnerud');
SELECT pgml.load_dataset('wine');
-\i examples/cluster.sql
+\i examples/clustering.sql
+\i examples/decomposition.sql
\i examples/binary_classification.sql
\i examples/image_classification.sql
\i examples/joint_regression.sql
pFad - Phonifier reborn
Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy