Skip to content

Commit a2101d3

Browse files
authored
Move things around (#331)
1 parent 53d9d21 commit a2101d3

File tree

8 files changed

+1096
-985
lines changed

8 files changed

+1096
-985
lines changed

pgml-extension/pgml_rust/src/api.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ fn train(
4040
search_args: default!(JsonB, "'{}'"),
4141
test_size: default!(f32, 0.25),
4242
test_sampling: default!(Sampling, "'last'"),
43-
engine: default!(Engine, "'sklearn'"),
43+
engine: Option<default!(Engine, "NULL")>,
4444
) -> impl std::iter::Iterator<
4545
Item = (
4646
name!(project, String),
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pub mod engine;
22
pub mod sklearn;
33
pub mod smartcore;
4+
pub mod xgboost;

pgml-extension/pgml_rust/src/engines/sklearn.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1-
use pgx::*;
1+
/// Scikit-Learn implementation.
2+
///
3+
/// Scikit needs no introduction. It implements dozens of industry-standard
4+
/// algorithms used in data science and machine learning.
5+
///
6+
/// It uses numpy as its dense matrix.
7+
///
8+
/// Our implementation below calls into Python wrappers
9+
/// defined in `src/engines/wrappers.py`.
210
use pyo3::prelude::*;
311
use pyo3::types::PyTuple;
412

5-
use std::collections::HashMap;
6-
713
use crate::orm::algorithm::Algorithm;
814
use crate::orm::dataset::Dataset;
915
use crate::orm::estimator::SklearnBox;
1016
use crate::orm::task::Task;
1117

18+
use pgx::*;
19+
1220
#[pg_extern]
1321
pub fn sklearn_version() -> String {
1422
let mut version = String::new();
@@ -25,7 +33,7 @@ pub fn sklearn_train(
2533
task: Task,
2634
algorithm: Algorithm,
2735
dataset: &Dataset,
28-
hyperparams: &JsonB,
36+
hyperparams: &serde_json::Map<std::string::String, serde_json::Value>,
2937
) -> SklearnBox {
3038
let module = include_str!(concat!(
3139
env!("CARGO_MANIFEST_DIR"),
@@ -75,12 +83,15 @@ pub fn sklearn_train(
7583
SklearnBox::new(estimator)
7684
}
7785

78-
pub fn sklearn_test(estimator: &SklearnBox, x_test: &[f32], num_features: usize) -> Vec<f32> {
86+
pub fn sklearn_test(estimator: &SklearnBox, dataset: &Dataset) -> Vec<f32> {
7987
let module = include_str!(concat!(
8088
env!("CARGO_MANIFEST_DIR"),
8189
"/src/engines/wrappers.py"
8290
));
8391

92+
let x_test = dataset.x_test();
93+
let num_features = dataset.num_features;
94+
8495
let y_hat: Vec<f32> = Python::with_gil(|py| -> Vec<f32> {
8596
let module = PyModule::from_code(py, module, "", "").unwrap();
8697
let predictor = module.getattr("predictor").unwrap();

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy