Skip to content

Commit 53c7e43

Browse files
authored
Merge pull request #14 from postgresml/montana/api
add a new example
2 parents 415f2e1 + a9f9dc2 commit 53c7e43

File tree

6 files changed

+252
-52
lines changed

6 files changed

+252
-52
lines changed

examples/digits/run.sql

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
-- This example trains models on the sklean digits dataset
2+
-- which is a copy of the test set of the UCI ML hand-written digits datasets
3+
-- https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits
4+
--
5+
-- The final result after a few seconds of training is not terrible. Maybe not perfect
6+
-- enough for mission critical applications, but it's telling how quickly "off the shelf"
7+
-- solutions can solve problems these days.
8+
SELECT pgml.load_dataset('digits');
9+
10+
-- view the dataset
11+
SELECT * from pgml.digits;
12+
13+
-- train a simple model to classify the data
14+
SELECT pgml.train('Handwritten Digit Image Classifier', 'classification', 'pgml.digits', 'target');
15+
16+
-- check out the predictions
17+
SELECT target, pgml.predict('Handwritten Digit Image Classifier', image) AS prediction
18+
FROM pgml.digits
19+
LIMIT 10;
20+
21+
-- -- train some more models with different algorithms
22+
SELECT pgml.train('Handwritten Digit Image Classifier', 'classification', 'pgml.digits', 'target', 'svm');
23+
SELECT pgml.train('Handwritten Digit Image Classifier', 'classification', 'pgml.digits', 'target', 'random_forest');
24+
SELECT pgml.train('Handwritten Digit Image Classifier', 'classification', 'pgml.digits', 'target', 'gradient_boosting_trees');
25+
-- TODO SELECT pgml.train('Handwritten Digit Image Classifier', 'classification', 'pgml.digits', 'target', 'dense_neural_network');
26+
-- -- check out all that hard work
27+
SELECT * FROM pgml.trained_models;
28+
29+
-- deploy the random_forest model for prediction use
30+
SELECT pgml.deploy('Handwritten Digit Image Classifier', 'random_forest');
31+
-- check out that throughput
32+
SELECT * FROM pgml.deployed_models;
33+
34+
-- do some hyper param tuning
35+
-- TODO SELECT pgml.hypertune(100, 'Handwritten Digit Image Classifier', 'classification', 'pgml.digits', 'target', 'gradient_boosted_trees');
36+
-- deploy the "best" model for prediction use
37+
SELECT pgml.deploy('Handwritten Digit Image Classifier', 'best_fit');
38+
39+
-- check out the improved predictions
40+
SELECT target, pgml.predict('Handwritten Digit Image Classifier', image) AS prediction
41+
FROM pgml.digits
42+
LIMIT 10;

pgml/pgml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
def version():
2-
return "0.4.1"
2+
return "0.4.2"

pgml/pgml/datasets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import plpy
2+
from sklearn.datasets import load_digits as d
3+
4+
from pgml.sql import q
5+
from pgml.exceptions import PgMLException
6+
7+
def load(source: str):
8+
if source == "digits":
9+
load_digits()
10+
else:
11+
raise PgMLException(f"Invalid dataset name: {source}. Valid values are ['digits'].")
12+
return "OK"
13+
14+
def load_digits():
15+
dataset = d()
16+
a = plpy.execute("DROP TABLE IF EXISTS pgml.digits")
17+
a = plpy.execute("CREATE TABLE pgml.digits (image SMALLINT[], target INTEGER)")
18+
a = plpy.execute(f"""COMMENT ON TABLE pgml.digits IS {q(dataset["DESCR"])}""")
19+
for X, y in zip(dataset["data"], dataset["target"]):
20+
X = ",".join("%i" % x for x in list(X))
21+
plpy.execute(f"""INSERT INTO pgml.digits (image, target) VALUES ('{{{X}}}', {y})""")

pgml/pgml/model.py

Lines changed: 107 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1+
from re import M
12
import plpy
2-
from sklearn.linear_model import LinearRegression
3-
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
3+
from sklearn.linear_model import LinearRegression, LogisticRegression
4+
from sklearn.svm import SVR, SVC
5+
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, GradientBoostingRegressor, GradientBoostingClassifier
46
from sklearn.model_selection import train_test_split
5-
from sklearn.metrics import mean_squared_error, r2_score
7+
from sklearn.metrics import mean_squared_error, r2_score, f1_score, precision_score, recall_score
68

79
import pickle
10+
import json
811

912
from pgml.exceptions import PgMLException
1013
from pgml.sql import q
1114

15+
def flatten(S):
16+
if S == []:
17+
return S
18+
if isinstance(S[0], list):
19+
return flatten(S[0]) + flatten(S[1:])
20+
return S[:1] + flatten(S[1:])
1221

1322
class Project(object):
1423
"""
@@ -124,6 +133,14 @@ def deployed_model(self):
124133
self._deployed_model = Model.find_deployed(self.id)
125134
return self._deployed_model
126135

136+
def deploy(self, algorithm_name):
137+
model = None
138+
if algorithm_name == "best_fit":
139+
model = Model.find_by_project_and_best_fit(self)
140+
else:
141+
model = Model.find_by_project_id_and_algorithm_name(self.id, algorithm_name)
142+
model.deploy()
143+
return model
127144

128145
class Snapshot(object):
129146
"""
@@ -178,7 +195,7 @@ def create(
178195
plpy.execute(
179196
f"""
180197
CREATE TABLE pgml."snapshot_{snapshot.id}" AS
181-
SELECT * FROM "{snapshot.relation_name}";
198+
SELECT * FROM {snapshot.relation_name};
182199
"""
183200
)
184201
snapshot.__dict__ = dict(
@@ -232,6 +249,7 @@ def data(self):
232249
for column in columns:
233250
x_.append(row[column])
234251

252+
x_ = flatten(x_) # TODO be smart about flattening X depending on algorithm
235253
X.append(x_)
236254
y.append(y_)
237255

@@ -262,8 +280,7 @@ class Model(object):
262280
status (str): The current status of the model, e.g. 'new', 'training' or 'successful'
263281
created_at (Timestamp): when this model was created
264282
updated_at (Timestamp): when this model was last updated
265-
mean_squared_error (float):
266-
r2_score (float):
283+
metrics (dict): key performance indicators for the model
267284
pickle (bytes): the serialized version of the model parameters
268285
algorithm: the in memory version of the model parameters that can make predictions
269286
"""
@@ -320,6 +337,63 @@ def find_deployed(cls, project_id: int):
320337
model.__init__()
321338
return model
322339

340+
@classmethod
341+
def find_by_project_id_and_algorithm_name(cls, project_id: int, algorithm_name: str):
342+
"""
343+
Args:
344+
project_id (int): The project id
345+
algorithm_name (str): The algorithm
346+
Returns:
347+
Model: most recently created model that fits the criteria
348+
"""
349+
result = plpy.execute(
350+
f"""
351+
SELECT models.*
352+
FROM pgml.models
353+
WHERE algorithm_name = {q(algorithm_name)}
354+
AND project_id = {q(project_id)}
355+
ORDER by models.created_at DESC
356+
LIMIT 1
357+
"""
358+
)
359+
if len(result) == 0:
360+
return None
361+
362+
model = Model()
363+
model.__dict__ = dict(result[0])
364+
model.__init__()
365+
return model
366+
367+
@classmethod
368+
def find_by_project_and_best_fit(cls, project: Project):
369+
"""
370+
Args:
371+
project (Project): The project
372+
Returns:
373+
Model: the model with the best metrics for the project
374+
"""
375+
if project.objective == "regression":
376+
metric = "mean_squared_error"
377+
elif project.objective == "classification":
378+
metric = "f1"
379+
380+
result = plpy.execute(
381+
f"""
382+
SELECT models.*
383+
FROM pgml.models
384+
WHERE project_id = {q(project.id)}
385+
ORDER by models.metrics->>{q(metric)} DESC
386+
LIMIT 1
387+
"""
388+
)
389+
if len(result) == 0:
390+
return None
391+
392+
model = Model()
393+
model.__dict__ = dict(result[0])
394+
model.__init__()
395+
return model
396+
323397
def __init__(self):
324398
self._algorithm = None
325399
self._project = None
@@ -342,8 +416,13 @@ def algorithm(self):
342416
else:
343417
self._algorithm = {
344418
"linear_regression": LinearRegression,
419+
"linear_classification": LogisticRegression,
420+
"svm_regression": SVR,
421+
"svm_classification": SVC,
345422
"random_forest_regression": RandomForestRegressor,
346423
"random_forest_classification": RandomForestClassifier,
424+
"gradient_boosting_trees_regression": GradientBoostingRegressor,
425+
"gradient_boosting_trees_classification": GradientBoostingClassifier,
347426
}[self.algorithm_name + "_" + self.project.objective]()
348427

349428
return self._algorithm
@@ -362,8 +441,14 @@ def fit(self, snapshot: Snapshot):
362441

363442
# Test
364443
y_pred = self.algorithm.predict(X_test)
365-
msq = mean_squared_error(y_test, y_pred)
366-
r2 = r2_score(y_test, y_pred)
444+
metrics = {}
445+
if self.project.objective == "regression":
446+
metrics["mean_squared_error"] = mean_squared_error(y_test, y_pred)
447+
metrics["r2"] = r2_score(y_test, y_pred)
448+
elif self.project.objective == "classification":
449+
metrics["f1"] = f1_score(y_test, y_pred, average="weighted")
450+
metrics["precision"] = precision_score(y_test, y_pred, average="weighted")
451+
metrics["recall"] = recall_score(y_test, y_pred, average="weighted")
367452

368453
# Save the model
369454
self.__dict__ = dict(
@@ -372,8 +457,7 @@ def fit(self, snapshot: Snapshot):
372457
UPDATE pgml.models
373458
SET pickle = '\\x{pickle.dumps(self.algorithm).hex()}',
374459
status = 'successful',
375-
mean_squared_error = {q(msq)},
376-
r2_score = {q(r2)}
460+
metrics = {q(json.dumps(metrics))}
377461
WHERE id = {q(self.id)}
378462
RETURNING *
379463
"""
@@ -398,6 +482,7 @@ def predict(self, data: list):
398482
Returns:
399483
float or int: scores for regressions or ints for classifications
400484
"""
485+
# TODO: add metrics for tracking prediction volume/accuracy by model
401486
return self.algorithm.predict(data)
402487

403488

@@ -406,6 +491,7 @@ def train(
406491
objective: str,
407492
relation_name: str,
408493
y_column_name: str,
494+
algorithm_name: str = "linear",
409495
test_size: float or int = 0.1,
410496
test_sampling: str = "random",
411497
):
@@ -416,15 +502,14 @@ def train(
416502
objective (str): Defaults to "regression". Valid values are ["regression", "classification"].
417503
relation_name (str): the table or view that stores the training data
418504
y_column_name (str): the column in the training data that acts as the label
419-
algorithm (str, optional): the algorithm used to implement the objective. Defaults to "linear". Valid values are ["linear", "random_forest"].
505+
algorithm_name (str, optional): the algorithm used to implement the objective. Defaults to "linear". Valid values are ["linear", "svm", "random_forest", "gradient_boosting"].
420506
test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25.
421507
test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"].
422508
"""
423-
if objective == "regression":
424-
algorithms = ["linear", "random_forest"]
425-
elif objective == "classification":
426-
algorithms = ["random_forest"]
427-
else:
509+
if algorithm_name is None:
510+
algorithm_name = "linear"
511+
512+
if objective not in ["regression", "classification"]:
428513
raise PgMLException(
429514
f"Unknown objective `{objective}`, available options are: regression, classification."
430515
)
@@ -440,23 +525,11 @@ def train(
440525
)
441526

442527
snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling)
443-
deployed = Model.find_deployed(project.id)
444-
445-
# Let's assume that the deployed model is better for now.
446-
best_model = deployed
447-
best_error = best_model.mean_squared_error if best_model else None
448-
449-
for algorithm_name in algorithms:
450-
model = Model.create(project, snapshot, algorithm_name)
451-
model.fit(snapshot)
528+
model = Model.create(project, snapshot, algorithm_name)
529+
model.fit(snapshot)
452530

453-
# Find the better model and deploy that.
454-
if best_error is None or model.mean_squared_error < best_error:
455-
best_error = model.mean_squared_error
456-
best_model = model
457-
458-
if deployed and deployed.id == best_model.id:
459-
return "rolled back"
460-
else:
461-
best_model.deploy()
531+
if project.deployed_model is None:
532+
model.deploy()
462533
return "deployed"
534+
else:
535+
return "not deployed"

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy