Skip to content

Commit f75114b

Browse files
authored
LLM fine-tuning (#1350)
1 parent 790e4f9 commit f75114b

File tree

15 files changed

+1993
-105
lines changed

15 files changed

+1993
-105
lines changed

README.md

Lines changed: 737 additions & 0 deletions
Large diffs are not rendered by default.

pgml-extension/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
.DS_Store
1515

1616

17+
# venv
18+
pgml-venv

pgml-extension/Cargo.lock

Lines changed: 38 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.8.2"
3+
version = "2.8.3"
44
edition = "2021"
55

66
[lib]
@@ -39,8 +39,8 @@ openblas-src = { version = "0.10", features = ["cblas", "system"] }
3939
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
4040
ndarray-stats = "0.5.1"
4141
parking_lot = "0.12"
42-
pgrx = "=0.11.2"
43-
pgrx-pg-sys = "=0.11.2"
42+
pgrx = "=0.11.3"
43+
pgrx-pg-sys = "=0.11.3"
4444
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
4545
rand = "0.8"
4646
rmp-serde = { version = "1.1" }
@@ -51,7 +51,7 @@ typetag = "0.2"
5151
xgboost = { git = "https://github.com/postgresml/rust-xgboost", branch = "master" }
5252

5353
[dev-dependencies]
54-
pgrx-tests = "=0.11.2"
54+
pgrx-tests = "=0.11.3"
5555

5656
[build-dependencies]
5757
vergen = { version = "8", features = ["build", "git", "gitcl"] }

pgml-extension/requirements.linux.txt

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
accelerate==0.25.0
1+
accelerate==0.27.2
22
aiohttp==3.9.1
33
aiosignal==1.3.1
44
annotated-types==0.6.0
55
anyio==4.2.0
6+
appdirs==1.4.4
67
async-timeout==4.0.3
78
attrs==23.1.0
89
auto-gptq==0.6.0
910
bitsandbytes==0.41.3.post2
11+
black==24.1.1
1012
catboost==1.2.2
1113
certifi==2023.11.17
1214
charset-normalizer==3.3.2
@@ -20,13 +22,18 @@ dataclasses-json==0.6.3
2022
datasets==2.15.0
2123
deepspeed==0.12.5
2224
dill==0.3.7
25+
docker-pycreds==0.4.0
26+
docstring-parser==0.15
2327
einops==0.7.0
28+
evaluate==0.4.1
2429
exceptiongroup==1.2.0
2530
filelock==3.13.1
2631
fonttools==4.47.0
2732
frozenlist==1.4.1
2833
fsspec==2023.10.0
2934
gekko==1.0.6
35+
gitdb==4.0.11
36+
GitPython==3.1.41
3037
graphviz==0.20.1
3138
greenlet==3.0.2
3239
hjson==3.1.0
@@ -45,9 +52,11 @@ langchain-core==0.1.1
4552
langsmith==0.0.72
4653
lightgbm==4.1.0
4754
lxml==4.9.3
55+
markdown-it-py==3.0.0
4856
MarkupSafe==2.1.3
4957
marshmallow==3.20.1
5058
matplotlib==3.8.2
59+
mdurl==0.1.2
5160
mpmath==1.3.0
5261
multidict==6.0.4
5362
multiprocess==0.70.15
@@ -72,8 +81,10 @@ optimum==1.16.1
7281
orjson==3.9.10
7382
packaging==23.2
7483
pandas==2.1.4
84+
pathspec==0.12.1
7585
peft==0.7.1
7686
Pillow==10.1.0
87+
platformdirs==4.2.0
7788
plotly==5.18.0
7889
portalocker==2.8.2
7990
protobuf==4.25.1
@@ -83,13 +94,16 @@ pyarrow==11.0.0
8394
pyarrow-hotfix==0.6
8495
pydantic==2.5.2
8596
pydantic_core==2.14.5
97+
Pygments==2.17.2
8698
pynvml==11.5.0
8799
pyparsing==3.1.1
88100
python-dateutil==2.8.2
89101
pytz==2023.3.post1
90102
PyYAML==6.0.1
91103
regex==2023.10.3
92104
requests==2.31.0
105+
responses==0.18.0
106+
rich==13.7.1
93107
rouge==1.0.1
94108
sacrebleu==2.4.0
95109
sacremoses==0.1.1
@@ -98,23 +112,30 @@ scikit-learn==1.3.2
98112
scipy==1.11.4
99113
sentence-transformers==2.5.1
100114
sentencepiece==0.1.99
115+
sentry-sdk==1.40.2
116+
setproctitle==1.3.3
117+
shtab==1.6.5
101118
six==1.16.0
119+
smmap==5.0.1
102120
sniffio==1.3.0
103121
SQLAlchemy==2.0.23
104122
sympy==1.12
105123
tabulate==0.9.0
106124
tenacity==8.2.3
107125
threadpoolctl==3.2.0
108126
tokenizers==0.15.0
127+
tomli==2.0.1
109128
torch==2.1.2
110129
torchaudio==2.1.2
111130
torchvision==0.16.2
112131
tqdm==4.66.1
113132
transformers==4.38.2
114133
transformers-stream-generator==0.0.4
115134
triton==2.1.0
135+
trl==0.7.10
116136
typing-inspect==0.9.0
117137
typing_extensions==4.9.0
138+
tyro==0.7.2
118139
tzdata==2023.3
119140
urllib3==2.1.0
120141
xformers==0.0.23.post1
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
-- Add conversation, text-pair-classification task type
2+
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
3+
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text-pair-classification';
4+
5+
-- Crate pgml.logs table
6+
CREATE TABLE IF NOT EXISTS pgml.logs (
7+
id SERIAL PRIMARY KEY,
8+
model_id BIGINT,
9+
project_id BIGINT,
10+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
11+
logs JSONB
12+
);

pgml-extension/src/api.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ fn tune(
816816
project_name: &str,
817817
task: default!(Option<&str>, "NULL"),
818818
relation_name: default!(Option<&str>, "NULL"),
819-
y_column_name: default!(Option<&str>, "NULL"),
819+
_y_column_name: default!(Option<&str>, "NULL"),
820820
model_name: default!(Option<&str>, "NULL"),
821821
hyperparams: default!(JsonB, "'{}'"),
822822
test_size: default!(f32, 0.25),
@@ -874,9 +874,7 @@ fn tune(
874874

875875
let snapshot = Snapshot::create(
876876
relation_name,
877-
Some(vec![y_column_name
878-
.expect("You must pass a `y_column_name` when you pass a `relation_name`")
879-
.to_string()]),
877+
None,
880878
test_size,
881879
test_sampling,
882880
materialize_snapshot,
@@ -898,13 +896,14 @@ fn tune(
898896
// algorithm will be transformers, stash the model_name in a hyperparam for v1 compatibility.
899897
let mut hyperparams = hyperparams.0.as_object().unwrap().clone();
900898
hyperparams.insert(String::from("model_name"), json!(model_name));
899+
hyperparams.insert(String::from("project_name"), json!(project_name));
901900
let hyperparams = JsonB(json!(hyperparams));
902901

903902
// # Default repeatable random state when possible
904903
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
905904
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
906905
// hyperparams["random_state"] = 0
907-
let model = Model::tune(&project, &mut snapshot, &hyperparams);
906+
let model = Model::finetune(&project, &mut snapshot, &hyperparams);
908907
let new_metrics: &serde_json::Value = &model.metrics.unwrap().0;
909908
let new_metrics = new_metrics.as_object().unwrap();
910909

@@ -928,18 +927,19 @@ fn tune(
928927
Some(true) | None => {
929928
if let Ok(Some(deployed_metrics)) = deployed_metrics {
930929
let deployed_metrics = deployed_metrics.0.as_object().unwrap();
931-
if project.task.value_is_better(
932-
deployed_metrics
933-
.get(&project.task.default_target_metric())
934-
.unwrap()
935-
.as_f64()
936-
.unwrap(),
937-
new_metrics
938-
.get(&project.task.default_target_metric())
939-
.unwrap()
940-
.as_f64()
941-
.unwrap(),
942-
) {
930+
931+
let deployed_value = deployed_metrics
932+
.get(&project.task.default_target_metric())
933+
.and_then(|value| value.as_f64())
934+
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails
935+
936+
// Get the value for the default target metric from new_metrics or provide a default value
937+
let new_value = new_metrics
938+
.get(&project.task.default_target_metric())
939+
.and_then(|value| value.as_f64())
940+
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails
941+
942+
if project.task.value_is_better(deployed_value, new_value) {
943943
deploy = false;
944944
}
945945
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy