Skip to content

Commit 4612b4f

Browse files
committed
Refactor the initialization of GUC parameters.
Managing GUC parameters in different places is hard to maintain. This patch organizes GUC definitions in a single place. Also, we use define_xxx_guc() APIs to define these parameters and it will allow us to manage GucContext, GucFlags in future. P.S., the test case test_trusted_model doesn't seem correct. I fixed it in this patch.
1 parent 0842673 commit 4612b4f

File tree

6 files changed

+88
-44
lines changed

6 files changed

+88
-44
lines changed

pgml-extension/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ serde = { version = "1.0" }
4949
serde_json = { version = "1.0", features = ["preserve_order"] }
5050
typetag = "0.2"
5151
xgboost = { git = "https://github.com/postgresml/rust-xgboost", branch = "master" }
52+
lazy_static = "1.4.0"
5253

5354
[dev-dependencies]
5455
pgrx-tests = "=0.11.2"

pgml-extension/src/bindings/python/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ use pgrx::*;
66
use pyo3::prelude::*;
77
use pyo3::types::PyTuple;
88

9-
use crate::config::get_config;
9+
use crate::config::PGML_VENV;
1010
use crate::create_pymodule;
1111

12-
static CONFIG_NAME: &str = "pgml.venv";
13-
1412
create_pymodule!("/src/bindings/python/python.py");
1513

1614
pub fn activate_venv(venv: &str) -> Result<bool> {
@@ -23,8 +21,8 @@ pub fn activate_venv(venv: &str) -> Result<bool> {
2321
}
2422

2523
pub fn activate() -> Result<bool> {
26-
match get_config(CONFIG_NAME) {
27-
Some(venv) => activate_venv(&venv),
24+
match PGML_VENV.1.get() {
25+
Some(venv) => activate_venv(&venv.to_string_lossy()),
2826
None => Ok(false),
2927
}
3028
}

pgml-extension/src/bindings/transformers/whitelist.rs

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,54 @@
11
use anyhow::{bail, Error};
2+
use pgrx::GucSetting;
23
#[cfg(any(test, feature = "pg_test"))]
34
use pgrx::{pg_schema, pg_test};
45
use serde_json::Value;
6+
use std::ffi::CStr;
57

6-
use crate::config::get_config;
7-
8-
static CONFIG_HF_WHITELIST: &str = "pgml.huggingface_whitelist";
9-
static CONFIG_HF_TRUST_REMOTE_CODE_BOOL: &str = "pgml.huggingface_trust_remote_code";
10-
static CONFIG_HF_TRUST_WHITELIST: &str = "pgml.huggingface_trust_remote_code_whitelist";
8+
use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_WHITELIST, PGML_HF_WHITELIST};
119

1210
/// Verify that the model in the task JSON is allowed based on the huggingface whitelists.
1311
pub fn verify_task(task: &Value) -> Result<(), Error> {
1412
let task_model = match get_model_name(task) {
1513
Some(model) => model.to_string(),
1614
None => return Ok(()),
1715
};
18-
let whitelisted_models = config_csv_list(CONFIG_HF_WHITELIST);
16+
let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST.1);
1917

2018
let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model);
2119
if !model_is_allowed {
22-
bail!("model {task_model} is not whitelisted. Consider adding to {CONFIG_HF_WHITELIST} in postgresql.conf");
20+
bail!(
21+
"model {} is not whitelisted. Consider adding to {} in postgresql.conf",
22+
task_model,
23+
PGML_HF_WHITELIST.0
24+
);
2325
}
2426

2527
let task_trust = get_trust_remote_code(task);
26-
let trust_remote_code = get_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL)
27-
.map(|v| v == "true")
28-
.unwrap_or(true);
28+
let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.1.get();
2929

30-
let trusted_models = config_csv_list(CONFIG_HF_TRUST_WHITELIST);
30+
let trusted_models = config_csv_list(&PGML_HF_TRUST_WHITELIST.1);
3131

3232
let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model);
3333

3434
let remote_code_allowed = trust_remote_code && model_is_trusted;
3535
if !remote_code_allowed && task_trust == Some(true) {
36-
bail!("model {task_model} is not trusted to run remote code. Consider setting {CONFIG_HF_TRUST_REMOTE_CODE_BOOL} = 'true' or adding {task_model} to {CONFIG_HF_TRUST_WHITELIST}");
36+
bail!(
37+
"model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}",
38+
task_model,
39+
PGML_HF_TRUST_REMOTE_CODE.0,
40+
task_model,
41+
PGML_HF_TRUST_WHITELIST.0
42+
);
3743
}
3844

3945
Ok(())
4046
}
4147

42-
fn config_csv_list(name: &str) -> Vec<String> {
43-
match get_config(name) {
48+
fn config_csv_list(csv_list: &GucSetting<Option<&'static CStr>>) -> Vec<String> {
49+
match csv_list.get() {
4450
Some(value) => value
51+
.to_string_lossy()
4552
.trim_matches('"')
4653
.split(',')
4754
.filter_map(|s| if s.is_empty() { None } else { Some(s.to_string()) })
@@ -122,7 +129,7 @@ mod tests {
122129
#[pg_test]
123130
fn test_empty_whitelist() {
124131
let model = "Salesforce/xgen-7b-8k-inst";
125-
set_config(CONFIG_HF_WHITELIST, "").unwrap();
132+
set_config(PGML_HF_WHITELIST.0, "").unwrap();
126133
let task_json = format!(json_template!(), model, false);
127134
let task: Value = serde_json::from_str(&task_json).unwrap();
128135
assert!(verify_task(&task).is_ok());
@@ -131,12 +138,12 @@ mod tests {
131138
#[pg_test]
132139
fn test_nonempty_whitelist() {
133140
let model = "Salesforce/xgen-7b-8k-inst";
134-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
141+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
135142
let task_json = format!(json_template!(), model, false);
136143
let task: Value = serde_json::from_str(&task_json).unwrap();
137144
assert!(verify_task(&task).is_ok());
138145

139-
set_config(CONFIG_HF_WHITELIST, "other_model").unwrap();
146+
set_config(PGML_HF_WHITELIST.0, "other_model").unwrap();
140147
let task_json = format!(json_template!(), model, false);
141148
let task: Value = serde_json::from_str(&task_json).unwrap();
142149
assert!(verify_task(&task).is_err());
@@ -145,18 +152,18 @@ mod tests {
145152
#[pg_test]
146153
fn test_trusted_model() {
147154
let model = "Salesforce/xgen-7b-8k-inst";
148-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
149-
set_config(CONFIG_HF_TRUST_WHITELIST, model).unwrap();
155+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
156+
set_config(PGML_HF_TRUST_WHITELIST.0, model).unwrap();
150157

151158
let task_json = format!(json_template!(), model, false);
152159
let task: Value = serde_json::from_str(&task_json).unwrap();
153160
assert!(verify_task(&task).is_ok());
154161

155162
let task_json = format!(json_template!(), model, true);
156163
let task: Value = serde_json::from_str(&task_json).unwrap();
157-
assert!(verify_task(&task).is_ok());
164+
assert!(verify_task(&task).is_err());
158165

159-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap();
166+
set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap();
160167
let task_json = format!(json_template!(), model, false);
161168
let task: Value = serde_json::from_str(&task_json).unwrap();
162169
assert!(verify_task(&task).is_ok());
@@ -169,8 +176,8 @@ mod tests {
169176
#[pg_test]
170177
fn test_untrusted_model() {
171178
let model = "Salesforce/xgen-7b-8k-inst";
172-
set_config(CONFIG_HF_WHITELIST, model).unwrap();
173-
set_config(CONFIG_HF_TRUST_WHITELIST, "other_model").unwrap();
179+
set_config(PGML_HF_WHITELIST.0, model).unwrap();
180+
set_config(PGML_HF_TRUST_WHITELIST.0, "other_model").unwrap();
174181

175182
let task_json = format!(json_template!(), model, false);
176183
let task: Value = serde_json::from_str(&task_json).unwrap();
@@ -180,7 +187,7 @@ mod tests {
180187
let task: Value = serde_json::from_str(&task_json).unwrap();
181188
assert!(verify_task(&task).is_err());
182189

183-
set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap();
190+
set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap();
184191
let task_json = format!(json_template!(), model, false);
185192
let task: Value = serde_json::from_str(&task_json).unwrap();
186193
assert!(verify_task(&task).is_ok());

pgml-extension/src/config.rs

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,58 @@
1+
use lazy_static::lazy_static;
2+
use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting};
13
use std::ffi::CStr;
24

35
#[cfg(any(test, feature = "pg_test"))]
46
use pgrx::{pg_schema, pg_test};
5-
use pgrx_pg_sys::AsPgCStr;
67

7-
pub fn get_config(name: &str) -> Option<String> {
8-
// SAFETY: name is not null because it is a Rust reference.
9-
let ptr = unsafe { pgrx_pg_sys::GetConfigOption(name.as_pg_cstr(), true, false) };
10-
(!ptr.is_null()).then(move || {
11-
// SAFETY: assuming pgrx_pg_sys is providing a valid, null terminated pointer.
12-
unsafe { CStr::from_ptr(ptr) }.to_string_lossy().to_string()
13-
})
8+
lazy_static! {
9+
pub static ref PGML_VENV: (&'static str, GucSetting<Option<&'static CStr>>) =
10+
("pgml.venv", GucSetting::<Option<&'static CStr>>::new(None));
11+
pub static ref PGML_HF_WHITELIST: (&'static str, GucSetting<Option<&'static CStr>>) = (
12+
"pgml.huggingface_whitelist",
13+
GucSetting::<Option<&'static CStr>>::new(None),
14+
);
15+
pub static ref PGML_HF_TRUST_REMOTE_CODE: (&'static str, GucSetting<bool>) =
16+
("pgml.huggingface_trust_remote_code", GucSetting::<bool>::new(false));
17+
pub static ref PGML_HF_TRUST_WHITELIST: (&'static str, GucSetting<Option<&'static CStr>>) = (
18+
"pgml.huggingface_trust_remote_code_whitelist",
19+
GucSetting::<Option<&'static CStr>>::new(None),
20+
);
21+
}
22+
23+
pub fn initialize_server_params() {
24+
GucRegistry::define_string_guc(
25+
PGML_VENV.0,
26+
"Python's virtual environment path",
27+
"",
28+
&PGML_VENV.1,
29+
GucContext::Userset,
30+
GucFlags::default(),
31+
);
32+
GucRegistry::define_string_guc(
33+
PGML_HF_WHITELIST.0,
34+
"Models allowed to be downloaded from huggingface",
35+
"",
36+
&PGML_HF_WHITELIST.1,
37+
GucContext::Userset,
38+
GucFlags::default(),
39+
);
40+
GucRegistry::define_bool_guc(
41+
PGML_HF_TRUST_REMOTE_CODE.0,
42+
"Whether model can execute remote codes",
43+
"",
44+
&PGML_HF_TRUST_REMOTE_CODE.1,
45+
GucContext::Userset,
46+
GucFlags::default(),
47+
);
48+
GucRegistry::define_string_guc(
49+
PGML_HF_TRUST_WHITELIST.0,
50+
"Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'",
51+
"",
52+
&PGML_HF_TRUST_WHITELIST.1,
53+
GucContext::Userset,
54+
GucFlags::default(),
55+
);
1456
}
1557

1658
#[cfg(any(test, feature = "pg_test"))]
@@ -26,17 +68,11 @@ pub fn set_config(name: &str, value: &str) -> Result<(), pgrx::spi::Error> {
2668
mod tests {
2769
use super::*;
2870

29-
#[pg_test]
30-
fn read_config_max_connections() {
31-
let name = "max_connections";
32-
assert_eq!(get_config(name), Some("100".into()));
33-
}
34-
3571
#[pg_test]
3672
fn read_pgml_huggingface_whitelist() {
3773
let name = "pgml.huggingface_whitelist";
3874
let value = "meta-llama/Llama-2-7b";
3975
set_config(name, value).unwrap();
40-
assert_eq!(get_config(name), Some(value.into()));
76+
assert_eq!(PGML_HF_WHITELIST.1.get().unwrap().to_string_lossy(), value);
4177
}
4278
}

pgml-extension/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ extension_sql_file!("../sql/schema.sql", name = "schema");
2424
#[cfg(not(feature = "use_as_lib"))]
2525
#[pg_guard]
2626
pub extern "C" fn _PG_init() {
27+
config::initialize_server_params();
2728
bindings::python::activate().expect("Error setting python venv");
2829
orm::project::init();
2930
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy