Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgml-extension/pgml_rust/sql/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models(
hyperparams JSONB NOT NULL,
status TEXT NOT NULL,
metrics JSONB,
search pgml_rust.search,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should stay as an enum value, you can convert to_string() in the call to scikit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't work, it gives me an error about not finding the type. I don't know why but enums are a pain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other enums work, you just have to cast to TEXT on the in and out. We don't have the OID for them AFAIK, although pgx might know better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably the nullable property that's causing this then.

search TEXT,
search_params JSONB NOT NULL,
search_args JSONB NOT NULL,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT clock_timestamp(),
Expand Down
4 changes: 4 additions & 0 deletions pgml-extension/pgml_rust/src/engines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@ pub mod engine;
pub mod sklearn;
pub mod smartcore;
pub mod xgboost;

use serde_json;

pub type Hyperparams = serde_json::Map<std::string::String, serde_json::Value>;
59 changes: 59 additions & 0 deletions pgml-extension/pgml_rust/src/engines/sklearn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
use pyo3::prelude::*;
use pyo3::types::PyTuple;

use crate::engines::Hyperparams;
use crate::orm::algorithm::Algorithm;
use crate::orm::dataset::Dataset;
use crate::orm::estimator::SklearnBox;
use crate::orm::search::Search;
use crate::orm::task::Task;

use pgx::*;
Expand Down Expand Up @@ -171,3 +173,60 @@ pub fn sklearn_load(data: &Vec<u8>) -> SklearnBox {
SklearnBox::new(estimator)
})
}

/// Hyperparameter search using Scikit's
/// RandomizedSearchCV or GridSearchCV.
pub fn sklearn_search(
task: Task,
algorithm: Algorithm,
search: Search,
dataset: &Dataset,
hyperparams: &Hyperparams,
search_params: &Hyperparams,
) -> (SklearnBox, Hyperparams) {
let module = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/engines/wrappers.py"
));

let algorithm_name = match task {
Task::regression => match algorithm {
Algorithm::linear => "linear_regression",
_ => todo!(),
},

Task::classification => match algorithm {
Algorithm::linear => "linear_classification",
_ => todo!(),
},
};

Python::with_gil(|py| -> (SklearnBox, Hyperparams) {
let module = PyModule::from_code(py, module, "", "").unwrap();
let estimator_search = module.getattr("estimator_search").unwrap();
let train = estimator_search
.call1(PyTuple::new(
py,
&[
algorithm_name.into_py(py),
dataset.num_features.into_py(py),
serde_json::to_string(hyperparams).unwrap().into_py(py),
serde_json::to_string(search_params).unwrap().into_py(py),
search.to_string().into_py(py),
None::<String>.into_py(py),
],
))
.unwrap();

let (estimator, hyperparams): (Py<PyAny>, String) = train
.call1(PyTuple::new(py, &[dataset.x_train(), dataset.y_train()]))
.unwrap()
.extract()
.unwrap();

let estimator = SklearnBox::new(estimator);
let hyperparams: Hyperparams = serde_json::from_str::<Hyperparams>(&hyperparams).unwrap();

(estimator, hyperparams)
})
}
69 changes: 67 additions & 2 deletions pgml-extension/pgml_rust/src/engines/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def estimator_joint(algorithm_name, num_features, num_targets, hyperparams):
"""Returns the correct estimator based on algorithm names we defined
internally (see dict above).


Parameters:
- algorithm_name: The human-readable name of the algorithm (see dict above).
- num_features: The number of features in X.
Expand All @@ -101,6 +101,70 @@ def train(X_train, y_train):
return train


def estimator_search_joint(algorithm_name, num_features, num_targets, hyperparams, search_params, search, search_args):
"""Hyperparameter search.

Parameters:
- algorithm_name: The human-readable name of the algorithm (see dict above).
- num_features: The number of features in X.
- num_targets: For joint training (more than one y target).
- hyperparams: JSON of hyperparameters.
- search_params: Hyperparameters to search (see Scikit docs for examples).
- search: Type of search to do, grid or random.
- search_args: See Scikit docs for examples.

Return:
A tuple of Estimator and chosen hyperparameters.
"""
if search_args is None:
search_args = {}
else:
search_args = json.loads(search_args)

if search is None:
search = "grid"

search_params = json.loads(search_params)
hyperparams = json.loads(hyperparams)

if search == "random":
algorithm = sklearn.model_selection.RandomizedSearchCV(
_ALGORITHM_MAP[algorithm_name](**hyperparams),
search_params,
)
elif search == "grid":
algorithm = sklearn.model_selection.GridSearchCV(
_ALGORITHM_MAP[algorithm_name](**hyperparams),
search_params,
)
else:
raise Exception(f"search can be 'grid' or 'random', got: '{search}'")

def train(X_train, y_train):
X_train = np.asarray(X_train).reshape((-1, num_features))
y_train = np.asarray(y_train).reshape((-1, num_targets))

algorithm.fit(X_train, y_train)

return (algorithm.best_estimator_, json.dumps(algorithm.best_params_))

return train


def estimator_search(algorithm_name, num_features, hyperparams, search_params, search, search_args):
"""Hyperparameter search.

Parameters:
- algorithm_name: The human-readable name of the algorithm (see dict above).
- num_features: The number of features in X.
- hyperparams: JSON of hyperparameters.
- search_params: Hyperparameters to search (see Scikit docs for examples).
- search: Type of search to do, grid or random.
- search_args: See Scikit docs for examples.
"""
return estimator_search_joint(algorithm_name, num_features, 1, hyperparams, search_params, search, search_args)


def test(estimator, X_test):
"""Validate the estimator using the test dataset.

Expand Down Expand Up @@ -134,6 +198,7 @@ def predictor_joint(estimator, num_features, num_targets):
- num_features: The number of features in X.
- num_targets: Used in joint models (more than 1 y target).
"""

def predict(X):
X = np.asarray(X).reshape((-1, num_features))
y_hat = estimator.predict(X)
Expand All @@ -149,7 +214,7 @@ def predict(X):

def save(estimator):
"""Save the estimtator as bytes (pickle).

Parameters:
- estimator: Scikit-Learn estimator, instantiated.

Expand Down
45 changes: 39 additions & 6 deletions pgml-extension/pgml_rust/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::orm::Project;
use crate::orm::Search;
use crate::orm::Snapshot;

use crate::engines::sklearn::{sklearn_save, sklearn_train};
use crate::engines::sklearn::{sklearn_save, sklearn_search, sklearn_train};
use crate::engines::smartcore::{smartcore_save, smartcore_train};
use crate::engines::xgboost::{xgboost_save, xgboost_train};

Expand Down Expand Up @@ -67,7 +67,7 @@ impl Model {
Spi::connect(|client| {
let result = client.select("
INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args, engine)
VALUES ($1, $2, $3, $4, $5, $6::pgml_rust.search, $7, $8, $9)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;",
Some(1),
Some(vec![
Expand All @@ -76,7 +76,10 @@ impl Model {
(PgBuiltInOids::TEXTOID.oid(), algorithm.to_string().into_datum()),
(PgBuiltInOids::JSONBOID.oid(), hyperparams.into_datum()),
(PgBuiltInOids::TEXTOID.oid(), "new".to_string().into_datum()),
(PgBuiltInOids::TEXTOID.oid(), search.into_datum()),
(PgBuiltInOids::TEXTOID.oid(), match search {
Some(search) => Some(search.to_string()),
None => None,
}.into_datum()),
(PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()),
(PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()),
(PgBuiltInOids::TEXTOID.oid(), engine.to_string().into_datum()),
Expand Down Expand Up @@ -117,13 +120,30 @@ impl Model {
fn fit(&mut self, project: &Project, dataset: &Dataset) {
// Get the hyperparameters.
let hyperparams: &serde_json::Value = &self.hyperparams.0;
let hyperparams = hyperparams.as_object().unwrap();
let mut hyperparams = hyperparams.as_object().unwrap().clone();

// Train the estimator. We are getting the estimator struct and
// it's serialized form to save into the `models` table.
let (estimator, bytes): (Box<dyn Estimator>, Vec<u8>) = match self.engine {
Engine::sklearn => {
let estimator = sklearn_train(project.task, self.algorithm, dataset, &hyperparams);
let estimator = match self.search {
Some(search) => {
let (estimator, chosen_hyperparams) = sklearn_search(
project.task,
self.algorithm,
search,
dataset,
&hyperparams,
&self.search_params.0.as_object().unwrap(),
);

hyperparams.extend(chosen_hyperparams);

estimator
}

None => sklearn_train(project.task, self.algorithm, dataset, &hyperparams),
};

let bytes = sklearn_save(&estimator);

Expand All @@ -150,7 +170,7 @@ impl Model {
_ => todo!(),
};

// Save the estimator
// Save the estimator.
Spi::get_one_with_args::<i64>(
"INSERT INTO pgml_rust.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id",
vec![
Expand All @@ -159,6 +179,19 @@ impl Model {
]
).unwrap();

// Save the hyperparams after search
Spi::get_one_with_args::<i64>(
"UPDATE pgml_rust.models SET hyperparams = $1::jsonb WHERE id = $2 RETURNING id",
vec![
(
PgBuiltInOids::TEXTOID.oid(),
serde_json::to_string(&hyperparams).unwrap().into_datum(),
),
(PgBuiltInOids::INT8OID.oid(), self.id.into_datum()),
],
)
.unwrap();

self.estimator = Some(estimator);
}

Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy