Content-Length: 4391 | pFad | http://github.com/postgresml/postgresml/pull/8.patch
thub.com
From 5f696048292fb1cd1aef7db3c34b5705a15fe71d Mon Sep 17 00:00:00 2001
From: Lev
Date: Sat, 16 Apr 2022 11:23:31 -0700
Subject: [PATCH 1/2] Allow to retrain the same project
---
pgml/pgml/__init__.py | 2 +-
pgml/pgml/model.py | 20 +++++++++++++++-----
2 files changed, 16 insertions(+), 6 deletions(-)
diff --git a/pgml/pgml/__init__.py b/pgml/pgml/__init__.py
index 1342ac7bd..78af849ea 100644
--- a/pgml/pgml/__init__.py
+++ b/pgml/pgml/__init__.py
@@ -1,2 +1,2 @@
def version():
- return "0.3"
+ return "0.4"
diff --git a/pgml/pgml/model.py b/pgml/pgml/model.py
index a073b9543..d8657c927 100644
--- a/pgml/pgml/model.py
+++ b/pgml/pgml/model.py
@@ -420,19 +420,29 @@ def train(
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
"""
- project = Project.create(project_name, objective)
- snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling)
- best_model = None
- best_error = None
if objective == "regression":
algorithms = ["linear", "random_forest"]
elif objective == "classification":
algorithms = ["random_forest"]
else:
raise PgMLException(
- f"Unknown objective '{objective}', available options are: regression, classification"
+ f"Unknown objective `{objective}`, available options are: regression, classification."
)
+ try:
+ project = Project.find_by_name(project_name)
+ except PgMLException:
+ project = Project.create(project_name, objective)
+
+ if project.objective != objective:
+ raise PgMLException(
+ f"Project `{project_name}` already exists with a different objective: `{project.objective}`. Create a new project instead."
+ )
+
+ snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling)
+ best_model = None
+ best_error = None
+
for algorithm_name in algorithms:
model = Model.create(project, snapshot, algorithm_name)
model.fit(snapshot)
From c24419eb3ab8808529ab3bec39097003a9bd72c4 Mon Sep 17 00:00:00 2001
From: Lev
Date: Sat, 16 Apr 2022 11:48:51 -0700
Subject: [PATCH 2/2] rollbacks
---
pgml/pgml/model.py | 16 +++++++++++++---
sql/install.sql | 4 ++--
2 files changed, 15 insertions(+), 5 deletions(-)
diff --git a/pgml/pgml/model.py b/pgml/pgml/model.py
index d8657c927..4c265a984 100644
--- a/pgml/pgml/model.py
+++ b/pgml/pgml/model.py
@@ -440,13 +440,23 @@ def train(
)
snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling)
- best_model = None
- best_error = None
+ deployed = Model.find_deployed(project.id)
+
+ # Let's assume that the deployed model is better for now.
+ best_model = deployed
+ best_error = best_model.mean_squared_error if best_model else None
for algorithm_name in algorithms:
model = Model.create(project, snapshot, algorithm_name)
model.fit(snapshot)
+
+ # Find the better model and deploy that.
if best_error is None or model.mean_squared_error < best_error:
best_error = model.mean_squared_error
best_model = model
- best_model.deploy()
+
+ if deployed and deployed.id == best_model.id:
+ return "rolled back"
+ else:
+ best_model.deploy()
+ return "deployed"
diff --git a/sql/install.sql b/sql/install.sql
index d44fef64a..e18fdc9aa 100644
--- a/sql/install.sql
+++ b/sql/install.sql
@@ -109,9 +109,9 @@ RETURNS TABLE(project_name TEXT, objective TEXT, status TEXT)
AS $$
from pgml.model import train
- train(project_name, objective, relation_name, y_column_name)
+ status = train(project_name, objective, relation_name, y_column_name)
- return [(project_name, objective, "deployed")]
+ return [(project_name, objective, status)]
$$ LANGUAGE plpython3u;
---
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/postgresml/postgresml/pull/8.patch
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy