Skip to content

Commit 829b62e

Browse files
author
Montana Low
committed
add a draft schema to support snapshots and multiple training runs for a project
1 parent 14b1f61 commit 829b62e

File tree

6 files changed

+214
-13
lines changed

6 files changed

+214
-13
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ PostgresML aims to be the easiest way to gain value from machine learning. Anyon
55
Getting started is as easy as creating a `table` or `view` that holds the training data, and then registering that with PostgresML.
66

77
```sql
8-
SELECT pgml.create_regression('Red Wine Quality', training_data_table_or_view_name, label_column_name);
8+
SELECT pgml.model_regression('Red Wine Quality', training_data_table_or_view_name, label_column_name);
99
```
1010

1111
And predict novel datapoints:
@@ -23,7 +23,7 @@ LIMIT 3;
2323
(3 rows)
2424
```
2525

26-
PostgresML similarly supports classification to predict numeric scores rather than classes for novel data.
26+
PostgresML similarly supports classification to predict discrete classes rather than numeric scores for novel data.
2727

2828
```sql
2929
SELECT pgml.create_classification('Handwritten Digit Classifier', pgml.mnist_training_data, label_column_name);

benchmarks.sql

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
--
2+
-- CREATE EXTENSION
3+
--
4+
CREATE EXTENSION IF NOT EXISTS plpython3u;
5+
6+
CREATE OR REPLACE FUNCTION pg_call()
7+
RETURNS INT
8+
AS $$
9+
BEGIN
10+
RETURN 1;
11+
END;
12+
$$ LANGUAGE plpgsql;
13+
14+
CREATE OR REPLACE FUNCTION py_call()
15+
RETURNS INT
16+
AS $$
17+
return 1;
18+
$$ LANGUAGE plpython3u;
19+
20+
\timing on
21+
SELECT generate_series(1, 50000), pg_call(); -- Time: 20.679 ms
22+
SELECT generate_series(1, 50000), py_call(); -- Time: 67.355 ms
23+

pgml/pgml/model.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import plpy
2+
3+
class Regression:
4+
"""Provides continuous real number predictions learned from the training data.
5+
"""
6+
def __init__(
7+
model_name: str,
8+
relation_name: str,
9+
y_column_name: str,
10+
implementation: str = "sklearn.linear_model"
11+
) -> None:
12+
"""Create a regression model from a table or view filled with training data.
13+
14+
Args:
15+
model_name (str): a human friendly identifier
16+
relation_name (str): the table or view that stores the training data
17+
y_column_name (str): the column in the training data that acts as the label
18+
implementation (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model".
19+
"""
20+
21+
data_source = f"SELECT * FROM {table_name}"
22+
23+
# Start training.
24+
start = plpy.execute(f"""
25+
INSERT INTO pgml.model_versions
26+
(name, data_source, y_column)
27+
VALUES
28+
('{table_name}', '{data_source}', '{y}')
29+
RETURNING *""", 1)
30+
31+
id_ = start[0]["id"]
32+
name = f"{table_name}_{id_}"
33+
34+
destination = models_directory(plpy)
35+
36+
# Train!
37+
pickle, msq, r2 = train(plpy.cursor(data_source), y_column=y, name=name, destination=destination)
38+
X = []
39+
y = []
40+
columns = []
41+
42+
for row in all_rows(cursor):
43+
row = row.copy()
44+
45+
if y_column not in row:
46+
PgMLException(
47+
f"Column `{y}` not found. Did you name your `y_column` correctly?"
48+
)
49+
50+
y_ = row.pop(y_column)
51+
x_ = []
52+
53+
# Always pull the columns in the same order from the row.
54+
# Python dict iteration is not always in the same order (hash table).
55+
if not columns:
56+
for col in row:
57+
columns.append(col)
58+
59+
for column in columns:
60+
x_.append(row[column])
61+
X.append(x_)
62+
y.append(y_)
63+
64+
X_train, X_test, y_train, y_test = train_test_split(X, y)
65+
66+
# Just linear regression for now, but can add many more later.
67+
lr = LinearRegression()
68+
lr.fit(X_train, y_train)
69+
70+
# Test
71+
y_pred = lr.predict(X_test)
72+
msq = mean_squared_error(y_test, y_pred)
73+
r2 = r2_score(y_test, y_pred)
74+
75+
path = os.path.join(destination, name)
76+
77+
if save:
78+
with open(path, "wb") as f:
79+
pickle.dump(lr, f)
80+
81+
return path, msq, r2
82+
83+
84+
plpy.execute(f"""
85+
UPDATE pgml.model_versions
86+
SET pickle = '{pickle}',
87+
successful = true,
88+
mean_squared_error = '{msq}',
89+
r2_score = '{r2}',
90+
ended_at = clock_timestamp()
91+
WHERE id = {id_}""")
92+
93+
return name
94+
95+
model

pgml/pgml/sql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tools to run SQL.
22
"""
33
import os
4+
import plpy
45

56

67
def all_rows(cursor):
@@ -14,7 +15,7 @@ def all_rows(cursor):
1415
yield row
1516

1617

17-
def models_directory(plpy):
18+
def models_directory():
1819
"""Get the directory where we store our models."""
1920
data_directory = plpy.execute(
2021
"""

sql/install.sql

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,87 @@
1+
SET client_min_messages TO WARNING;
12

23
-- Create the PL/Python3 extension.
34
CREATE EXTENSION IF NOT EXISTS plpython3u;
45

6+
---
7+
--- Create schema for models.
8+
---
59
DROP SCHEMA pgml CASCADE;
610
CREATE SCHEMA IF NOT EXISTS pgml;
711

12+
CREATE OR REPLACE FUNCTION pgml.auto_updated_at(tbl regclass)
13+
RETURNS VOID
14+
AS $$
15+
DECLARE name_parts TEXT[];
16+
DECLARE name TEXT;
17+
BEGIN
18+
name_parts := string_to_array(tbl::TEXT, '.');
19+
name := name_parts[array_upper(name_parts, 1)];
20+
21+
EXECUTE format('DROP TRIGGER IF EXISTS %s_auto_updated_at ON %s', name, tbl);
22+
EXECUTE format('CREATE TRIGGER %s_auto_updated_at BEFORE UPDATE ON %s
23+
FOR EACH ROW EXECUTE PROCEDURE pgml.set_updated_at()', name, tbl);
24+
END;
25+
$$
26+
LANGUAGE plpgsql;
27+
28+
CREATE OR REPLACE FUNCTION pgml.set_updated_at()
29+
RETURNS TRIGGER
30+
AS $$
31+
BEGIN
32+
IF (
33+
NEW IS DISTINCT FROM OLD
34+
AND NEW.updated_at IS NOT DISTINCT FROM OLD.updated_at
35+
) THEN
36+
NEW.updated_at := CURRENT_TIMESTAMP;
37+
END IF;
38+
RETURN new;
39+
END;
40+
$$
41+
LANGUAGE plpgsql;
42+
43+
CREATE TABLE pgml.projects(
44+
id BIGSERIAL PRIMARY KEY,
45+
name TEXT NOT NULL,
46+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
47+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
48+
);
49+
SELECT pgml.auto_updated_at('pgml.projects');
50+
51+
CREATE TABLE pgml.snapshots(
52+
id BIGSERIAL PRIMARY KEY,
53+
relation TEXT NOT NULL,
54+
y TEXT NOT NULL,
55+
validation_ratio FLOAT4 NOT NULL,
56+
validation_strategy TEXT NOT NULL,
57+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
58+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
59+
);
60+
SELECT pgml.auto_updated_at('pgml.snapshots');
61+
62+
CREATE TABLE pgml.models(
63+
id BIGSERIAL PRIMARY KEY,
64+
project_id BIGINT,
65+
snapshot_id BIGINT,
66+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
67+
updated_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
68+
pickle BYTEA,
69+
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml.projects(id),
70+
CONSTRAINT snapshot_id_fk FOREIGN KEY(snapshot_id) REFERENCES pgml.snapshots(id)
71+
);
72+
SELECT pgml.auto_updated_at('pgml.models');
73+
74+
CREATE TABLE pgml.promotions(
75+
project_id BIGINT NOT NULL,
76+
model_id BIGINT NOT NULL,
77+
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
78+
CONSTRAINT project_id_fk FOREIGN KEY(project_id) REFERENCES pgml.projects(id),
79+
CONSTRAINT model_id_fk FOREIGN KEY(model_id) REFERENCES pgml.models(id)
80+
);
81+
CREATE INDEX promotions_project_id_created_at_idx ON pgml.promotions(project_id, created_at);
82+
SELECT pgml.auto_updated_at('pgml.promotions');
83+
84+
885
---
986
--- Extension version.
1087
---
@@ -15,20 +92,28 @@ AS $$
1592
return pgml.version()
1693
$$ LANGUAGE plpython3u;
1794

95+
CREATE OR REPLACE FUNCTION pgml.model_regression(model_name TEXT, relation_name TEXT, y_column_name TEXT, algorithm TEXT)
96+
RETURNS VOID
97+
AS $$
98+
import pgml
99+
pgml.model.regression(model_name, relation_name, y_column_name, algorithm)
100+
$$ LANGUAGE plpython3u;
101+
102+
18103
---
19104
--- Track table versions.
20105
---
21106
CREATE TABLE pgml.model_versions(
22107
id BIGSERIAL PRIMARY KEY,
23-
name VARCHAR,
24-
location VARCHAR NULL,
108+
name VARCHAR NOT NULL,
25109
data_source TEXT,
26110
y_column VARCHAR,
27111
started_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP,
28112
ended_at TIMESTAMP WITHOUT TIME ZONE NULL,
29113
mean_squared_error DOUBLE PRECISION,
30114
r2_score DOUBLE PRECISION,
31-
successful BOOL NULL
115+
successful BOOL NULL,
116+
pickle BYTEA
32117
);
33118

34119
---
@@ -54,14 +139,14 @@ AS $$
54139
id_ = start[0]["id"]
55140
name = f"{table_name}_{id_}"
56141

57-
destination = models_directory(plpy)
142+
destination = models_directory()
58143

59144
# Train!
60-
location, msq, r2 = train(plpy.cursor(data_source), y_column=y, name=name, destination=destination)
145+
pickle, msq, r2 = train(plpy.cursor(data_source), y_column=y, name=name, destination=destination)
61146

62147
plpy.execute(f"""
63148
UPDATE pgml.model_versions
64-
SET location = '{location}',
149+
SET pickle = '{pickle}',
65150
successful = true,
66151
mean_squared_error = '{msq}',
67152
r2_score = '{r2}',
@@ -85,7 +170,7 @@ AS $$
85170
if model_name in SD:
86171
model = SD[model_name]
87172
else:
88-
SD[model_name] = load(model_name, models_directory(plpy))
173+
SD[model_name] = load(model_name, models_directory())
89174
model = SD[model_name]
90175

91176
scores = model.predict([features,])

sql/test.sql

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66

77
SELECT pgml.version();
88

9-
-- Valiate our wine data.
10-
SELECT pgml.validate('wine_quality_red');
11-
129
-- Train twice
1310
SELECT pgml.train('wine_quality_red', 'quality');
1411

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy