Skip to content

Commit 3768179

Browse files
committed
Simplify global variables.
1 parent 2a1853c commit 3768179

File tree

4 files changed

+40
-48
lines changed

4 files changed

+40
-48
lines changed

pgml-extension/src/bindings/python/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub fn activate_venv(venv: &str) -> Result<bool> {
2121
}
2222

2323
pub fn activate() -> Result<bool> {
24-
match PGML_VENV.1.get() {
24+
match PGML_VENV.get() {
2525
Some(venv) => activate_venv(&venv.to_string_lossy()),
2626
None => Ok(false),
2727
}

pgml-extension/src/bindings/transformers/whitelist.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@ use pgrx::{pg_schema, pg_test};
55
use serde_json::Value;
66
use std::ffi::CStr;
77

8-
use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_WHITELIST, PGML_HF_WHITELIST};
8+
use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_REMOTE_CODE_WHITELIST, PGML_HF_WHITELIST};
99

1010
/// Verify that the model in the task JSON is allowed based on the huggingface whitelists.
1111
pub fn verify_task(task: &Value) -> Result<(), Error> {
1212
let task_model = match get_model_name(task) {
1313
Some(model) => model.to_string(),
1414
None => return Ok(()),
1515
};
16-
let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST.1);
16+
let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST);
1717

1818
let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model);
1919
if !model_is_allowed {
2020
bail!(
2121
"model {} is not whitelisted. Consider adding to {} in postgresql.conf",
2222
task_model,
23-
PGML_HF_WHITELIST.0
23+
"pgml.huggingface_whitelist"
2424
);
2525
}
2626

2727
let task_trust = get_trust_remote_code(task);
28-
let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.1.get();
28+
let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.get();
2929

30-
let trusted_models = config_csv_list(&PGML_HF_TRUST_WHITELIST.1);
30+
let trusted_models = config_csv_list(&PGML_HF_TRUST_REMOTE_CODE_WHITELIST);
3131

3232
let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model);
3333

@@ -36,9 +36,9 @@ pub fn verify_task(task: &Value) -> Result<(), Error> {
3636
bail!(
3737
"model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}",
3838
task_model,
39-
PGML_HF_TRUST_REMOTE_CODE.0,
39+
"pgml.huggingface_trust_remote_code",
4040
task_model,
41-
PGML_HF_TRUST_WHITELIST.0
41+
"pgml.huggingface_trust_remote_code_whitelist",
4242
);
4343
}
4444

@@ -129,7 +129,7 @@ mod tests {
129129
#[pg_test]
130130
fn test_empty_whitelist() {
131131
let model = "Salesforce/xgen-7b-8k-inst";
132-
set_config(PGML_HF_WHITELIST.0, "").unwrap();
132+
set_config("pgml.huggingface_whitelist", "").unwrap();
133133
let task_json = format!(json_template!(), model, false);
134134
let task: Value = serde_json::from_str(&task_json).unwrap();
135135
assert!(verify_task(&task).is_ok());
@@ -138,12 +138,12 @@ mod tests {
138138
#[pg_test]
139139
fn test_nonempty_whitelist() {
140140
let model = "Salesforce/xgen-7b-8k-inst";
141-
set_config(PGML_HF_WHITELIST.0, model).unwrap();
141+
set_config("pgml.huggingface_whitelist", model).unwrap();
142142
let task_json = format!(json_template!(), model, false);
143143
let task: Value = serde_json::from_str(&task_json).unwrap();
144144
assert!(verify_task(&task).is_ok());
145145

146-
set_config(PGML_HF_WHITELIST.0, "other_model").unwrap();
146+
set_config("pgml.huggingface_whitelist", "other_model").unwrap();
147147
let task_json = format!(json_template!(), model, false);
148148
let task: Value = serde_json::from_str(&task_json).unwrap();
149149
assert!(verify_task(&task).is_err());
@@ -152,8 +152,8 @@ mod tests {
152152
#[pg_test]
153153
fn test_trusted_model() {
154154
let model = "Salesforce/xgen-7b-8k-inst";
155-
set_config(PGML_HF_WHITELIST.0, model).unwrap();
156-
set_config(PGML_HF_TRUST_WHITELIST.0, model).unwrap();
155+
set_config("pgml.huggingface_whitelist", model).unwrap();
156+
set_config("pgml.huggingface_trust_remote_code_whitelist", model).unwrap();
157157

158158
let task_json = format!(json_template!(), model, false);
159159
let task: Value = serde_json::from_str(&task_json).unwrap();
@@ -163,7 +163,7 @@ mod tests {
163163
let task: Value = serde_json::from_str(&task_json).unwrap();
164164
assert!(verify_task(&task).is_err());
165165

166-
set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap();
166+
set_config("pgml.huggingface_trust_remote_code", "true").unwrap();
167167
let task_json = format!(json_template!(), model, false);
168168
let task: Value = serde_json::from_str(&task_json).unwrap();
169169
assert!(verify_task(&task).is_ok());
@@ -176,8 +176,8 @@ mod tests {
176176
#[pg_test]
177177
fn test_untrusted_model() {
178178
let model = "Salesforce/xgen-7b-8k-inst";
179-
set_config(PGML_HF_WHITELIST.0, model).unwrap();
180-
set_config(PGML_HF_TRUST_WHITELIST.0, "other_model").unwrap();
179+
set_config("pgml.huggingface_whitelist", model).unwrap();
180+
set_config("pgml.huggingface_trust_remote_code_whitelist", "other_model").unwrap();
181181

182182
let task_json = format!(json_template!(), model, false);
183183
let task: Value = serde_json::from_str(&task_json).unwrap();
@@ -187,7 +187,7 @@ mod tests {
187187
let task: Value = serde_json::from_str(&task_json).unwrap();
188188
assert!(verify_task(&task).is_err());
189189

190-
set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap();
190+
set_config("pgml.huggingface_trust_remote_code", "true").unwrap();
191191
let task_json = format!(json_template!(), model, false);
192192
let task: Value = serde_json::from_str(&task_json).unwrap();
193193
assert!(verify_task(&task).is_ok());

pgml-extension/src/config.rs

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,58 @@
1-
use once_cell::sync::Lazy;
21
use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting};
32
use std::ffi::CStr;
43

54
#[cfg(any(test, feature = "pg_test"))]
65
use pgrx::{pg_schema, pg_test};
76

8-
pub static PGML_VENV: Lazy<(&'static str, GucSetting<Option<&'static CStr>>)> =
9-
Lazy::new(|| ("pgml.venv", GucSetting::<Option<&'static CStr>>::new(None)));
10-
pub static PGML_HF_WHITELIST: Lazy<(&'static str, GucSetting<Option<&'static CStr>>)> = Lazy::new(|| {
11-
(
12-
"pgml.huggingface_whitelist",
13-
GucSetting::<Option<&'static CStr>>::new(None),
14-
)
15-
});
16-
pub static PGML_HF_TRUST_REMOTE_CODE: Lazy<(&'static str, GucSetting<bool>)> =
17-
Lazy::new(|| ("pgml.huggingface_trust_remote_code", GucSetting::<bool>::new(false)));
18-
pub static PGML_HF_TRUST_WHITELIST: Lazy<(&'static str, GucSetting<Option<&'static CStr>>)> = Lazy::new(|| {
19-
(
20-
"pgml.huggingface_trust_remote_code_whitelist",
21-
GucSetting::<Option<&'static CStr>>::new(None),
22-
)
23-
});
24-
pub static PGML_OMP_NUM_THREADS: Lazy<(&'static str, GucSetting<i32>)> =
25-
Lazy::new(|| ("pgml.omp_num_threads", GucSetting::<i32>::new(0)));
7+
pub static PGML_VENV: GucSetting<Option<&'static CStr>> = GucSetting::<Option<&'static CStr>>::new(None);
8+
pub static PGML_HF_WHITELIST: GucSetting<Option<&'static CStr>> = GucSetting::<Option<&'static CStr>>::new(None);
9+
pub static PGML_HF_TRUST_REMOTE_CODE: GucSetting<bool> = GucSetting::<bool>::new(false);
10+
pub static PGML_HF_TRUST_REMOTE_CODE_WHITELIST: GucSetting<Option<&'static CStr>> =
11+
GucSetting::<Option<&'static CStr>>::new(None);
12+
pub static PGML_OMP_NUM_THREADS: GucSetting<i32> = GucSetting::<i32>::new(0);
2613

2714
pub fn initialize_server_params() {
2815
GucRegistry::define_string_guc(
29-
PGML_VENV.0,
16+
"pgml.venv",
3017
"Python's virtual environment path",
3118
"",
32-
&PGML_VENV.1,
19+
&PGML_VENV,
3320
GucContext::Userset,
3421
GucFlags::default(),
3522
);
23+
3624
GucRegistry::define_string_guc(
37-
PGML_HF_WHITELIST.0,
25+
"pgml.huggingface_whitelist",
3826
"Models allowed to be downloaded from huggingface",
3927
"",
40-
&PGML_HF_WHITELIST.1,
28+
&PGML_HF_WHITELIST,
4129
GucContext::Userset,
4230
GucFlags::default(),
4331
);
32+
4433
GucRegistry::define_bool_guc(
45-
PGML_HF_TRUST_REMOTE_CODE.0,
34+
"pgml.huggingface_trust_remote_code",
4635
"Whether model can execute remote codes",
4736
"",
48-
&PGML_HF_TRUST_REMOTE_CODE.1,
37+
&PGML_HF_TRUST_REMOTE_CODE,
4938
GucContext::Userset,
5039
GucFlags::default(),
5140
);
41+
5242
GucRegistry::define_string_guc(
53-
PGML_HF_TRUST_WHITELIST.0,
43+
"pgml.huggingface_trust_remote_code_whitelist",
5444
"Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'",
5545
"",
56-
&PGML_HF_TRUST_WHITELIST.1,
46+
&PGML_HF_TRUST_REMOTE_CODE_WHITELIST,
5747
GucContext::Userset,
5848
GucFlags::default(),
5949
);
50+
6051
GucRegistry::define_int_guc(
61-
PGML_OMP_NUM_THREADS.0,
52+
"pgml.omp_num_threads",
6253
"Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid",
6354
"",
64-
&PGML_OMP_NUM_THREADS.1,
55+
&PGML_OMP_NUM_THREADS,
6556
0,
6657
i32::max_value(),
6758
GucContext::Backend,
@@ -87,7 +78,8 @@ mod tests {
8778
let name = "pgml.huggingface_whitelist";
8879
let value = "meta-llama/Llama-2-7b";
8980
set_config(name, value).unwrap();
90-
assert_eq!(PGML_HF_WHITELIST.1.get().unwrap().to_string_lossy(), value);
81+
assert_eq!(PGML_HF_WHITELIST.get().unwrap().to_str().unwrap(), value);
82+
//assert_eq!((&*PGML_HF_WHITELIST).get().unwrap().to_str().unwrap(), value);
9183
}
9284

9385
#[pg_test]

pgml-extension/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ extern "C" {
2929
#[pg_guard]
3030
pub extern "C" fn _PG_init() {
3131
config::initialize_server_params();
32-
let omp_num_threads = config::PGML_OMP_NUM_THREADS.1.get();
32+
let omp_num_threads = config::PGML_OMP_NUM_THREADS.get();
3333
if omp_num_threads > 0 {
3434
unsafe {
3535
omp_set_num_threads(omp_num_threads);

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy