-
Notifications
You must be signed in to change notification settings - Fork 333
cleanup clippy lints #316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cleanup clippy lints #316
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
use std::fmt::Write; | ||
use std::str::FromStr; | ||
|
||
use pgx::*; | ||
|
@@ -11,6 +12,7 @@ use crate::orm::Snapshot; | |
use crate::orm::Strategy; | ||
use crate::orm::Task; | ||
|
||
#[allow(clippy::too_many_arguments)] | ||
#[pg_extern] | ||
fn train( | ||
project_name: &str, | ||
|
@@ -76,23 +78,24 @@ fn train( | |
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], | ||
); | ||
|
||
let mut deploy = false; | ||
if deployed_metrics.is_none() { | ||
deploy = true; | ||
} else { | ||
let deployed_metrics = deployed_metrics.unwrap().0; | ||
let deployed_metrics = deployed_metrics.as_object().unwrap(); | ||
if project.task == Task::classification | ||
&& deployed_metrics.get("f1").unwrap().as_f64() | ||
< new_metrics.get("f1").unwrap().as_f64() | ||
{ | ||
deploy = true; | ||
} | ||
if project.task == Task::regression | ||
&& deployed_metrics.get("r2").unwrap().as_f64() | ||
< new_metrics.get("r2").unwrap().as_f64() | ||
{ | ||
deploy = true; | ||
let mut deploy = true; | ||
if let Some(deployed_metrics) = deployed_metrics { | ||
let deployed_metrics = deployed_metrics.0.as_object().unwrap(); | ||
match project.task { | ||
Task::classification => { | ||
if deployed_metrics.get("f1").unwrap().as_f64() | ||
> new_metrics.get("f1").unwrap().as_f64() | ||
{ | ||
deploy = false; | ||
} | ||
} | ||
Task::regression => { | ||
if deployed_metrics.get("r2").unwrap().as_f64() | ||
> new_metrics.get("r2").unwrap().as_f64() | ||
{ | ||
deploy = false; | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
@@ -133,34 +136,39 @@ fn deploy( | |
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], | ||
); | ||
let project_id = | ||
project_id.expect(format!("Project named `{}` does not exist.", project_name).as_str()); | ||
project_id.unwrap_or_else(|| panic!("Project named `{}` does not exist.", project_name)); | ||
let task = Task::from_str(&task.unwrap()).unwrap(); | ||
|
||
let mut sql = "SELECT models.id, models.algorithm::TEXT FROM pgml_rust.models JOIN pgml_rust.projects ON projects.id = models.project_id".to_string(); | ||
let mut predicate = "\nWHERE projects.name = $1".to_string(); | ||
match algorithm { | ||
Some(algorithm) => { | ||
predicate += &format!( | ||
"\nAND algorithm::TEXT = '{}'", | ||
algorithm.to_string().as_str() | ||
) | ||
} | ||
_ => (), | ||
if let Some(algorithm) = algorithm { | ||
let _ = write!( | ||
predicate, | ||
"\nAND algorithm::TEXT = '{}'", | ||
algorithm.to_string().as_str() | ||
); | ||
} | ||
match strategy { | ||
Strategy::best_score => match task { | ||
Task::regression => { | ||
sql += &format!("{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST"); | ||
let _ = write!( | ||
sql, | ||
"{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST" | ||
); | ||
} | ||
Task::classification => { | ||
sql += &format!("{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST"); | ||
let _ = write!( | ||
sql, | ||
"{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST" | ||
); | ||
} | ||
}, | ||
Strategy::most_recent => { | ||
sql += &format!("{predicate}\nORDER by models.created_at DESC"); | ||
let _ = write!(sql, "{predicate}\nORDER by models.created_at DESC"); | ||
} | ||
Strategy::rollback => { | ||
sql += &format!( | ||
let _ = write!( | ||
sql, | ||
" | ||
JOIN pgml_rust.deployments ON deployments.project_id = projects.id | ||
AND deployments.model_id = models.id | ||
|
@@ -230,10 +238,7 @@ fn load_dataset( | |
limit: Option<default!(i64, "NULL")>, | ||
) -> impl std::iter::Iterator<Item = (name!(table_name, String), name!(rows, i64))> { | ||
// cast limit since pgx doesn't support usize | ||
let limit: Option<usize> = match limit { | ||
Some(limit) => Some(limit.try_into().unwrap()), | ||
None => None, | ||
}; | ||
let limit: Option<usize> = limit.map(|limit| limit.try_into().unwrap()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weird. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, turns out you can map things like this... |
||
let (name, rows) = match source { | ||
"breast_cancer" => crate::orm::dataset::load_breast_cancer(limit), | ||
"diabetes" => crate::orm::dataset::load_diabetes(limit), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ fn model_predict(model_id: i64, features: Vec<f32>) -> f32 { | |
|
||
match guard.get(&model_id) { | ||
Some(data) => { | ||
let bst = Booster::load_buffer(&data).unwrap(); | ||
let bst = Booster::load_buffer(data).unwrap(); | ||
let dmat = DMatrix::from_dense(&features, 1).unwrap(); | ||
|
||
bst.predict(&dmat).unwrap()[0] | ||
|
@@ -66,7 +66,7 @@ fn model_predict_batch(model_id: i64, features: Vec<f32>, num_rows: i32) -> Vec< | |
|
||
match guard.get(&model_id) { | ||
Some(data) => { | ||
let bst = Booster::load_buffer(&data).unwrap(); | ||
let bst = Booster::load_buffer(data).unwrap(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a clippy warning? Should of been a compilation error... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The compiler handles the extra dereference for you. |
||
let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap(); | ||
|
||
bst.predict(&dmat).unwrap() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,10 +30,14 @@ impl Dataset { | |
|
||
pub fn distinct_labels(&self) -> u32 { | ||
let mut v = HashSet::new(); | ||
// Treat the f32 values as u32 for std::cmp::Eq. We don't | ||
// Treat the f32 values as u32 for std::cmp::Eq. We don't | ||
// care about the nuance of nan equality here, they should | ||
// already be filtered out upstream. | ||
self.y.iter().for_each(|i| if !i.is_nan() { v.insert(i.to_bits()); }); | ||
self.y.iter().for_each(|i| { | ||
if !i.is_nan() { | ||
v.insert(i.to_bits()); | ||
} | ||
}); | ||
v.len().try_into().unwrap() | ||
} | ||
} | ||
|
@@ -86,7 +90,7 @@ pub fn load_diabetes(limit: Option<usize>) -> (String, i64) { | |
None => diabetes.num_samples, | ||
}; | ||
for i in 0..limit { | ||
let age = diabetes.data[(i * diabetes.num_features) + 0]; | ||
let age = diabetes.data[(i * diabetes.num_features)]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hah! |
||
let sex = diabetes.data[(i * diabetes.num_features) + 1]; | ||
let bmi = diabetes.data[(i * diabetes.num_features) + 2]; | ||
let bp = diabetes.data[(i * diabetes.num_features) + 3]; | ||
|
@@ -153,6 +157,7 @@ pub fn load_digits(limit: Option<usize>) -> (String, i64) { | |
let target = digits.target[i]; | ||
// shape the image in a 2d array | ||
let mut image = vec![vec![0.; 8]; 8]; | ||
#[allow(clippy::needless_range_loop)] // x & y are in fact used | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh no, don't get me started with these haha! Rubocop was plenty, thank you |
||
for x in 0..8 { | ||
for y in 0..8 { | ||
image[x][y] = digits.data[(i * 64) + (x * 8) + y]; | ||
|
@@ -213,7 +218,7 @@ pub fn load_iris(limit: Option<usize>) -> (String, i64) { | |
None => iris.num_samples, | ||
}; | ||
for i in 0..limit { | ||
let sepal_length = iris.data[(i * iris.num_features) + 0]; | ||
let sepal_length = iris.data[(i * iris.num_features)]; | ||
let sepal_width = iris.data[(i * iris.num_features) + 1]; | ||
let petal_length = iris.data[(i * iris.num_features) + 2]; | ||
let petal_width = iris.data[(i * iris.num_features) + 3]; | ||
|
@@ -354,7 +359,7 @@ pub fn load_breast_cancer(limit: Option<usize>) -> (String, i64) { | |
vec![ | ||
( | ||
PgBuiltInOids::FLOAT4OID.oid(), | ||
breast_cancer.data[(i * breast_cancer.num_features) + 0].into_datum(), | ||
breast_cancer.data[(i * breast_cancer.num_features)].into_datum(), | ||
), | ||
( | ||
PgBuiltInOids::FLOAT4OID.oid(), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,15 +13,16 @@ use crate::orm::Algorithm; | |
use crate::orm::Dataset; | ||
use crate::orm::Task; | ||
|
||
#[allow(clippy::type_complexity)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clippy needs to get it together. |
||
static DEPLOYED_ESTIMATORS_BY_PROJECT_NAME: Lazy<Mutex<HashMap<String, Arc<Box<dyn Estimator>>>>> = | ||
Lazy::new(|| Mutex::new(HashMap::new())); | ||
|
||
pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estimator>> { | ||
{ | ||
let estimators = DEPLOYED_ESTIMATORS_BY_PROJECT_NAME.lock().unwrap(); | ||
let estimator = estimators.get(name); | ||
if estimator.is_some() { | ||
return estimator.unwrap().clone(); | ||
if let Some(estimator) = estimator { | ||
return estimator.clone(); | ||
} | ||
} | ||
|
||
|
@@ -40,33 +41,26 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima | |
LIMIT 1;", | ||
vec![(PgBuiltInOids::TEXTOID.oid(), name.into_datum())], | ||
); | ||
let task = Task::from_str( | ||
&task.expect( | ||
format!( | ||
"Project {} does not have a trained and deployed model.", | ||
name | ||
) | ||
.as_str(), | ||
), | ||
) | ||
let task = Task::from_str(&task.unwrap_or_else(|| { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok what's up with these. Why does not like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. expect input is always executed which may have side effects or allocations. or_else is only executed if needed. |
||
panic!( | ||
"Project {} does not have a trained and deployed model.", | ||
name | ||
) | ||
})) | ||
.unwrap(); | ||
let algorithm = Algorithm::from_str( | ||
&algorithm.expect( | ||
format!( | ||
"Project {} does not have a trained and deployed model.", | ||
name | ||
) | ||
.as_str(), | ||
), | ||
) | ||
let algorithm = Algorithm::from_str(&algorithm.unwrap_or_else(|| { | ||
panic!( | ||
"Project {} does not have a trained and deployed model.", | ||
name | ||
) | ||
})) | ||
.unwrap(); | ||
let data = data.expect( | ||
format!( | ||
let data = data.unwrap_or_else(|| { | ||
panic!( | ||
"Project {} does not have a trained and deployed model.", | ||
name | ||
) | ||
.as_str(), | ||
); | ||
}); | ||
|
||
let e: Box<dyn Estimator> = match task { | ||
Task::regression => match algorithm { | ||
|
@@ -125,7 +119,12 @@ fn predict_smartcore( | |
smartcore::api::Predictor::predict(predictor, &features).unwrap()[0] | ||
} | ||
|
||
fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, distinct_labels: u32, task: Task) -> HashMap<String, f32> { | ||
fn calc_metrics( | ||
y_test: &Array1<f32>, | ||
y_hat: &Array1<f32>, | ||
distinct_labels: u32, | ||
task: Task, | ||
) -> HashMap<String, f32> { | ||
let mut results = HashMap::new(); | ||
match task { | ||
Task::regression => { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,7 +60,7 @@ impl Model { | |
(PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()), | ||
]) | ||
).first(); | ||
if result.len() > 0 { | ||
if !result.is_empty() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. He got you there. |
||
model = Some(Model { | ||
id: result.get_datum(1).unwrap(), | ||
project_id: result.get_datum(2).unwrap(), | ||
|
@@ -69,7 +69,7 @@ impl Model { | |
hyperparams: result.get_datum(5).unwrap(), | ||
status: result.get_datum(6).unwrap(), | ||
metrics: result.get_datum(7), | ||
search: search, // TODO | ||
search, // TODO | ||
search_params: result.get_datum(9).unwrap(), | ||
search_args: result.get_datum(10).unwrap(), | ||
created_at: result.get_datum(11).unwrap(), | ||
|
@@ -82,8 +82,8 @@ impl Model { | |
}); | ||
let mut model = model.unwrap(); | ||
let dataset = snapshot.dataset(); | ||
model.fit(&project, &dataset); | ||
model.test(&project, &dataset); | ||
model.fit(project, &dataset); | ||
model.test(project, &dataset); | ||
model | ||
} | ||
|
||
|
@@ -119,7 +119,7 @@ impl Model { | |
.unwrap(), | ||
)), | ||
}; | ||
let bytes: Vec<u8> = rmp_serde::to_vec(&*estimator.as_ref().unwrap()).unwrap(); | ||
let bytes: Vec<u8> = rmp_serde::to_vec(estimator.as_ref().unwrap()).unwrap(); | ||
Spi::get_one_with_args::<i64>( | ||
"INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", | ||
vec![ | ||
|
@@ -131,11 +131,11 @@ impl Model { | |
} | ||
Algorithm::xgboost => { | ||
let mut dtrain = | ||
DMatrix::from_dense(&dataset.x_train(), dataset.num_train_rows).unwrap(); | ||
DMatrix::from_dense(dataset.x_train(), dataset.num_train_rows).unwrap(); | ||
let mut dtest = | ||
DMatrix::from_dense(&dataset.x_test(), dataset.num_test_rows).unwrap(); | ||
dtrain.set_labels(&dataset.y_train()).unwrap(); | ||
dtest.set_labels(&dataset.y_test()).unwrap(); | ||
DMatrix::from_dense(dataset.x_test(), dataset.num_test_rows).unwrap(); | ||
dtrain.set_labels(dataset.y_train()).unwrap(); | ||
dtest.set_labels(dataset.y_test()).unwrap(); | ||
|
||
// specify datasets to evaluate against during training | ||
let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")]; | ||
|
@@ -146,8 +146,10 @@ impl Model { | |
.objective(match project.task { | ||
Task::regression => xgboost::parameters::learning::Objective::RegLinear, | ||
Task::classification => { | ||
xgboost::parameters::learning::Objective::MultiSoftmax(dataset.distinct_labels()) | ||
}, | ||
xgboost::parameters::learning::Objective::MultiSoftmax( | ||
dataset.distinct_labels(), | ||
) | ||
} | ||
}) | ||
.build() | ||
.unwrap(); | ||
|
@@ -207,12 +209,8 @@ impl Model { | |
} | ||
|
||
fn test(&mut self, project: &Project, dataset: &Dataset) { | ||
let metrics = self | ||
.estimator | ||
.as_ref() | ||
.unwrap() | ||
.test(project.task, &dataset); | ||
self.metrics = Some(JsonB(json!(metrics.clone()))); | ||
let metrics = self.estimator.as_ref().unwrap().test(project.task, dataset); | ||
self.metrics = Some(JsonB(json!(metrics))); | ||
Spi::get_one_with_args::<i64>( | ||
"UPDATE pgml_rust.models SET metrics = $1 WHERE id = $2 RETURNING id", | ||
vec![ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What come on clippy, this seems totally fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format! does an extra allocation compared to write... 🤷♂️