Content-Length: 5052 | pFad | http://github.com/postgresml/postgresml/pull/1374.patch
thub.com
From 107e894aafd72ace80e0e99a5ad66df9d47f2c39 Mon Sep 17 00:00:00 2001
From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com>
Date: Fri, 15 Mar 2024 12:29:45 -0700
Subject: [PATCH] Moved python functions
---
pgml-extension/src/bindings/mod.rs | 33 +++++++++++++++----
.../src/bindings/transformers/transformers.py | 14 ++++----
2 files changed, 33 insertions(+), 14 deletions(-)
diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs
index d877f490a..9c9449103 100644
--- a/pgml-extension/src/bindings/mod.rs
+++ b/pgml-extension/src/bindings/mod.rs
@@ -3,10 +3,31 @@ use std::fmt::Debug;
use anyhow::{anyhow, Result};
#[allow(unused_imports)] // used for test macros
use pgrx::*;
-use pyo3::{PyResult, Python};
+use pyo3::{pyfunction, PyResult, Python};
use crate::orm::*;
+#[pyfunction]
+fn r_insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult {
+ let id_value = Spi::get_one_with_args::(
+ "INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;",
+ vec![
+ (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()),
+ (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()),
+ (PgBuiltInOids::TEXTOID.oid(), logs.into_datum()),
+ ],
+ )
+ .unwrap()
+ .unwrap();
+ Ok(format!("Inserted logs with id: {}", id_value))
+}
+
+#[pyfunction]
+fn r_print_info(info: String) -> PyResult {
+ info!("{}", info);
+ Ok(info)
+}
+
#[cfg(feature = "python")]
#[macro_export]
macro_rules! create_pymodule {
@@ -16,11 +37,11 @@ macro_rules! create_pymodule {
pyo3::Python::with_gil(|py| -> anyhow::Result> {
use $crate::bindings::TracebackError;
let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile));
- Ok(
- pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__")
- .format_traceback(py)?
- .into(),
- )
+ let module = pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__")
+ .format_traceback(py)?;
+ module.add_function(wrap_pyfunction!($crate::bindings::r_insert_logs, module)?)?;
+ module.add_function(wrap_pyfunction!($crate::bindings::r_print_info, module)?)?;
+ Ok(module.into())
})
});
};
diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py
index 42ac43fe0..f3a6d63d4 100644
--- a/pgml-extension/src/bindings/transformers/transformers.py
+++ b/pgml-extension/src/bindings/transformers/transformers.py
@@ -55,7 +55,6 @@
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl.trainer import ConstantLengthDataset
from peft import LoraConfig, get_peft_model
-from pypgrx import print_info, insert_logs
from abc import abstractmethod
transformers.logging.set_verbosity_info()
@@ -1017,8 +1016,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
logs["step"] = state.global_step
logs["max_steps"] = state.max_steps
logs["timestamp"] = str(datetime.now())
- print_info(json.dumps(logs, indent=4))
- insert_logs(self.project_id, self.model_id, json.dumps(logs))
+ r_print_info(json.dumps(logs, indent=4))
class FineTuningBase:
@@ -1100,9 +1098,9 @@ def print_number_of_trainable_model_parameters(self, model):
trainable_model_params += param.numel()
# Calculate and print the number and percentage of trainable parameters
- print_info(f"Trainable model parameters: {trainable_model_params}")
- print_info(f"All model parameters: {all_model_params}")
- print_info(
+ r_print_info(f"Trainable model parameters: {trainable_model_params}")
+ r_print_info(f"All model parameters: {all_model_params}")
+ r_print_info(
f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"
)
@@ -1398,7 +1396,7 @@ def __init__(
"bias": "none",
"task_type": "CAUSAL_LM",
}
- print_info(
+ r_print_info(
"LoRA configuration are not set. Using default parameters"
+ json.dumps(self.lora_config_params)
)
@@ -1465,7 +1463,7 @@ def formatting_prompts_func(example):
peft_config=LoraConfig(**self.lora_config_params),
callbacks=[PGMLCallback(self.project_id, self.model_id)],
)
- print_info("Creating Supervised Fine Tuning trainer done. Training ... ")
+ r_print_info("Creating Supervised Fine Tuning trainer done. Training ... ")
# Train
self.trainer.train()
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/postgresml/postgresml/pull/1374.patch
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy