Skip to content

Commit 8e885b9

Browse files
author
Montana Low
committed
sketch out the regression model training cycle
1 parent 829b62e commit 8e885b9

File tree

3 files changed

+136
-76
lines changed

3 files changed

+136
-76
lines changed

pgml/pgml/model.py

Lines changed: 109 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,143 @@
1+
from cmath import e
12
import plpy
23

4+
from sklearn.linear_model import LinearRegression
5+
from sklearn.model_selection import train_test_split
6+
from sklearn.metrics import mean_squared_error, r2_score
7+
8+
import pickle
9+
10+
from pgml.exceptions import PgMLException
11+
12+
def awesome():
13+
print("hi")
14+
15+
316
class Regression:
417
"""Provides continuous real number predictions learned from the training data.
518
"""
619
def __init__(
7-
model_name: str,
20+
self,
21+
project_name: str,
822
relation_name: str,
923
y_column_name: str,
10-
implementation: str = "sklearn.linear_model"
24+
algorithm: str = "sklearn.linear_model",
25+
test_size: float or int = 0.1,
26+
test_sampling: str = "random"
1127
) -> None:
1228
"""Create a regression model from a table or view filled with training data.
1329
1430
Args:
15-
model_name (str): a human friendly identifier
31+
project_name (str): a human friendly identifier
1632
relation_name (str): the table or view that stores the training data
1733
y_column_name (str): the column in the training data that acts as the label
18-
implementation (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
34+
algorithm (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
35+
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
36+
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
1937
"""
2038

21-
data_source = f"SELECT * FROM {table_name}"
22-
23-
# Start training.
24-
start = plpy.execute(f"""
25-
INSERT INTO pgml.model_versions
26-
(name, data_source, y_column)
27-
VALUES
28-
('{table_name}', '{data_source}', '{y}')
29-
RETURNING *""", 1)
30-
31-
id_ = start[0]["id"]
32-
name = f"{table_name}_{id_}"
33-
34-
destination = models_directory(plpy)
39+
plpy.warning("snapshot")
40+
# Create a snapshot of the relation
41+
snapshot = plpy.execute(f"INSERT INTO pgml.snapshots (relation, y, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}', {test_size}, '{test_sampling}', 'new') RETURNING *", 1)[0]
42+
plpy.execute(f"""CREATE TABLE pgml.snapshot_{snapshot['id']} AS SELECT * FROM "{relation_name}";""")
43+
plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id = {snapshot['id']}")
44+
45+
plpy.warning("project")
46+
# Find or create the project
47+
project = plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'", 1)
48+
plpy.warning(f"project {project}")
49+
if (project.nrows == 1):
50+
plpy.warning("project found")
51+
project = project[0]
52+
else:
53+
try:
54+
project = plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{project_name}') RETURNING *", 1)
55+
plpy.warning(f"project inserted {project}")
56+
if (project.nrows() == 1):
57+
project = project[0]
58+
59+
except Exception as e: # handle race condition to insert
60+
plpy.warning(f"project retry: #{e}")
61+
project = plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'", 1)[0]
62+
63+
plpy.warning("model")
64+
# Create the model
65+
model = plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project['id']}, {snapshot['id']}, '{algorithm}', 'training') RETURNING *")[0]
66+
67+
plpy.warning("data")
68+
# Prepare the data
69+
data = plpy.execute(f"SELECT * FROM pgml.snapshot_{snapshot['id']}")
70+
71+
# Sanity check the data
72+
if data.nrows == 0:
73+
PgMLException(
74+
f"Relation `{y_column_name}` contains no rows. Did you pass the correct `relation_name`?"
75+
)
76+
if y_column_name not in data[0]:
77+
PgMLException(
78+
f"Column `{y_column_name}` not found. Did you pass the correct `y_column_name`?"
79+
)
80+
81+
# Always pull the columns in the same order from the row.
82+
# Python dict iteration is not always in the same order (hash table).
83+
columns = []
84+
for col in data[0]:
85+
if col != y_column_name:
86+
columns.append(col)
3587

36-
# Train!
37-
pickle, msq, r2 = train(plpy.cursor(data_source), y_column=y, name=name, destination=destination)
88+
# Split the label from the features
3889
X = []
3990
y = []
40-
columns = []
41-
42-
for row in all_rows(cursor):
43-
row = row.copy()
44-
45-
if y_column not in row:
46-
PgMLException(
47-
f"Column `{y}` not found. Did you name your `y_column` correctly?"
48-
)
49-
50-
y_ = row.pop(y_column)
91+
for row in data:
92+
plpy.warning(f"row: {row}")
93+
y_ = row.pop(y_column_name)
5194
x_ = []
5295

53-
# Always pull the columns in the same order from the row.
54-
# Python dict iteration is not always in the same order (hash table).
55-
if not columns:
56-
for col in row:
57-
columns.append(col)
58-
5996
for column in columns:
6097
x_.append(row[column])
98+
6199
X.append(x_)
62100
y.append(y_)
63101

64-
X_train, X_test, y_train, y_test = train_test_split(X, y)
65-
66-
# Just linear regression for now, but can add many more later.
67-
lr = LinearRegression()
68-
lr.fit(X_train, y_train)
69-
102+
# Split into training and test sets
103+
plpy.warning("split")
104+
if (test_sampling == 'random'):
105+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=0)
106+
else:
107+
if (test_sampling == 'first'):
108+
X.reverse()
109+
y.reverse()
110+
if isinstance(split, float):
111+
split = 1.0 - split
112+
split = test_size
113+
if isinstance(split, float):
114+
split = int(test_size * X.len())
115+
X_train, X_test, y_train, y_test = X[0:split], X[split:X.len()-1], y[0:split], y[split:y.len()-1]
116+
117+
# TODO normalize and clean data
118+
119+
plpy.warning("train")
120+
# Train the model
121+
algo = LinearRegression()
122+
algo.fit(X_train, y_train)
123+
124+
plpy.warning("test")
70125
# Test
71-
y_pred = lr.predict(X_test)
126+
y_pred = algo.predict(X_test)
72127
msq = mean_squared_error(y_test, y_pred)
73128
r2 = r2_score(y_test, y_pred)
74129

75-
path = os.path.join(destination, name)
76-
77-
if save:
78-
with open(path, "wb") as f:
79-
pickle.dump(lr, f)
80-
81-
return path, msq, r2
82-
130+
plpy.warning("save")
131+
# Save the model
132+
weights = pickle.dumps(algo)
83133

84134
plpy.execute(f"""
85-
UPDATE pgml.model_versions
86-
SET pickle = '{pickle}',
87-
successful = true,
135+
UPDATE pgml.models
136+
SET pickle = '\\x{weights.hex()}',
137+
status = 'successful',
88138
mean_squared_error = '{msq}',
89-
r2_score = '{r2}',
90-
ended_at = clock_timestamp()
91-
WHERE id = {id_}""")
92-
93-
return name
139+
r2_score = '{r2}'
140+
WHERE id = {model['id']}
141+
""")
94142

95-
model
143+
# TODO: promote the model?

sql/install.sql

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,35 @@ CREATE TABLE pgml.projects(
4747
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
4848
);
4949
SELECT pgml.auto_updated_at('pgml.projects');
50+
CREATE UNIQUE INDEX projects_name_idx ON pgml.projects(name);
5051

5152
CREATE TABLE pgml.snapshots(
5253
id BIGSERIAL PRIMARY KEY,
5354
relation TEXT NOT NULL,
5455
y TEXT NOT NULL,
55-
validation_ratio FLOAT4 NOT NULL,
56-
validation_strategy TEXT NOT NULL,
56+
test_size FLOAT4 NOT NULL,
57+
test_sampling TEXT NOT NULL,
58+
status TEXT NOT NULL,
5759
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
5860
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
5961
);
6062
SELECT pgml.auto_updated_at('pgml.snapshots');
6163

6264
CREATE TABLE pgml.models(
6365
id BIGSERIAL PRIMARY KEY,
64-
project_id BIGINT,
65-
snapshot_id BIGINT,
66+
project_id BIGINT NOT NULL,
67+
snapshot_id BIGINT NOT NULL,
68+
algorithm TEXT NOT NULL,
69+
status TEXT NOT NULL,
6670
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
6771
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
72+
mean_squared_error DOUBLE PRECISION,
73+
r2_score DOUBLE PRECISION,
6874
pickle BYTEA,
6975
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml.projects(id),
7076
CONSTRAINT snapshot_id_fk FOREIGN KEY(snapshot_id) REFERENCES pgml.snapshots(id)
7177
);
78+
CREATE INDEX models_project_id_created_at_idx ON pgml.models(project_id, created_at);
7279
SELECT pgml.auto_updated_at('pgml.models');
7380

7481
CREATE TABLE pgml.promotions(
@@ -92,11 +99,12 @@ AS $$
9299
return pgml.version()
93100
$$ LANGUAGE plpython3u;
94101

95-
CREATE OR REPLACE FUNCTION pgml.model_regression(model_name TEXT, relation_name TEXT, y_column_name TEXT, algorithm TEXT)
102+
CREATE OR REPLACE FUNCTION pgml.model_regression(project_name TEXT, relation_name TEXT, y_column_name TEXT)
96103
RETURNS VOID
97104
AS $$
98105
import pgml
99-
pgml.model.regression(model_name, relation_name, y_column_name, algorithm)
106+
from pgml.model import Regression
107+
Regression(project_name, relation_name, y_column_name)
100108
$$ LANGUAGE plpython3u;
101109

102110

sql/test.sql

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
SELECT pgml.version();
88

99
-- Train twice
10-
SELECT pgml.train('wine_quality_red', 'quality');
10+
-- SELECT pgml.train('wine_quality_red', 'quality');
1111

12-
SELECT * FROM pgml.model_versions;
12+
-- SELECT * FROM pgml.model_versions;
13+
14+
-- \timing
15+
-- WITH latest_model AS (
16+
-- SELECT name || '_' || id AS model_name FROM pgml.model_versions ORDER BY id DESC LIMIT 1
17+
-- )
18+
-- SELECT pgml.score(
19+
-- (SELECT model_name FROM latest_model), -- last model we just trained
20+
-- 7.4, 0.7, 0, 1.9, 0.076, 11, 34, 0.99, 2, 0.5, 9.4 -- features as variadic arguments
21+
-- ) AS score;
1322

1423
\timing
15-
WITH latest_model AS (
16-
SELECT name || '_' || id AS model_name FROM pgml.model_versions ORDER BY id DESC LIMIT 1
17-
)
18-
SELECT pgml.score(
19-
(SELECT model_name FROM latest_model), -- last model we just trained
20-
7.4, 0.7, 0, 1.9, 0.076, 11, 34, 0.99, 2, 0.5, 9.4 -- features as variadic arguments
21-
) AS score;
24+
25+
SELECT pgml.model_regression('Red Wine', 'wine_quality_red', 'quality');

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy