Skip to content

Commit 5aaac95

Browse files
authored
cleanup clippy lints (#316)
1 parent fbbb67a commit 5aaac95

File tree

12 files changed

+107
-100
lines changed

12 files changed

+107
-100
lines changed

pgml-extension/pgml_rust/src/api.rs

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::fmt::Write;
12
use std::str::FromStr;
23

34
use pgx::*;
@@ -11,6 +12,7 @@ use crate::orm::Snapshot;
1112
use crate::orm::Strategy;
1213
use crate::orm::Task;
1314

15+
#[allow(clippy::too_many_arguments)]
1416
#[pg_extern]
1517
fn train(
1618
project_name: &str,
@@ -76,23 +78,24 @@ fn train(
7678
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
7779
);
7880

79-
let mut deploy = false;
80-
if deployed_metrics.is_none() {
81-
deploy = true;
82-
} else {
83-
let deployed_metrics = deployed_metrics.unwrap().0;
84-
let deployed_metrics = deployed_metrics.as_object().unwrap();
85-
if project.task == Task::classification
86-
&& deployed_metrics.get("f1").unwrap().as_f64()
87-
< new_metrics.get("f1").unwrap().as_f64()
88-
{
89-
deploy = true;
90-
}
91-
if project.task == Task::regression
92-
&& deployed_metrics.get("r2").unwrap().as_f64()
93-
< new_metrics.get("r2").unwrap().as_f64()
94-
{
95-
deploy = true;
81+
let mut deploy = true;
82+
if let Some(deployed_metrics) = deployed_metrics {
83+
let deployed_metrics = deployed_metrics.0.as_object().unwrap();
84+
match project.task {
85+
Task::classification => {
86+
if deployed_metrics.get("f1").unwrap().as_f64()
87+
> new_metrics.get("f1").unwrap().as_f64()
88+
{
89+
deploy = false;
90+
}
91+
}
92+
Task::regression => {
93+
if deployed_metrics.get("r2").unwrap().as_f64()
94+
> new_metrics.get("r2").unwrap().as_f64()
95+
{
96+
deploy = false;
97+
}
98+
}
9699
}
97100
}
98101

@@ -133,34 +136,39 @@ fn deploy(
133136
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
134137
);
135138
let project_id =
136-
project_id.expect(format!("Project named `{}` does not exist.", project_name).as_str());
139+
project_id.unwrap_or_else(|| panic!("Project named `{}` does not exist.", project_name));
137140
let task = Task::from_str(&task.unwrap()).unwrap();
138141

139142
let mut sql = "SELECT models.id, models.algorithm::TEXT FROM pgml_rust.models JOIN pgml_rust.projects ON projects.id = models.project_id".to_string();
140143
let mut predicate = "\nWHERE projects.name = $1".to_string();
141-
match algorithm {
142-
Some(algorithm) => {
143-
predicate += &format!(
144-
"\nAND algorithm::TEXT = '{}'",
145-
algorithm.to_string().as_str()
146-
)
147-
}
148-
_ => (),
144+
if let Some(algorithm) = algorithm {
145+
let _ = write!(
146+
predicate,
147+
"\nAND algorithm::TEXT = '{}'",
148+
algorithm.to_string().as_str()
149+
);
149150
}
150151
match strategy {
151152
Strategy::best_score => match task {
152153
Task::regression => {
153-
sql += &format!("{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST");
154+
let _ = write!(
155+
sql,
156+
"{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST"
157+
);
154158
}
155159
Task::classification => {
156-
sql += &format!("{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST");
160+
let _ = write!(
161+
sql,
162+
"{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST"
163+
);
157164
}
158165
},
159166
Strategy::most_recent => {
160-
sql += &format!("{predicate}\nORDER by models.created_at DESC");
167+
let _ = write!(sql, "{predicate}\nORDER by models.created_at DESC");
161168
}
162169
Strategy::rollback => {
163-
sql += &format!(
170+
let _ = write!(
171+
sql,
164172
"
165173
JOIN pgml_rust.deployments ON deployments.project_id = projects.id
166174
AND deployments.model_id = models.id
@@ -230,10 +238,7 @@ fn load_dataset(
230238
limit: Option<default!(i64, "NULL")>,
231239
) -> impl std::iter::Iterator<Item = (name!(table_name, String), name!(rows, i64))> {
232240
// cast limit since pgx doesn't support usize
233-
let limit: Option<usize> = match limit {
234-
Some(limit) => Some(limit.try_into().unwrap()),
235-
None => None,
236-
};
241+
let limit: Option<usize> = limit.map(|limit| limit.try_into().unwrap());
237242
let (name, rows) = match source {
238243
"breast_cancer" => crate::orm::dataset::load_breast_cancer(limit),
239244
"diabetes" => crate::orm::dataset::load_diabetes(limit),

pgml-extension/pgml_rust/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ fn model_predict(model_id: i64, features: Vec<f32>) -> f32 {
2828

2929
match guard.get(&model_id) {
3030
Some(data) => {
31-
let bst = Booster::load_buffer(&data).unwrap();
31+
let bst = Booster::load_buffer(data).unwrap();
3232
let dmat = DMatrix::from_dense(&features, 1).unwrap();
3333

3434
bst.predict(&dmat).unwrap()[0]
@@ -66,7 +66,7 @@ fn model_predict_batch(model_id: i64, features: Vec<f32>, num_rows: i32) -> Vec<
6666

6767
match guard.get(&model_id) {
6868
Some(data) => {
69-
let bst = Booster::load_buffer(&data).unwrap();
69+
let bst = Booster::load_buffer(data).unwrap();
7070
let dmat = DMatrix::from_dense(&features, num_rows as usize).unwrap();
7171

7272
bst.predict(&dmat).unwrap()

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use pgx::*;
22
use serde::Deserialize;
33

4-
#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)]
4+
#[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)]
55
#[allow(non_camel_case_types)]
66
pub enum Algorithm {
77
linear,

pgml-extension/pgml_rust/src/orm/dataset.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@ impl Dataset {
3030

3131
pub fn distinct_labels(&self) -> u32 {
3232
let mut v = HashSet::new();
33-
// Treat the f32 values as u32 for std::cmp::Eq. We don't
33+
// Treat the f32 values as u32 for std::cmp::Eq. We don't
3434
// care about the nuance of nan equality here, they should
3535
// already be filtered out upstream.
36-
self.y.iter().for_each(|i| if !i.is_nan() { v.insert(i.to_bits()); });
36+
self.y.iter().for_each(|i| {
37+
if !i.is_nan() {
38+
v.insert(i.to_bits());
39+
}
40+
});
3741
v.len().try_into().unwrap()
3842
}
3943
}
@@ -86,7 +90,7 @@ pub fn load_diabetes(limit: Option<usize>) -> (String, i64) {
8690
None => diabetes.num_samples,
8791
};
8892
for i in 0..limit {
89-
let age = diabetes.data[(i * diabetes.num_features) + 0];
93+
let age = diabetes.data[(i * diabetes.num_features)];
9094
let sex = diabetes.data[(i * diabetes.num_features) + 1];
9195
let bmi = diabetes.data[(i * diabetes.num_features) + 2];
9296
let bp = diabetes.data[(i * diabetes.num_features) + 3];
@@ -153,6 +157,7 @@ pub fn load_digits(limit: Option<usize>) -> (String, i64) {
153157
let target = digits.target[i];
154158
// shape the image in a 2d array
155159
let mut image = vec![vec![0.; 8]; 8];
160+
#[allow(clippy::needless_range_loop)] // x & y are in fact used
156161
for x in 0..8 {
157162
for y in 0..8 {
158163
image[x][y] = digits.data[(i * 64) + (x * 8) + y];
@@ -213,7 +218,7 @@ pub fn load_iris(limit: Option<usize>) -> (String, i64) {
213218
None => iris.num_samples,
214219
};
215220
for i in 0..limit {
216-
let sepal_length = iris.data[(i * iris.num_features) + 0];
221+
let sepal_length = iris.data[(i * iris.num_features)];
217222
let sepal_width = iris.data[(i * iris.num_features) + 1];
218223
let petal_length = iris.data[(i * iris.num_features) + 2];
219224
let petal_width = iris.data[(i * iris.num_features) + 3];
@@ -354,7 +359,7 @@ pub fn load_breast_cancer(limit: Option<usize>) -> (String, i64) {
354359
vec![
355360
(
356361
PgBuiltInOids::FLOAT4OID.oid(),
357-
breast_cancer.data[(i * breast_cancer.num_features) + 0].into_datum(),
362+
breast_cancer.data[(i * breast_cancer.num_features)].into_datum(),
358363
),
359364
(
360365
PgBuiltInOids::FLOAT4OID.oid(),

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@ use crate::orm::Algorithm;
1313
use crate::orm::Dataset;
1414
use crate::orm::Task;
1515

16+
#[allow(clippy::type_complexity)]
1617
static DEPLOYED_ESTIMATORS_BY_PROJECT_NAME: Lazy<Mutex<HashMap<String, Arc<Box<dyn Estimator>>>>> =
1718
Lazy::new(|| Mutex::new(HashMap::new()));
1819

1920
pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estimator>> {
2021
{
2122
let estimators = DEPLOYED_ESTIMATORS_BY_PROJECT_NAME.lock().unwrap();
2223
let estimator = estimators.get(name);
23-
if estimator.is_some() {
24-
return estimator.unwrap().clone();
24+
if let Some(estimator) = estimator {
25+
return estimator.clone();
2526
}
2627
}
2728

@@ -40,33 +41,26 @@ pub fn find_deployed_estimator_by_project_name(name: &str) -> Arc<Box<dyn Estima
4041
LIMIT 1;",
4142
vec![(PgBuiltInOids::TEXTOID.oid(), name.into_datum())],
4243
);
43-
let task = Task::from_str(
44-
&task.expect(
45-
format!(
46-
"Project {} does not have a trained and deployed model.",
47-
name
48-
)
49-
.as_str(),
50-
),
51-
)
44+
let task = Task::from_str(&task.unwrap_or_else(|| {
45+
panic!(
46+
"Project {} does not have a trained and deployed model.",
47+
name
48+
)
49+
}))
5250
.unwrap();
53-
let algorithm = Algorithm::from_str(
54-
&algorithm.expect(
55-
format!(
56-
"Project {} does not have a trained and deployed model.",
57-
name
58-
)
59-
.as_str(),
60-
),
61-
)
51+
let algorithm = Algorithm::from_str(&algorithm.unwrap_or_else(|| {
52+
panic!(
53+
"Project {} does not have a trained and deployed model.",
54+
name
55+
)
56+
}))
6257
.unwrap();
63-
let data = data.expect(
64-
format!(
58+
let data = data.unwrap_or_else(|| {
59+
panic!(
6560
"Project {} does not have a trained and deployed model.",
6661
name
6762
)
68-
.as_str(),
69-
);
63+
});
7064

7165
let e: Box<dyn Estimator> = match task {
7266
Task::regression => match algorithm {
@@ -125,7 +119,12 @@ fn predict_smartcore(
125119
smartcore::api::Predictor::predict(predictor, &features).unwrap()[0]
126120
}
127121

128-
fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, distinct_labels: u32, task: Task) -> HashMap<String, f32> {
122+
fn calc_metrics(
123+
y_test: &Array1<f32>,
124+
y_hat: &Array1<f32>,
125+
distinct_labels: u32,
126+
task: Task,
127+
) -> HashMap<String, f32> {
129128
let mut results = HashMap::new();
130129
match task {
131130
Task::regression => {

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl Model {
6060
(PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()),
6161
])
6262
).first();
63-
if result.len() > 0 {
63+
if !result.is_empty() {
6464
model = Some(Model {
6565
id: result.get_datum(1).unwrap(),
6666
project_id: result.get_datum(2).unwrap(),
@@ -69,7 +69,7 @@ impl Model {
6969
hyperparams: result.get_datum(5).unwrap(),
7070
status: result.get_datum(6).unwrap(),
7171
metrics: result.get_datum(7),
72-
search: search, // TODO
72+
search, // TODO
7373
search_params: result.get_datum(9).unwrap(),
7474
search_args: result.get_datum(10).unwrap(),
7575
created_at: result.get_datum(11).unwrap(),
@@ -82,8 +82,8 @@ impl Model {
8282
});
8383
let mut model = model.unwrap();
8484
let dataset = snapshot.dataset();
85-
model.fit(&project, &dataset);
86-
model.test(&project, &dataset);
85+
model.fit(project, &dataset);
86+
model.test(project, &dataset);
8787
model
8888
}
8989

@@ -119,7 +119,7 @@ impl Model {
119119
.unwrap(),
120120
)),
121121
};
122-
let bytes: Vec<u8> = rmp_serde::to_vec(&*estimator.as_ref().unwrap()).unwrap();
122+
let bytes: Vec<u8> = rmp_serde::to_vec(estimator.as_ref().unwrap()).unwrap();
123123
Spi::get_one_with_args::<i64>(
124124
"INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id",
125125
vec![
@@ -131,11 +131,11 @@ impl Model {
131131
}
132132
Algorithm::xgboost => {
133133
let mut dtrain =
134-
DMatrix::from_dense(&dataset.x_train(), dataset.num_train_rows).unwrap();
134+
DMatrix::from_dense(dataset.x_train(), dataset.num_train_rows).unwrap();
135135
let mut dtest =
136-
DMatrix::from_dense(&dataset.x_test(), dataset.num_test_rows).unwrap();
137-
dtrain.set_labels(&dataset.y_train()).unwrap();
138-
dtest.set_labels(&dataset.y_test()).unwrap();
136+
DMatrix::from_dense(dataset.x_test(), dataset.num_test_rows).unwrap();
137+
dtrain.set_labels(dataset.y_train()).unwrap();
138+
dtest.set_labels(dataset.y_test()).unwrap();
139139

140140
// specify datasets to evaluate against during training
141141
let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
@@ -146,8 +146,10 @@ impl Model {
146146
.objective(match project.task {
147147
Task::regression => xgboost::parameters::learning::Objective::RegLinear,
148148
Task::classification => {
149-
xgboost::parameters::learning::Objective::MultiSoftmax(dataset.distinct_labels())
150-
},
149+
xgboost::parameters::learning::Objective::MultiSoftmax(
150+
dataset.distinct_labels(),
151+
)
152+
}
151153
})
152154
.build()
153155
.unwrap();
@@ -207,12 +209,8 @@ impl Model {
207209
}
208210

209211
fn test(&mut self, project: &Project, dataset: &Dataset) {
210-
let metrics = self
211-
.estimator
212-
.as_ref()
213-
.unwrap()
214-
.test(project.task, &dataset);
215-
self.metrics = Some(JsonB(json!(metrics.clone())));
212+
let metrics = self.estimator.as_ref().unwrap().test(project.task, dataset);
213+
self.metrics = Some(JsonB(json!(metrics)));
216214
Spi::get_one_with_args::<i64>(
217215
"UPDATE pgml_rust.models SET metrics = $1 WHERE id = $2 RETURNING id",
218216
vec![

pgml-extension/pgml_rust/src/orm/project.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ impl Project {
2525
(PgBuiltInOids::INT8OID.oid(), id.into_datum()),
2626
])
2727
).first();
28-
if result.len() > 0 {
28+
if !result.is_empty() {
2929
project = Some(Project {
3030
id: result.get_datum(1).unwrap(),
3131
name: result.get_datum(2).unwrap(),
@@ -50,7 +50,7 @@ impl Project {
5050
(PgBuiltInOids::TEXTOID.oid(), name.into_datum()),
5151
])
5252
).first();
53-
if result.len() > 0 {
53+
if !result.is_empty() {
5454
project = Some(Project {
5555
id: result.get_datum(1).unwrap(),
5656
name: result.get_datum(2).unwrap(),
@@ -76,7 +76,7 @@ impl Project {
7676
(PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()),
7777
])
7878
).first();
79-
if result.len() > 0 {
79+
if !result.is_empty() {
8080
project = Some(Project {
8181
id: result.get_datum(1).unwrap(),
8282
name: result.get_datum(2).unwrap(),

pgml-extension/pgml_rust/src/orm/sampling.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use pgx::*;
22
use serde::Deserialize;
33

4-
#[derive(PostgresEnum, Copy, Clone, PartialEq, Debug, Deserialize)]
4+
#[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)]
55
#[allow(non_camel_case_types)]
66
pub enum Sampling {
77
random,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy