|
| 1 | +from cmath import e |
1 | 2 | import plpy
|
2 | 3 |
|
| 4 | +from sklearn.linear_model import LinearRegression |
| 5 | +from sklearn.model_selection import train_test_split |
| 6 | +from sklearn.metrics import mean_squared_error, r2_score |
| 7 | + |
| 8 | +import pickle |
| 9 | + |
| 10 | +from pgml.exceptions import PgMLException |
| 11 | + |
| 12 | +def awesome(): |
| 13 | + print("hi") |
| 14 | + |
| 15 | + |
3 | 16 | class Regression:
|
4 | 17 | """Provides continuous real number predictions learned from the training data.
|
5 | 18 | """
|
6 | 19 | def __init__(
|
7 |
| - model_name: str, |
| 20 | + self, |
| 21 | + project_name: str, |
8 | 22 | relation_name: str,
|
9 | 23 | y_column_name: str,
|
10 |
| - implementation: str = "sklearn.linear_model" |
| 24 | + algorithm: str = "sklearn.linear_model", |
| 25 | + test_size: float or int = 0.1, |
| 26 | + test_sampling: str = "random" |
11 | 27 | ) -> None:
|
12 | 28 | """Create a regression model from a table or view filled with training data.
|
13 | 29 |
|
14 | 30 | Args:
|
15 |
| - model_name (str): a human friendly identifier |
| 31 | + project_name (str): a human friendly identifier |
16 | 32 | relation_name (str): the table or view that stores the training data
|
17 | 33 | y_column_name (str): the column in the training data that acts as the label
|
18 |
| - implementation (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model". |
| 34 | + algorithm (str, optional): the algorithm used to implement the regression. Defaults to "sklearn.linear_model". |
| 35 | + test_size (float or int, optional): If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If None, the value is set to the complement of the train size. If train_size is also None, it will be set to 0.25. |
| 36 | + test_sampling: (str, optional): How to sample to create the test data. Defaults to "random". Valid values are ["first", "last", "random"]. |
19 | 37 | """
|
20 | 38 |
|
21 |
| - data_source = f"SELECT * FROM {table_name}" |
22 |
| - |
23 |
| - # Start training. |
24 |
| - start = plpy.execute(f""" |
25 |
| - INSERT INTO pgml.model_versions |
26 |
| - (name, data_source, y_column) |
27 |
| - VALUES |
28 |
| - ('{table_name}', '{data_source}', '{y}') |
29 |
| - RETURNING *""", 1) |
30 |
| - |
31 |
| - id_ = start[0]["id"] |
32 |
| - name = f"{table_name}_{id_}" |
33 |
| - |
34 |
| - destination = models_directory(plpy) |
| 39 | + plpy.warning("snapshot") |
| 40 | + # Create a snapshot of the relation |
| 41 | + snapshot = plpy.execute(f"INSERT INTO pgml.snapshots (relation, y, test_size, test_sampling, status) VALUES ('{relation_name}', '{y_column_name}', {test_size}, '{test_sampling}', 'new') RETURNING *", 1)[0] |
| 42 | + plpy.execute(f"""CREATE TABLE pgml.snapshot_{snapshot['id']} AS SELECT * FROM "{relation_name}";""") |
| 43 | + plpy.execute(f"UPDATE pgml.snapshots SET status = 'created' WHERE id = {snapshot['id']}") |
| 44 | + |
| 45 | + plpy.warning("project") |
| 46 | + # Find or create the project |
| 47 | + project = plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'", 1) |
| 48 | + plpy.warning(f"project {project}") |
| 49 | + if (project.nrows == 1): |
| 50 | + plpy.warning("project found") |
| 51 | + project = project[0] |
| 52 | + else: |
| 53 | + try: |
| 54 | + project = plpy.execute(f"INSERT INTO pgml.projects (name) VALUES ('{project_name}') RETURNING *", 1) |
| 55 | + plpy.warning(f"project inserted {project}") |
| 56 | + if (project.nrows() == 1): |
| 57 | + project = project[0] |
| 58 | + |
| 59 | + except Exception as e: # handle race condition to insert |
| 60 | + plpy.warning(f"project retry: #{e}") |
| 61 | + project = plpy.execute(f"SELECT * FROM pgml.projects WHERE name = '{project_name}'", 1)[0] |
| 62 | + |
| 63 | + plpy.warning("model") |
| 64 | + # Create the model |
| 65 | + model = plpy.execute(f"INSERT INTO pgml.models (project_id, snapshot_id, algorithm, status) VALUES ({project['id']}, {snapshot['id']}, '{algorithm}', 'training') RETURNING *")[0] |
| 66 | + |
| 67 | + plpy.warning("data") |
| 68 | + # Prepare the data |
| 69 | + data = plpy.execute(f"SELECT * FROM pgml.snapshot_{snapshot['id']}") |
| 70 | + |
| 71 | + # Sanity check the data |
| 72 | + if data.nrows == 0: |
| 73 | + PgMLException( |
| 74 | + f"Relation `{y_column_name}` contains no rows. Did you pass the correct `relation_name`?" |
| 75 | + ) |
| 76 | + if y_column_name not in data[0]: |
| 77 | + PgMLException( |
| 78 | + f"Column `{y_column_name}` not found. Did you pass the correct `y_column_name`?" |
| 79 | + ) |
| 80 | + |
| 81 | + # Always pull the columns in the same order from the row. |
| 82 | + # Python dict iteration is not always in the same order (hash table). |
| 83 | + columns = [] |
| 84 | + for col in data[0]: |
| 85 | + if col != y_column_name: |
| 86 | + columns.append(col) |
35 | 87 |
|
36 |
| - # Train! |
37 |
| - pickle, msq, r2 = train(plpy.cursor(data_source), y_column=y, name=name, destination=destination) |
| 88 | + # Split the label from the features |
38 | 89 | X = []
|
39 | 90 | y = []
|
40 |
| - columns = [] |
41 |
| - |
42 |
| - for row in all_rows(cursor): |
43 |
| - row = row.copy() |
44 |
| - |
45 |
| - if y_column not in row: |
46 |
| - PgMLException( |
47 |
| - f"Column `{y}` not found. Did you name your `y_column` correctly?" |
48 |
| - ) |
49 |
| - |
50 |
| - y_ = row.pop(y_column) |
| 91 | + for row in data: |
| 92 | + plpy.warning(f"row: {row}") |
| 93 | + y_ = row.pop(y_column_name) |
51 | 94 | x_ = []
|
52 | 95 |
|
53 |
| - # Always pull the columns in the same order from the row. |
54 |
| - # Python dict iteration is not always in the same order (hash table). |
55 |
| - if not columns: |
56 |
| - for col in row: |
57 |
| - columns.append(col) |
58 |
| - |
59 | 96 | for column in columns:
|
60 | 97 | x_.append(row[column])
|
| 98 | + |
61 | 99 | X.append(x_)
|
62 | 100 | y.append(y_)
|
63 | 101 |
|
64 |
| - X_train, X_test, y_train, y_test = train_test_split(X, y) |
65 |
| - |
66 |
| - # Just linear regression for now, but can add many more later. |
67 |
| - lr = LinearRegression() |
68 |
| - lr.fit(X_train, y_train) |
69 |
| - |
| 102 | + # Split into training and test sets |
| 103 | + plpy.warning("split") |
| 104 | + if (test_sampling == 'random'): |
| 105 | + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=0) |
| 106 | + else: |
| 107 | + if (test_sampling == 'first'): |
| 108 | + X.reverse() |
| 109 | + y.reverse() |
| 110 | + if isinstance(split, float): |
| 111 | + split = 1.0 - split |
| 112 | + split = test_size |
| 113 | + if isinstance(split, float): |
| 114 | + split = int(test_size * X.len()) |
| 115 | + X_train, X_test, y_train, y_test = X[0:split], X[split:X.len()-1], y[0:split], y[split:y.len()-1] |
| 116 | + |
| 117 | + # TODO normalize and clean data |
| 118 | + |
| 119 | + plpy.warning("train") |
| 120 | + # Train the model |
| 121 | + algo = LinearRegression() |
| 122 | + algo.fit(X_train, y_train) |
| 123 | + |
| 124 | + plpy.warning("test") |
70 | 125 | # Test
|
71 |
| - y_pred = lr.predict(X_test) |
| 126 | + y_pred = algo.predict(X_test) |
72 | 127 | msq = mean_squared_error(y_test, y_pred)
|
73 | 128 | r2 = r2_score(y_test, y_pred)
|
74 | 129 |
|
75 |
| - path = os.path.join(destination, name) |
76 |
| - |
77 |
| - if save: |
78 |
| - with open(path, "wb") as f: |
79 |
| - pickle.dump(lr, f) |
80 |
| - |
81 |
| - return path, msq, r2 |
82 |
| - |
| 130 | + plpy.warning("save") |
| 131 | + # Save the model |
| 132 | + weights = pickle.dumps(algo) |
83 | 133 |
|
84 | 134 | plpy.execute(f"""
|
85 |
| - UPDATE pgml.model_versions |
86 |
| - SET pickle = '{pickle}', |
87 |
| - successful = true, |
| 135 | + UPDATE pgml.models |
| 136 | + SET pickle = '\\x{weights.hex()}', |
| 137 | + status = 'successful', |
88 | 138 | mean_squared_error = '{msq}',
|
89 |
| - r2_score = '{r2}', |
90 |
| - ended_at = clock_timestamp() |
91 |
| - WHERE id = {id_}""") |
92 |
| - |
93 |
| - return name |
| 139 | + r2_score = '{r2}' |
| 140 | + WHERE id = {model['id']} |
| 141 | + """) |
94 | 142 |
|
95 |
| - model |
| 143 | + # TODO: promote the model? |
0 commit comments