From 5f696048292fb1cd1aef7db3c34b5705a15fe71d Mon Sep 17 00:00:00 2001 From: Lev Date: Sat, 16 Apr 2022 11:23:31 -0700 Subject: [PATCH 1/2] Allow to retrain the same project --- pgml/pgml/__init__.py | 2 +- pgml/pgml/model.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pgml/pgml/__init__.py b/pgml/pgml/__init__.py index 1342ac7bd..78af849ea 100644 --- a/pgml/pgml/__init__.py +++ b/pgml/pgml/__init__.py @@ -1,2 +1,2 @@ def version(): - return "0.3" + return "0.4" diff --git a/pgml/pgml/model.py b/pgml/pgml/model.py index a073b9543..d8657c927 100644 --- a/pgml/pgml/model.py +++ b/pgml/pgml/model.py @@ -420,19 +420,29 @@ def train( test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25. test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"]. """ - project = Project.create(project_name, objective) - snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling) - best_model = None - best_error = None if objective == "regression": algorithms = ["linear", "random_forest"] elif objective == "classification": algorithms = ["random_forest"] else: raise PgMLException( - f"Unknown objective '{objective}', available options are: regression, classification" + f"Unknown objective `{objective}`, available options are: regression, classification." ) + try: + project = Project.find_by_name(project_name) + except PgMLException: + project = Project.create(project_name, objective) + + if project.objective != objective: + raise PgMLException( + f"Project `{project_name}` already exists with a different objective: `{project.objective}`. Create a new project instead." + ) + + snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling) + best_model = None + best_error = None + for algorithm_name in algorithms: model = Model.create(project, snapshot, algorithm_name) model.fit(snapshot) From c24419eb3ab8808529ab3bec39097003a9bd72c4 Mon Sep 17 00:00:00 2001 From: Lev Date: Sat, 16 Apr 2022 11:48:51 -0700 Subject: [PATCH 2/2] rollbacks --- pgml/pgml/model.py | 16 +++++++++++++--- sql/install.sql | 4 ++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pgml/pgml/model.py b/pgml/pgml/model.py index d8657c927..4c265a984 100644 --- a/pgml/pgml/model.py +++ b/pgml/pgml/model.py @@ -440,13 +440,23 @@ def train( ) snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling) - best_model = None - best_error = None + deployed = Model.find_deployed(project.id) + + # Let's assume that the deployed model is better for now. + best_model = deployed + best_error = best_model.mean_squared_error if best_model else None for algorithm_name in algorithms: model = Model.create(project, snapshot, algorithm_name) model.fit(snapshot) + + # Find the better model and deploy that. if best_error is None or model.mean_squared_error < best_error: best_error = model.mean_squared_error best_model = model - best_model.deploy() + + if deployed and deployed.id == best_model.id: + return "rolled back" + else: + best_model.deploy() + return "deployed" diff --git a/sql/install.sql b/sql/install.sql index d44fef64a..e18fdc9aa 100644 --- a/sql/install.sql +++ b/sql/install.sql @@ -109,9 +109,9 @@ RETURNS TABLE(project_name TEXT, objective TEXT, status TEXT) AS $$ from pgml.model import train - train(project_name, objective, relation_name, y_column_name) + status = train(project_name, objective, relation_name, y_column_name) - return [(project_name, objective, "deployed")] + return [(project_name, objective, status)] $$ LANGUAGE plpython3u; --- pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy