Skip to content

organize python related modules #962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pgml-extension/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ sacremoses==0.0.53
scikit-learn==1.3.0
sentencepiece==0.1.99
sentence-transformers==2.2.2
tokenizers==0.13.3
torch==2.0.1
torchaudio==2.0.2
torchvision==0.15.2
Expand Down
70 changes: 14 additions & 56 deletions pgml-extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ use pgrx::iter::{SetOfIterator, TableIterator};
use pgrx::*;

#[cfg(feature = "python")]
use pyo3::prelude::*;
use serde_json::json;

#[cfg(feature = "python")]
use crate::bindings::sklearn::package_version;
use crate::orm::*;

macro_rules! unwrap_or_error {
Expand All @@ -25,38 +23,13 @@ macro_rules! unwrap_or_error {
#[cfg(feature = "python")]
#[pg_extern]
pub fn activate_venv(venv: &str) -> bool {
unwrap_or_error!(crate::bindings::venv::activate_venv(venv))
unwrap_or_error!(crate::bindings::python::activate_venv(venv))
}

#[cfg(feature = "python")]
#[pg_extern(immutable, parallel_safe)]
pub fn validate_python_dependencies() -> bool {
unwrap_or_error!(crate::bindings::venv::activate());

Python::with_gil(|py| {
let sys = PyModule::import(py, "sys").unwrap();
let version: String = sys.getattr("version").unwrap().extract().unwrap();
info!("Python version: {version}");
for module in ["xgboost", "lightgbm", "numpy", "sklearn"] {
match py.import(module) {
Ok(_) => (),
Err(e) => {
panic!(
"The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"
);
}
}
}
});

let sklearn = unwrap_or_error!(package_version("sklearn"));
let xgboost = unwrap_or_error!(package_version("xgboost"));
let lightgbm = unwrap_or_error!(package_version("lightgbm"));
let numpy = unwrap_or_error!(package_version("numpy"));

info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);

true
unwrap_or_error!(crate::bindings::python::validate_dependencies())
}

#[cfg(not(feature = "python"))]
Expand All @@ -66,8 +39,7 @@ pub fn validate_python_dependencies() {}
#[cfg(feature = "python")]
#[pg_extern]
pub fn python_package_version(name: &str) -> String {
unwrap_or_error!(crate::bindings::venv::activate());
unwrap_or_error!(package_version(name))
unwrap_or_error!(crate::bindings::python::package_version(name))
}

#[cfg(not(feature = "python"))]
Expand All @@ -79,13 +51,19 @@ pub fn python_package_version(name: &str) {
#[cfg(feature = "python")]
#[pg_extern]
pub fn python_pip_freeze() -> TableIterator<'static, (name!(package, String),)> {
unwrap_or_error!(crate::bindings::venv::activate());
unwrap_or_error!(crate::bindings::python::pip_freeze())
}

let packages = unwrap_or_error!(crate::bindings::venv::freeze())
.into_iter()
.map(|package| (package,));
#[cfg(feature = "python")]
#[pg_extern]
fn python_version() -> String {
unwrap_or_error!(crate::bindings::python::version())
}

TableIterator::new(packages)
#[cfg(not(feature = "python"))]
#[pg_extern]
pub fn python_version() -> String {
String::from("Python is not installed, recompile with `--features python`")
}

#[pg_extern]
Expand All @@ -104,26 +82,6 @@ pub fn validate_shared_library() {
}
}

#[cfg(feature = "python")]
#[pg_extern]
fn python_version() -> String {
unwrap_or_error!(crate::bindings::venv::activate());
let mut version = String::new();

Python::with_gil(|py| {
let sys = PyModule::import(py, "sys").unwrap();
version = sys.getattr("version").unwrap().extract().unwrap();
});

version
}

#[cfg(not(feature = "python"))]
#[pg_extern]
pub fn python_version() -> String {
String::from("Python is not installed, recompile with `--features python`")
}

#[pg_extern(immutable, parallel_safe)]
fn version() -> String {
crate::VERSION.to_string()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use pyo3::types::PyTuple;

use crate::{bindings::TracebackError, create_pymodule};

create_pymodule!("/src/bindings/langchain.py");
create_pymodule!("/src/bindings/langchain/langchain.py");

pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Result<Vec<String>> {
crate::bindings::venv::activate()?;
crate::bindings::python::activate()?;

let kwargs = serde_json::to_string(kwargs).unwrap();

Expand Down
4 changes: 2 additions & 2 deletions pgml-extension/src/bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ pub mod langchain;
pub mod lightgbm;
pub mod linfa;
#[cfg(feature = "python")]
pub mod python;
#[cfg(feature = "python")]
pub mod sklearn;
#[cfg(feature = "python")]
pub mod transformers;
#[cfg(feature = "python")]
pub mod venv;
pub mod xgboost;

pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Result<Box<dyn Bindings>>;
Expand Down
91 changes: 91 additions & 0 deletions pgml-extension/src/bindings/python/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//! Use virtualenv.

use anyhow::Result;
use once_cell::sync::Lazy;
use pgrx::iter::TableIterator;
use pgrx::*;
use pyo3::prelude::*;
use pyo3::types::PyTuple;

use crate::config::get_config;
use crate::{bindings::TracebackError, create_pymodule};

static CONFIG_NAME: &str = "pgml.venv";

create_pymodule!("/src/bindings/python/python.py");

pub fn activate_venv(venv: &str) -> Result<bool> {
Python::with_gil(|py| {
let activate_venv: Py<PyAny> = get_module!(PY_MODULE).getattr(py, "activate_venv")?;
let result: Py<PyAny> =
activate_venv.call1(py, PyTuple::new(py, &[venv.to_string().into_py(py)]))?;

Ok(result.extract(py)?)
})
}

pub fn activate() -> Result<bool> {
match get_config(CONFIG_NAME) {
Some(venv) => activate_venv(&venv),
None => Ok(false),
}
}

pub fn pip_freeze() -> Result<TableIterator<'static, (name!(package, String),)>> {
activate()?;
let packages = Python::with_gil(|py| -> Result<Vec<String>> {
let freeze = get_module!(PY_MODULE).getattr(py, "freeze")?;
let result = freeze.call0(py)?;

Ok(result.extract(py)?)
})?;

Ok(TableIterator::new(
packages.into_iter().map(|package| (package,)),
))
}

pub fn validate_dependencies() -> Result<bool> {
activate()?;
Python::with_gil(|py| {
let sys = PyModule::import(py, "sys").unwrap();
let version: String = sys.getattr("version").unwrap().extract().unwrap();
info!("Python version: {version}");
for module in ["xgboost", "lightgbm", "numpy", "sklearn"] {
match py.import(module) {
Ok(_) => (),
Err(e) => {
panic!(
"The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"
);
}
}
}
});

let sklearn = package_version("sklearn")?;
let xgboost = package_version("xgboost")?;
let lightgbm = package_version("lightgbm")?;
let numpy = package_version("numpy")?;

info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);

Ok(true)
}

pub fn version() -> Result<String> {
activate()?;
Python::with_gil(|py| {
let sys = PyModule::import(py, "sys").unwrap();
let version: String = sys.getattr("version").unwrap().extract().unwrap();
Ok(version)
})
}

pub fn package_version(name: &str) -> Result<String> {
activate()?;
Python::with_gil(|py| {
let package = py.import(name)?;
Ok(package.getattr("__version__")?.extract()?)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ use once_cell::sync::Lazy;
use pyo3::prelude::*;
use pyo3::types::PyTuple;

use crate::bindings::Bindings;
use crate::{
bindings::{Bindings, TracebackError},
create_pymodule,
orm::*,
};

use crate::{bindings::TracebackError, create_pymodule, orm::*};

create_pymodule!("/src/bindings/sklearn.py");
create_pymodule!("/src/bindings/sklearn/sklearn.py");

macro_rules! wrap_fit {
($fn_name:tt, $task:literal) => {
Expand Down Expand Up @@ -355,10 +357,3 @@ pub fn cluster_metrics(
Ok(scores)
})
}

pub fn package_version(name: &str) -> Result<String> {
Python::with_gil(|py| {
let package = py.import(name)?;
Ok(package.getattr("__version__")?.extract()?)
})
}
12 changes: 6 additions & 6 deletions pgml-extension/src/bindings/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub fn transform(
args: &serde_json::Value,
inputs: Vec<&str>,
) -> Result<serde_json::Value> {
crate::bindings::venv::activate()?;
crate::bindings::python::activate()?;

whitelist::verify_task(task)?;

Expand Down Expand Up @@ -70,7 +70,7 @@ pub fn embed(
inputs: Vec<&str>,
kwargs: &serde_json::Value,
) -> Result<Vec<Vec<f32>>> {
crate::bindings::venv::activate()?;
crate::bindings::python::activate()?;

let kwargs = serde_json::to_string(kwargs)?;
Python::with_gil(|py| -> Result<Vec<Vec<f32>>> {
Expand Down Expand Up @@ -101,7 +101,7 @@ pub fn tune(
hyperparams: &JsonB,
path: &Path,
) -> Result<HashMap<String, f64>> {
crate::bindings::venv::activate()?;
crate::bindings::python::activate()?;

let task = task.to_string();
let hyperparams = serde_json::to_string(&hyperparams.0)?;
Expand Down Expand Up @@ -131,7 +131,7 @@ pub fn tune(
}

pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<String>> {
crate::bindings::venv::activate()?;
crate::bindings::python::activate()?;

Python::with_gil(|py| -> Result<Vec<String>> {
let generate = get_module!(PY_MODULE)
Expand Down Expand Up @@ -219,7 +219,7 @@ pub fn load_dataset(
limit: Option<usize>,
kwargs: &serde_json::Value,
) -> Result<usize> {
crate::bindings::venv::activate()?;
crate::bindings::python::activate()?;

let kwargs = serde_json::to_string(kwargs)?;

Expand Down Expand Up @@ -376,7 +376,7 @@ pub fn load_dataset(
}

pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
crate::bindings::venv::activate().unwrap();
crate::bindings::python::activate().unwrap();

Python::with_gil(|py| -> Result<bool> {
let clear_gpu_cache: Py<PyAny> = get_module!(PY_MODULE)
Expand Down
2 changes: 2 additions & 0 deletions pgml-extension/src/bindings/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
DataCollatorWithPadding,
DefaultDataCollator,
GenerationConfig,
PegasusForConditionalGeneration,
PegasusTokenizer,
TrainingArguments,
Trainer,
)
Expand Down
40 changes: 0 additions & 40 deletions pgml-extension/src/bindings/venv.rs

This file was deleted.

2 changes: 1 addition & 1 deletion pgml-extension/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl Model {
};

if runtime == Runtime::python {
crate::bindings::venv::activate().unwrap();
crate::bindings::python::activate().unwrap();
}

let dataset = snapshot.tabular_dataset();
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy