Skip to content

Commit 9907aaa

Browse files
author
Montana Low
committed
sketch out the regression model training cycle
1 parent 829b62e commit 9907aaa

File tree

3 files changed

+132
-76
lines changed

3 files changed

+132
-76
lines changed

pgml/pgml/model.py

Lines changed: 105 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,139 @@
1+
from cmath import e
12
import plpy
23

4+
from sklearn.linear_model import LinearRegression
5+
from sklearn.model_selection import train_test_split
6+
from sklearn.metrics import mean_squared_error, r2_score
7+
8+
import pickle
9+
10+
from pgml.exceptions import PgMLException
11+
312
class Regression:
413
"""Provides continuous real number predictions learned from the training data.
514
"""
615
def __init__(
7-
model_name: str,
16+
self,
17+
project_name: str,
818
relation_name: str,
919
y_column_name: str,
10-
implementation: str = "sklearn.linear_model"
20+
algorithm: str = "sklearn.linear_model",
21+
test_size: float or int = 0.1,
22+
test_sampling: str = "random"
1123
) -> None:
1224
"""Create a regression model from a table or view filled with training data.
1325
1426
Args:
15-
model_name (str): a human friendly identifier
27+
project_name (str): a human friendly identifier
1628
relation_name (str): the table or view that stores the training data
1729
y_column_name (str): the column in the training data that acts as the label
18-
implementation (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
30+
algorithm (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
31+
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
32+
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
1933
"""
2034

21-
data_source = f"SELECT * FROM {table_name}"
22-
23-
# Start training.
24-
start = plpy.execute(f"""
25-
INSERT INTO pgml.model_versions
26-
(name, data_source, y_column)
27-
VALUES
28-
('{table_name}', '{data_source}', '{y}')
29-
RETURNING *""", 1)
30-
31-
id_ = start[0]["id"]
32-
name = f"{table_name}_{id_}"
33-
34-
destination = models_directory(plpy)
35+
plpy.warning("snapshot")
36+
# Create a snapshot of the relation
37+
snapshot = plpy.execute(f"INSERT INTO pgml.snapshots (relation, y, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}', {test_size}, '{test_sampling}', 'new') RETURNING *", 1)[0]
38+
plpy.execute(f"""CREATE TABLE pgml.snapshot_{snapshot['id']} AS SELECT * FROM "{relation_name}";""")
39+
plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id = {snapshot['id']}")
40+
41+
plpy.warning("project")
42+
# Find or create the project
43+
project = plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'", 1)
44+
plpy.warning(f"project {project}")
45+
if (project.nrows == 1):
46+
plpy.warning("project found")
47+
project = project[0]
48+
else:
49+
try:
50+
project = plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{project_name}') RETURNING *", 1)
51+
plpy.warning(f"project inserted {project}")
52+
if (project.nrows() == 1):
53+
project = project[0]
54+
55+
except Exception as e: # handle race condition to insert
56+
plpy.warning(f"project retry: #{e}")
57+
project = plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'", 1)[0]
58+
59+
plpy.warning("model")
60+
# Create the model
61+
model = plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project['id']}, {snapshot['id']}, '{algorithm}', 'training') RETURNING *")[0]
62+
63+
plpy.warning("data")
64+
# Prepare the data
65+
data = plpy.execute(f"SELECT * FROM pgml.snapshot_{snapshot['id']}")
66+
67+
# Sanity check the data
68+
if data.nrows == 0:
69+
PgMLException(
70+
f"Relation `{y_column_name}` contains no rows. Did you pass the correct `relation_name`?"
71+
)
72+
if y_column_name not in data[0]:
73+
PgMLException(
74+
f"Column `{y_column_name}` not found. Did you pass the correct `y_column_name`?"
75+
)
76+
77+
# Always pull the columns in the same order from the row.
78+
# Python dict iteration is not always in the same order (hash table).
79+
columns = []
80+
for col in data[0]:
81+
if col != y_column_name:
82+
columns.append(col)
3583

36-
# Train!
37-
pickle, msq, r2 = train(plpy.cursor(data_source), y_column=y, name=name, destination=destination)
84+
# Split the label from the features
3885
X = []
3986
y = []
40-
columns = []
41-
42-
for row in all_rows(cursor):
43-
row = row.copy()
44-
45-
if y_column not in row:
46-
PgMLException(
47-
f"Column `{y}` not found. Did you name your `y_column` correctly?"
48-
)
49-
50-
y_ = row.pop(y_column)
87+
for row in data:
88+
plpy.warning(f"row: {row}")
89+
y_ = row.pop(y_column_name)
5190
x_ = []
5291

53-
# Always pull the columns in the same order from the row.
54-
# Python dict iteration is not always in the same order (hash table).
55-
if not columns:
56-
for col in row:
57-
columns.append(col)
58-
5992
for column in columns:
6093
x_.append(row[column])
94+
6195
X.append(x_)
6296
y.append(y_)
6397

64-
X_train, X_test, y_train, y_test = train_test_split(X, y)
65-
66-
# Just linear regression for now, but can add many more later.
67-
lr = LinearRegression()
68-
lr.fit(X_train, y_train)
69-
98+
# Split into training and test sets
99+
plpy.warning("split")
100+
if (test_sampling == 'random'):
101+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=0)
102+
else:
103+
if (test_sampling == 'first'):
104+
X.reverse()
105+
y.reverse()
106+
if isinstance(split, float):
107+
split = 1.0 - split
108+
split = test_size
109+
if isinstance(split, float):
110+
split = int(test_size * X.len())
111+
X_train, X_test, y_train, y_test = X[0:split], X[split:X.len()-1], y[0:split], y[split:y.len()-1]
112+
113+
# TODO normalize and clean data
114+
115+
plpy.warning("train")
116+
# Train the model
117+
algo = LinearRegression()
118+
algo.fit(X_train, y_train)
119+
120+
plpy.warning("test")
70121
# Test
71-
y_pred = lr.predict(X_test)
122+
y_pred = algo.predict(X_test)
72123
msq = mean_squared_error(y_test, y_pred)
73124
r2 = r2_score(y_test, y_pred)
74125

75-
path = os.path.join(destination, name)
76-
77-
if save:
78-
with open(path, "wb") as f:
79-
pickle.dump(lr, f)
80-
81-
return path, msq, r2
82-
126+
plpy.warning("save")
127+
# Save the model
128+
weights = pickle.dumps(algo)
83129

84130
plpy.execute(f"""
85-
UPDATE pgml.model_versions
86-
SET pickle = '{pickle}',
87-
successful = true,
131+
UPDATE pgml.models
132+
SET pickle = '\\x{weights.hex()}',
133+
status = 'successful',
88134
mean_squared_error = '{msq}',
89-
r2_score = '{r2}',
90-
ended_at = clock_timestamp()
91-
WHERE id = {id_}""")
92-
93-
return name
135+
r2_score = '{r2}'
136+
WHERE id = {model['id']}
137+
""")
94138

95-
model
139+
# TODO: promote the model?

sql/install.sql

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,35 @@ CREATE TABLE pgml.projects(
4747
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
4848
);
4949
SELECT pgml.auto_updated_at('pgml.projects');
50+
CREATE UNIQUE INDEX projects_name_idx ON pgml.projects(name);
5051

5152
CREATE TABLE pgml.snapshots(
5253
id BIGSERIAL PRIMARY KEY,
5354
relation TEXT NOT NULL,
5455
y TEXT NOT NULL,
55-
validation_ratio FLOAT4 NOT NULL,
56-
validation_strategy TEXT NOT NULL,
56+
test_size FLOAT4 NOT NULL,
57+
test_sampling TEXT NOT NULL,
58+
status TEXT NOT NULL,
5759
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
5860
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
5961
);
6062
SELECT pgml.auto_updated_at('pgml.snapshots');
6163

6264
CREATE TABLE pgml.models(
6365
id BIGSERIAL PRIMARY KEY,
64-
project_id BIGINT,
65-
snapshot_id BIGINT,
66+
project_id BIGINT NOT NULL,
67+
snapshot_id BIGINT NOT NULL,
68+
algorithm TEXT NOT NULL,
69+
status TEXT NOT NULL,
6670
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
6771
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
72+
mean_squared_error DOUBLE PRECISION,
73+
r2_score DOUBLE PRECISION,
6874
pickle BYTEA,
6975
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml.projects(id),
7076
CONSTRAINT snapshot_id_fk FOREIGN KEY(snapshot_id) REFERENCES pgml.snapshots(id)
7177
);
78+
CREATE INDEX models_project_id_created_at_idx ON pgml.models(project_id, created_at);
7279
SELECT pgml.auto_updated_at('pgml.models');
7380

7481
CREATE TABLE pgml.promotions(
@@ -92,11 +99,12 @@ AS $$
9299
return pgml.version()
93100
$$ LANGUAGE plpython3u;
94101

95-
CREATE OR REPLACE FUNCTION pgml.model_regression(model_name TEXT, relation_name TEXT, y_column_name TEXT, algorithm TEXT)
102+
CREATE OR REPLACE FUNCTION pgml.model_regression(project_name TEXT, relation_name TEXT, y_column_name TEXT)
96103
RETURNS VOID
97104
AS $$
98105
import pgml
99-
pgml.model.regression(model_name, relation_name, y_column_name, algorithm)
106+
from pgml.model import Regression
107+
Regression(project_name, relation_name, y_column_name)
100108
$$ LANGUAGE plpython3u;
101109

102110

sql/test.sql

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
SELECT pgml.version();
88

99
-- Train twice
10-
SELECT pgml.train('wine_quality_red', 'quality');
10+
-- SELECT pgml.train('wine_quality_red', 'quality');
1111

12-
SELECT * FROM pgml.model_versions;
12+
-- SELECT * FROM pgml.model_versions;
13+
14+
-- \timing
15+
-- WITH latest_model AS (
16+
-- SELECT name || '_' || id AS model_name FROM pgml.model_versions ORDER BY id DESC LIMIT 1
17+
-- )
18+
-- SELECT pgml.score(
19+
-- (SELECT model_name FROM latest_model), -- last model we just trained
20+
-- 7.4, 0.7, 0, 1.9, 0.076, 11, 34, 0.99, 2, 0.5, 9.4 -- features as variadic arguments
21+
-- ) AS score;
1322

1423
\timing
15-
WITH latest_model AS (
16-
SELECT name || '_' || id AS model_name FROM pgml.model_versions ORDER BY id DESC LIMIT 1
17-
)
18-
SELECT pgml.score(
19-
(SELECT model_name FROM latest_model), -- last model we just trained
20-
7.4, 0.7, 0, 1.9, 0.076, 11, 34, 0.99, 2, 0.5, 9.4 -- features as variadic arguments
21-
) AS score;
24+
25+
SELECT pgml.model_regression('Red Wine', 'wine_quality_red', 'quality');

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy