>()
- .join(", "),
- self.relation_name(),
- match self.materialized {
- // If the snapshot is materialized, we already randomized it.
- true => "",
- false => {
- if self.test_sampling == Sampling::random {
- "ORDER BY random()"
- } else {
- ""
- }
- }
- },
- );
-
// Postgres Arrays arrays are 1 indexed and so are SPI tuples...
- let result = client.select(&sql, None, None).unwrap();
+ let result = client.select(&self.select_sql(), None, None).unwrap();
let num_rows = result.len();
-
- let num_test_rows = if self.test_size > 1.0 {
- self.test_size as usize
- } else {
- (num_rows as f32 * self.test_size).round() as usize
- };
-
- let num_train_rows = num_rows - num_test_rows;
- if num_train_rows == 0 {
- error!(
- "test_size = {} is too large. There are only {} samples.",
- num_test_rows, num_rows
- );
- }
-
+ let (num_train_rows, num_test_rows) = self.train_test_split(num_rows);
let num_features = self.num_features();
let num_labels = self.num_labels();
@@ -968,7 +1098,7 @@ impl Snapshot {
let num_features = self.num_features();
let num_labels = self.num_labels();
- data = Some(Dataset {
+ data = Some(Dataset{
x_train,
y_train,
x_test,
@@ -1009,6 +1139,9 @@ fn check_column_size(column: &mut Column, len: usize) {
if column.size == 0 {
column.size = len;
} else if column.size != len {
- error!("Mismatched array length for feature `{}`. Expected: {} Received: {}", column.name, column.size, len);
+ error!(
+ "Mismatched array length for feature `{}`. Expected: {} Received: {}",
+ column.name, column.size, len
+ );
}
}
diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs
index 1c2cd92cf..f9285a2cd 100644
--- a/pgml-extension/src/orm/task.rs
+++ b/pgml-extension/src/orm/task.rs
@@ -6,6 +6,28 @@ use serde::Deserialize;
pub enum Task {
regression,
classification,
+ question_answering,
+ summarization,
+ translation,
+ text_classification,
+ text_generation,
+ text2text,
+}
+
+// unfortunately the pgx macro expands the enum names to underscore, but huggingface uses dash
+impl Task {
+ pub fn to_pg_enum(&self) -> String {
+ match *self {
+ Task::regression => "regression".to_string(),
+ Task::classification => "classification".to_string(),
+ Task::question_answering => "question_answering".to_string(),
+ Task::summarization => "summarization".to_string(),
+ Task::translation => "translation".to_string(),
+ Task::text_classification => "text_classification".to_string(),
+ Task::text_generation => "text_generation".to_string(),
+ Task::text2text => "text2text".to_string(),
+ }
+ }
}
impl std::str::FromStr for Task {
@@ -15,6 +37,12 @@ impl std::str::FromStr for Task {
match input {
"regression" => Ok(Task::regression),
"classification" => Ok(Task::classification),
+ "question-answering" | "question_answering" => Ok(Task::question_answering),
+ "summarization" => Ok(Task::summarization),
+ "translation" => Ok(Task::translation),
+ "text-classification" | "text_classification" => Ok(Task::text_classification),
+ "text-generation" | "text_generation" => Ok(Task::text_generation),
+ "text2text" => Ok(Task::text2text),
_ => Err(()),
}
}
@@ -25,6 +53,12 @@ impl std::string::ToString for Task {
match *self {
Task::regression => "regression".to_string(),
Task::classification => "classification".to_string(),
+ Task::question_answering => "question-answering".to_string(),
+ Task::summarization => "summarization".to_string(),
+ Task::translation => "translation".to_string(),
+ Task::text_classification => "text-classification".to_string(),
+ Task::text_generation => "text-generation".to_string(),
+ Task::text2text => "text2text".to_string(),
}
}
}
pFad - Phonifier reborn
Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy