Skip to content

Commit 995e0b2

Browse files
committed
data loading in rust
1 parent 69986e4 commit 995e0b2

File tree

4 files changed

+163
-161
lines changed

4 files changed

+163
-161
lines changed

pgml-extension/src/api.rs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -485,18 +485,23 @@ fn snapshot(
485485
#[pg_extern]
486486
fn load_dataset(
487487
source: &str,
488+
subset: default!(Option<String>, "NULL"),
488489
limit: default!(Option<i64>, "NULL"),
490+
kwargs: default!(JsonB, "'{}'"),
489491
) -> TableIterator<'static, (name!(table_name, String), name!(rows, i64))> {
490492
// cast limit since pgx doesn't support usize
491493
let limit: Option<usize> = limit.map(|limit| limit.try_into().unwrap());
492494
let (name, rows) = match source {
493-
"breast_cancer" => crate::orm::dataset::load_breast_cancer(limit),
494-
"diabetes" => crate::orm::dataset::load_diabetes(limit),
495-
"digits" => crate::orm::dataset::load_digits(limit),
496-
"iris" => crate::orm::dataset::load_iris(limit),
497-
"linnerud" => crate::orm::dataset::load_linnerud(limit),
498-
"wine" => crate::orm::dataset::load_wine(limit),
499-
_ => error!("Unknown source: `{source}`"),
495+
"breast_cancer" => dataset::load_breast_cancer(limit),
496+
"diabetes" => dataset::load_diabetes(limit),
497+
"digits" => dataset::load_digits(limit),
498+
"iris" => dataset::load_iris(limit),
499+
"linnerud" => dataset::load_linnerud(limit),
500+
"wine" => dataset::load_wine(limit),
501+
_ => {
502+
let rows = crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0);
503+
(source.into(), rows as i64)
504+
},
500505
};
501506

502507
TableIterator::new(vec![(name, rows)].into_iter())
@@ -537,7 +542,7 @@ fn tune(
537542
task: default!(Option<Task>, "NULL"),
538543
relation_name: default!(Option<&str>, "NULL"),
539544
y_column_name: default!(Option<&str>, "NULL"),
540-
algorithm: default!(Algorithm, "transformers"),
545+
algorithm: default!(Option<&str>, "NULL"),
541546
hyperparams: default!(JsonB, "'{}'"),
542547
search: default!(Option<Search>, "NULL"),
543548
search_params: default!(JsonB, "'{}'"),
@@ -608,14 +613,16 @@ fn tune(
608613
}
609614
};
610615

616+
let model_name = algorithm;
617+
611618
// # Default repeatable random state when possible
612619
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
613620
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
614621
// hyperparams["random_state"] = 0
615622
let model = Model::create(
616623
&project,
617624
&mut snapshot,
618-
algorithm,
625+
Algorithm::transformers,
619626
hyperparams,
620627
search,
621628
search_params,

pgml-extension/src/bindings/sklearn.rs

Lines changed: 50 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,25 @@
99
/// defined in `src/bindings/sklearn.py`.
1010
use std::collections::HashMap;
1111

12+
use once_cell::sync::Lazy;
1213
use pyo3::prelude::*;
1314
use pyo3::types::PyTuple;
1415

1516
use crate::bindings::Bindings;
1617

1718
use crate::orm::*;
1819

20+
static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(||
21+
Python::with_gil(|py| -> Py<PyModule> {
22+
let src = include_str!(concat!(
23+
env!("CARGO_MANIFEST_DIR"),
24+
"/src/bindings/sklearn.py"
25+
));
26+
27+
PyModule::from_code(py, src, "", "").unwrap().into()
28+
})
29+
);
30+
1931
pub fn linear_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Box<dyn Bindings> {
2032
fit(dataset, hyperparams, "linear_regression")
2133
}
@@ -290,17 +302,11 @@ fn fit(
290302
hyperparams: &Hyperparams,
291303
algorithm_task: &'static str,
292304
) -> Box<dyn Bindings> {
293-
let module = include_str!(concat!(
294-
env!("CARGO_MANIFEST_DIR"),
295-
"/src/bindings/sklearn.py"
296-
));
297-
298305
let hyperparams = serde_json::to_string(hyperparams).unwrap();
299306

300307
let (estimator, predict, predict_proba) =
301308
Python::with_gil(|py| -> (Py<PyAny>, Py<PyAny>, Py<PyAny>) {
302-
let module = PyModule::from_code(py, module, "", "").unwrap();
303-
let estimator: Py<PyAny> = module.getattr("estimator").unwrap().into();
309+
let estimator: Py<PyAny> = PY_MODULE.getattr(py, "estimator").unwrap().into();
304310

305311
let train: Py<PyAny> = estimator
306312
.call1(
@@ -321,20 +327,20 @@ fn fit(
321327
.call1(py, PyTuple::new(py, &[&dataset.x_train, &dataset.y_train]))
322328
.unwrap();
323329

324-
let predict: Py<PyAny> = module
325-
.getattr("predictor")
330+
let predict: Py<PyAny> = PY_MODULE
331+
.getattr(py, "predictor")
326332
.unwrap()
327-
.call1(PyTuple::new(py, &[&estimator]))
333+
.call1(py, PyTuple::new(py, &[&estimator]))
328334
.unwrap()
329-
.extract()
335+
.extract(py)
330336
.unwrap();
331337

332-
let predict_proba: Py<PyAny> = module
333-
.getattr("predictor_proba")
338+
let predict_proba: Py<PyAny> = PY_MODULE
339+
.getattr(py, "predictor_proba")
334340
.unwrap()
335-
.call1(PyTuple::new(py, &[&estimator]))
341+
.call1(py, PyTuple::new(py, &[&estimator]))
336342
.unwrap()
337-
.extract()
343+
.extract(py)
338344
.unwrap();
339345

340346
(estimator, predict, predict_proba)
@@ -389,17 +395,11 @@ impl Bindings for Estimator {
389395

390396
/// Serialize self to bytes
391397
fn to_bytes(&self) -> Vec<u8> {
392-
let module = include_str!(concat!(
393-
env!("CARGO_MANIFEST_DIR"),
394-
"/src/bindings/sklearn.py"
395-
));
396-
397398
Python::with_gil(|py| -> Vec<u8> {
398-
let module = PyModule::from_code(py, module, "", "").unwrap();
399-
let save = module.getattr("save").unwrap();
400-
save.call1(PyTuple::new(py, &[&self.estimator]))
399+
let save = PY_MODULE.getattr(py, "save").unwrap();
400+
save.call1(py, PyTuple::new(py, &[&self.estimator]))
401401
.unwrap()
402-
.extract()
402+
.extract(py)
403403
.unwrap()
404404
})
405405
}
@@ -409,34 +409,28 @@ impl Bindings for Estimator {
409409
where
410410
Self: Sized,
411411
{
412-
let module = include_str!(concat!(
413-
env!("CARGO_MANIFEST_DIR"),
414-
"/src/bindings/sklearn.py"
415-
));
416-
417412
Python::with_gil(|py| -> Box<dyn Bindings> {
418-
let module = PyModule::from_code(py, module, "", "").unwrap();
419-
let load = module.getattr("load").unwrap();
413+
let load = PY_MODULE.getattr(py, "load").unwrap();
420414
let estimator: Py<PyAny> = load
421-
.call1(PyTuple::new(py, &[bytes]))
415+
.call1(py,PyTuple::new(py, &[bytes]))
422416
.unwrap()
423-
.extract()
417+
.extract(py)
424418
.unwrap();
425419

426-
let predict: Py<PyAny> = module
427-
.getattr("predictor")
420+
let predict: Py<PyAny> = PY_MODULE
421+
.getattr(py,"predictor")
428422
.unwrap()
429-
.call1(PyTuple::new(py, &[&estimator]))
423+
.call1(py,PyTuple::new(py, &[&estimator]))
430424
.unwrap()
431-
.extract()
425+
.extract(py)
432426
.unwrap();
433427

434-
let predict_proba: Py<PyAny> = module
435-
.getattr("predictor_proba")
428+
let predict_proba: Py<PyAny> = PY_MODULE
429+
.getattr(py, "predictor_proba")
436430
.unwrap()
437-
.call1(PyTuple::new(py, &[&estimator]))
431+
.call1(py,PyTuple::new(py, &[&estimator]))
438432
.unwrap()
439-
.extract()
433+
.extract(py)
440434
.unwrap();
441435

442436
Box::new(Estimator {
@@ -449,18 +443,12 @@ impl Bindings for Estimator {
449443
}
450444

451445
fn sklearn_metric(name: &str, ground_truth: &[f32], y_hat: &[f32]) -> f32 {
452-
let module = include_str!(concat!(
453-
env!("CARGO_MANIFEST_DIR"),
454-
"/src/bindings/sklearn.py"
455-
));
456-
457446
Python::with_gil(|py| -> f32 {
458-
let module = PyModule::from_code(py, module, "", "").unwrap();
459-
let calculate_metric = module.getattr("calculate_metric").unwrap();
447+
let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap();
460448
let wrapper: Py<PyAny> = calculate_metric
461-
.call1(PyTuple::new(py, &[name]))
449+
.call1(py,PyTuple::new(py, &[name]))
462450
.unwrap()
463-
.extract()
451+
.extract(py)
464452
.unwrap();
465453

466454
let score: f32 = wrapper
@@ -490,18 +478,12 @@ pub fn recall(ground_truth: &[f32], y_hat: &[f32]) -> f32 {
490478
}
491479

492480
pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec<Vec<f32>> {
493-
let module = include_str!(concat!(
494-
env!("CARGO_MANIFEST_DIR"),
495-
"/src/bindings/sklearn.py"
496-
));
497-
498481
Python::with_gil(|py| -> Vec<Vec<f32>> {
499-
let module = PyModule::from_code(py, module, "", "").unwrap();
500-
let calculate_metric = module.getattr("calculate_metric").unwrap();
482+
let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap();
501483
let wrapper: Py<PyAny> = calculate_metric
502-
.call1(PyTuple::new(py, &["confusion_matrix"]))
484+
.call1(py,PyTuple::new(py, &["confusion_matrix"]))
503485
.unwrap()
504-
.extract()
486+
.extract(py)
505487
.unwrap();
506488

507489
let matrix: Vec<Vec<f32>> = wrapper
@@ -515,18 +497,12 @@ pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec<Vec<f32>> {
515497
}
516498

517499
pub fn regression_metrics(ground_truth: &[f32], y_hat: &[f32]) -> HashMap<String, f32> {
518-
let module = include_str!(concat!(
519-
env!("CARGO_MANIFEST_DIR"),
520-
"/src/bindings/sklearn.py"
521-
));
522-
523500
Python::with_gil(|py| -> HashMap<String, f32> {
524-
let module = PyModule::from_code(py, module, "", "").unwrap();
525-
let calculate_metric = module.getattr("regression_metrics").unwrap();
501+
let calculate_metric = PY_MODULE.getattr(py,"regression_metrics").unwrap();
526502
let scores: HashMap<String, f32> = calculate_metric
527-
.call1(PyTuple::new(py, &[ground_truth, y_hat]))
503+
.call1(py,PyTuple::new(py, &[ground_truth, y_hat]))
528504
.unwrap()
529-
.extract()
505+
.extract(py)
530506
.unwrap();
531507

532508
scores
@@ -538,18 +514,12 @@ pub fn classification_metrics(
538514
y_hat: &[f32],
539515
num_classes: usize,
540516
) -> HashMap<String, f32> {
541-
let module = include_str!(concat!(
542-
env!("CARGO_MANIFEST_DIR"),
543-
"/src/bindings/sklearn.py"
544-
));
545-
546517
let mut scores = Python::with_gil(|py| -> HashMap<String, f32> {
547-
let module = PyModule::from_code(py, module, "", "").unwrap();
548-
let calculate_metric = module.getattr("classification_metrics").unwrap();
518+
let calculate_metric = PY_MODULE.getattr(py, "classification_metrics").unwrap();
549519
let scores: HashMap<String, f32> = calculate_metric
550-
.call1(PyTuple::new(py, &[ground_truth, y_hat]))
520+
.call1(py,PyTuple::new(py, &[ground_truth, y_hat]))
551521
.unwrap()
552-
.extract()
522+
.extract(py)
553523
.unwrap();
554524

555525
scores
@@ -564,12 +534,8 @@ pub fn classification_metrics(
564534
}
565535

566536
pub fn package_version(name: &str) -> String {
567-
let mut version = String::new();
568-
569-
Python::with_gil(|py| {
537+
Python::with_gil(|py| -> String {
570538
let package = py.import(name).unwrap();
571-
version = package.getattr("__version__").unwrap().extract().unwrap();
572-
});
573-
574-
version
539+
package.getattr("__version__").unwrap().extract().unwrap()
540+
})
575541
}

pgml-extension/src/bindings/transformers.py

Lines changed: 15 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,75 +14,33 @@ def transform(task, args, inputs):
1414

1515
return json.dumps(pipe(inputs, **args))
1616

17-
def load_dataset(name, subset, limit: None, **kwargs):
17+
def load_dataset(name, subset, limit: None, kwargs: "{}"):
18+
kwargs = json.loads(kwargs)
19+
1820
if limit:
1921
dataset = datasets.load_dataset(name, subset, split=f"train[:{limit}]", **kwargs)
2022
else:
2123
dataset = datasets.load_dataset(name, subset, **kwargs)
2224

25+
dict = None
2326
if isinstance(dataset, datasets.Dataset):
24-
sample = dataset[0]
27+
sample = dataset.to_dict()
2528
elif isinstance(dataset, datasets.DatasetDict):
26-
sample = dataset["train"][0]
29+
dict = {}
30+
# Merge train/test splits, we'll re-split back in PostgresML.
31+
for name, split in dataset.items():
32+
for field, values in split.to_dict().items():
33+
if field in dict:
34+
dict[field] += values
35+
else:
36+
dict[field] = values
2737
else:
2838
raise PgMLException(f"Unhandled dataset type: {type(dataset)}")
2939

30-
columns = OrderedDict()
31-
for key, value in sample.items():
32-
column = c(key)
33-
columns[column] = _PYTHON_TO_PG_MAP[type(value)]
34-
35-
table_name = f"pgml.{c(name)}"
36-
plpy.execute(f"DROP TABLE IF EXISTS {table_name}")
37-
plpy.execute(f"""CREATE TABLE {table_name} ({", ".join([f"{name} {type}" for name, type in columns.items()])})""")
38-
39-
if isinstance(dataset, datasets.Dataset):
40-
load_dataset_rows(dataset, table_name)
41-
elif isinstance(dataset, datasets.DatasetDict):
42-
for name, rows in dataset.items():
43-
if name == "unsupervised":
44-
# postgresml doesn't provide unsupervised learning methods
45-
continue
46-
load_dataset_rows(rows, table_name)
47-
48-
49-
def load_dataset_rows(rows, table_name):
50-
for row in rows:
51-
plpy.execute(
52-
f"""INSERT INTO {table_name} ({", ".join([c(v) for v in row.keys()])})
53-
VALUES ({", ".join([q(v) for v in row.values()])})"""
54-
)
55-
56-
57-
def transform(task, args, inputs):
58-
cache = args.pop("cache", True)
59-
60-
# construct the cache key from task
61-
key = task
62-
if type(key) == dict:
63-
key = tuple(sorted(key.items()))
64-
65-
if cache and key in _pipeline_cache:
66-
pipe = _pipeline_cache.get(key)
67-
else:
68-
with timer("Initializing pipeline"):
69-
if type(task) == str:
70-
pipe = transformers.pipeline(task)
71-
else:
72-
pipe = transformers.pipeline(**task)
73-
if cache:
74-
_pipeline_cache[key] = pipe
75-
76-
if pipe.task == "question-answering":
77-
inputs = [json.loads(input) for input in inputs]
78-
79-
with timer("inference"):
80-
result = pipe(inputs, **args)
81-
82-
return result
40+
return json.dumps(dict)
8341

8442

85-
class Model(BaseModel):
43+
class Model:
8644
@property
8745
def algorithm(self):
8846
if self._algorithm is None:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy