Content-Length: 4646 | pFad | http://github.com/postgresml/postgresml/pull/1374.diff

thub.com diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index d877f490a..9c9449103 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -3,10 +3,31 @@ use std::fmt::Debug; use anyhow::{anyhow, Result}; #[allow(unused_imports)] // used for test macros use pgrx::*; -use pyo3::{PyResult, Python}; +use pyo3::{pyfunction, PyResult, Python}; use crate::orm::*; +#[pyfunction] +fn r_insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult { + let id_value = Spi::get_one_with_args::( + "INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;", + vec![ + (PgBuiltInOids::INT8OID.oid(), project_id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), logs.into_datum()), + ], + ) + .unwrap() + .unwrap(); + Ok(format!("Inserted logs with id: {}", id_value)) +} + +#[pyfunction] +fn r_print_info(info: String) -> PyResult { + info!("{}", info); + Ok(info) +} + #[cfg(feature = "python")] #[macro_export] macro_rules! create_pymodule { @@ -16,11 +37,11 @@ macro_rules! create_pymodule { pyo3::Python::with_gil(|py| -> anyhow::Result> { use $crate::bindings::TracebackError; let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile)); - Ok( - pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__") - .format_traceback(py)? - .into(), - ) + let module = pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__") + .format_traceback(py)?; + module.add_function(wrap_pyfunction!($crate::bindings::r_insert_logs, module)?)?; + module.add_function(wrap_pyfunction!($crate::bindings::r_print_info, module)?)?; + Ok(module.into()) }) }); }; diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 42ac43fe0..f3a6d63d4 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -55,7 +55,6 @@ from trl import SFTTrainer, DataCollatorForCompletionOnlyLM from trl.trainer import ConstantLengthDataset from peft import LoraConfig, get_peft_model -from pypgrx import print_info, insert_logs from abc import abstractmethod transformers.logging.set_verbosity_info() @@ -1017,8 +1016,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): logs["step"] = state.global_step logs["max_steps"] = state.max_steps logs["timestamp"] = str(datetime.now()) - print_info(json.dumps(logs, indent=4)) - insert_logs(self.project_id, self.model_id, json.dumps(logs)) + r_print_info(json.dumps(logs, indent=4)) class FineTuningBase: @@ -1100,9 +1098,9 @@ def print_number_of_trainable_model_parameters(self, model): trainable_model_params += param.numel() # Calculate and print the number and percentage of trainable parameters - print_info(f"Trainable model parameters: {trainable_model_params}") - print_info(f"All model parameters: {all_model_params}") - print_info( + r_print_info(f"Trainable model parameters: {trainable_model_params}") + r_print_info(f"All model parameters: {all_model_params}") + r_print_info( f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%" ) @@ -1398,7 +1396,7 @@ def __init__( "bias": "none", "task_type": "CAUSAL_LM", } - print_info( + r_print_info( "LoRA configuration are not set. Using default parameters" + json.dumps(self.lora_config_params) ) @@ -1465,7 +1463,7 @@ def formatting_prompts_func(example): peft_config=LoraConfig(**self.lora_config_params), callbacks=[PGMLCallback(self.project_id, self.model_id)], ) - print_info("Creating Supervised Fine Tuning trainer done. Training ... ") + r_print_info("Creating Supervised Fine Tuning trainer done. Training ... ") # Train self.trainer.train()








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/pull/1374.diff

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy