Skip to content

Commit aebd36d

Browse files
authored
Scikit hyperparameter search (#333)
1 parent 4ed6faa commit aebd36d

File tree

5 files changed

+170
-9
lines changed

5 files changed

+170
-9
lines changed

pgml-extension/pgml_rust/sql/schema.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models(
8181
hyperparams JSONB NOT NULL,
8282
status TEXT NOT NULL,
8383
metrics JSONB,
84-
search pgml_rust.search,
84+
search TEXT,
8585
search_params JSONB NOT NULL,
8686
search_args JSONB NOT NULL,
8787
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),

pgml-extension/pgml_rust/src/engines/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@ pub mod engine;
22
pub mod sklearn;
33
pub mod smartcore;
44
pub mod xgboost;
5+
6+
use serde_json;
7+
8+
pub type Hyperparams = serde_json::Map<std::string::String, serde_json::Value>;

pgml-extension/pgml_rust/src/engines/sklearn.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
use pyo3::prelude::*;
1111
use pyo3::types::PyTuple;
1212

13+
use crate::engines::Hyperparams;
1314
use crate::orm::algorithm::Algorithm;
1415
use crate::orm::dataset::Dataset;
1516
use crate::orm::estimator::SklearnBox;
17+
use crate::orm::search::Search;
1618
use crate::orm::task::Task;
1719

1820
use pgx::*;
@@ -171,3 +173,60 @@ pub fn sklearn_load(data: &Vec<u8>) -> SklearnBox {
171173
SklearnBox::new(estimator)
172174
})
173175
}
176+
177+
/// Hyperparameter search using Scikit's
178+
/// RandomizedSearchCV or GridSearchCV.
179+
pub fn sklearn_search(
180+
task: Task,
181+
algorithm: Algorithm,
182+
search: Search,
183+
dataset: &Dataset,
184+
hyperparams: &Hyperparams,
185+
search_params: &Hyperparams,
186+
) -> (SklearnBox, Hyperparams) {
187+
let module = include_str!(concat!(
188+
env!("CARGO_MANIFEST_DIR"),
189+
"/src/engines/wrappers.py"
190+
));
191+
192+
let algorithm_name = match task {
193+
Task::regression => match algorithm {
194+
Algorithm::linear => "linear_regression",
195+
_ => todo!(),
196+
},
197+
198+
Task::classification => match algorithm {
199+
Algorithm::linear => "linear_classification",
200+
_ => todo!(),
201+
},
202+
};
203+
204+
Python::with_gil(|py| -> (SklearnBox, Hyperparams) {
205+
let module = PyModule::from_code(py, module, "", "").unwrap();
206+
let estimator_search = module.getattr("estimator_search").unwrap();
207+
let train = estimator_search
208+
.call1(PyTuple::new(
209+
py,
210+
&[
211+
algorithm_name.into_py(py),
212+
dataset.num_features.into_py(py),
213+
serde_json::to_string(hyperparams).unwrap().into_py(py),
214+
serde_json::to_string(search_params).unwrap().into_py(py),
215+
search.to_string().into_py(py),
216+
None::<String>.into_py(py),
217+
],
218+
))
219+
.unwrap();
220+
221+
let (estimator, hyperparams): (Py<PyAny>, String) = train
222+
.call1(PyTuple::new(py, &[dataset.x_train(), dataset.y_train()]))
223+
.unwrap()
224+
.extract()
225+
.unwrap();
226+
227+
let estimator = SklearnBox::new(estimator);
228+
let hyperparams: Hyperparams = serde_json::from_str::<Hyperparams>(&hyperparams).unwrap();
229+
230+
(estimator, hyperparams)
231+
})
232+
}

pgml-extension/pgml_rust/src/engines/wrappers.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def estimator_joint(algorithm_name, num_features, num_targets, hyperparams):
7575
"""Returns the correct estimator based on algorithm names we defined
7676
internally (see dict above).
7777
78-
78+
7979
Parameters:
8080
- algorithm_name: The human-readable name of the algorithm (see dict above).
8181
- num_features: The number of features in X.
@@ -101,6 +101,70 @@ def train(X_train, y_train):
101101
return train
102102

103103

104+
def estimator_search_joint(algorithm_name, num_features, num_targets, hyperparams, search_params, search, search_args):
105+
"""Hyperparameter search.
106+
107+
Parameters:
108+
- algorithm_name: The human-readable name of the algorithm (see dict above).
109+
- num_features: The number of features in X.
110+
- num_targets: For joint training (more than one y target).
111+
- hyperparams: JSON of hyperparameters.
112+
- search_params: Hyperparameters to search (see Scikit docs for examples).
113+
- search: Type of search to do, grid or random.
114+
- search_args: See Scikit docs for examples.
115+
116+
Return:
117+
A tuple of Estimator and chosen hyperparameters.
118+
"""
119+
if search_args is None:
120+
search_args = {}
121+
else:
122+
search_args = json.loads(search_args)
123+
124+
if search is None:
125+
search = "grid"
126+
127+
search_params = json.loads(search_params)
128+
hyperparams = json.loads(hyperparams)
129+
130+
if search == "random":
131+
algorithm = sklearn.model_selection.RandomizedSearchCV(
132+
_ALGORITHM_MAP[algorithm_name](**hyperparams),
133+
search_params,
134+
)
135+
elif search == "grid":
136+
algorithm = sklearn.model_selection.GridSearchCV(
137+
_ALGORITHM_MAP[algorithm_name](**hyperparams),
138+
search_params,
139+
)
140+
else:
141+
raise Exception(f"search can be 'grid' or 'random', got: '{search}'")
142+
143+
def train(X_train, y_train):
144+
X_train = np.asarray(X_train).reshape((-1, num_features))
145+
y_train = np.asarray(y_train).reshape((-1, num_targets))
146+
147+
algorithm.fit(X_train, y_train)
148+
149+
return (algorithm.best_estimator_, json.dumps(algorithm.best_params_))
150+
151+
return train
152+
153+
154+
def estimator_search(algorithm_name, num_features, hyperparams, search_params, search, search_args):
155+
"""Hyperparameter search.
156+
157+
Parameters:
158+
- algorithm_name: The human-readable name of the algorithm (see dict above).
159+
- num_features: The number of features in X.
160+
- hyperparams: JSON of hyperparameters.
161+
- search_params: Hyperparameters to search (see Scikit docs for examples).
162+
- search: Type of search to do, grid or random.
163+
- search_args: See Scikit docs for examples.
164+
"""
165+
return estimator_search_joint(algorithm_name, num_features, 1, hyperparams, search_params, search, search_args)
166+
167+
104168
def test(estimator, X_test):
105169
"""Validate the estimator using the test dataset.
106170
@@ -134,6 +198,7 @@ def predictor_joint(estimator, num_features, num_targets):
134198
- num_features: The number of features in X.
135199
- num_targets: Used in joint models (more than 1 y target).
136200
"""
201+
137202
def predict(X):
138203
X = np.asarray(X).reshape((-1, num_features))
139204
y_hat = estimator.predict(X)
@@ -149,7 +214,7 @@ def predict(X):
149214

150215
def save(estimator):
151216
"""Save the estimtator as bytes (pickle).
152-
217+
153218
Parameters:
154219
- estimator: Scikit-Learn estimator, instantiated.
155220

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::orm::Project;
1111
use crate::orm::Search;
1212
use crate::orm::Snapshot;
1313

14-
use crate::engines::sklearn::{sklearn_save, sklearn_train};
14+
use crate::engines::sklearn::{sklearn_save, sklearn_search, sklearn_train};
1515
use crate::engines::smartcore::{smartcore_save, smartcore_train};
1616
use crate::engines::xgboost::{xgboost_save, xgboost_train};
1717

@@ -67,7 +67,7 @@ impl Model {
6767
Spi::connect(|client| {
6868
let result = client.select("
6969
INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args, engine)
70-
VALUES ($1, $2, $3, $4, $5, $6::pgml_rust.search, $7, $8, $9)
70+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
7171
RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;",
7272
Some(1),
7373
Some(vec![
@@ -76,7 +76,10 @@ impl Model {
7676
(PgBuiltInOids::TEXTOID.oid(), algorithm.to_string().into_datum()),
7777
(PgBuiltInOids::JSONBOID.oid(), hyperparams.into_datum()),
7878
(PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()),
79-
(PgBuiltInOids::TEXTOID.oid(), search.into_datum()),
79+
(PgBuiltInOids::TEXTOID.oid(), match search {
80+
Some(search) => Some(search.to_string()),
81+
None => None,
82+
}.into_datum()),
8083
(PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()),
8184
(PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()),
8285
(PgBuiltInOids::TEXTOID.oid(), engine.to_string().into_datum()),
@@ -117,13 +120,30 @@ impl Model {
117120
fn fit(&mut self, project: &Project, dataset: &Dataset) {
118121
// Get the hyperparameters.
119122
let hyperparams: &serde_json::Value = &self.hyperparams.0;
120-
let hyperparams = hyperparams.as_object().unwrap();
123+
let mut hyperparams = hyperparams.as_object().unwrap().clone();
121124

122125
// Train the estimator. We are getting the estimator struct and
123126
// it's serialized form to save into the `models` table.
124127
let (estimator, bytes): (Box<dyn Estimator>, Vec<u8>) = match self.engine {
125128
Engine::sklearn => {
126-
let estimator = sklearn_train(project.task, self.algorithm, dataset, &hyperparams);
129+
let estimator = match self.search {
130+
Some(search) => {
131+
let (estimator, chosen_hyperparams) = sklearn_search(
132+
project.task,
133+
self.algorithm,
134+
search,
135+
dataset,
136+
&hyperparams,
137+
&self.search_params.0.as_object().unwrap(),
138+
);
139+
140+
hyperparams.extend(chosen_hyperparams);
141+
142+
estimator
143+
}
144+
145+
None => sklearn_train(project.task, self.algorithm, dataset, &hyperparams),
146+
};
127147

128148
let bytes = sklearn_save(&estimator);
129149

@@ -150,7 +170,7 @@ impl Model {
150170
_ => todo!(),
151171
};
152172

153-
// Save the estimator
173+
// Save the estimator.
154174
Spi::get_one_with_args::<i64>(
155175
"INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id",
156176
vec![
@@ -159,6 +179,19 @@ impl Model {
159179
]
160180
).unwrap();
161181

182+
// Save the hyperparams after search
183+
Spi::get_one_with_args::<i64>(
184+
"UPDATE pgml_rust.models SET hyperparams = $1::jsonb WHERE id = $2 RETURNING id",
185+
vec![
186+
(
187+
PgBuiltInOids::TEXTOID.oid(),
188+
serde_json::to_string(&hyperparams).unwrap().into_datum(),
189+
),
190+
(PgBuiltInOids::INT8OID.oid(), self.id.into_datum()),
191+
],
192+
)
193+
.unwrap();
194+
162195
self.estimator = Some(estimator);
163196
}
164197

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy