1
- use pgx:: * ;
1
+ /// Scikit-Learn implementation.
2
+ ///
3
+ /// Scikit needs no introduction. It implements dozens of industry-standard
4
+ /// algorithms used in data science and machine learning.
5
+ ///
6
+ /// It uses numpy as its dense matrix.
7
+ ///
8
+ /// Our implementation below calls into Python wrappers
9
+ /// defined in `src/engines/wrappers.py`.
2
10
use pyo3:: prelude:: * ;
3
11
use pyo3:: types:: PyTuple ;
4
12
5
- use std:: collections:: HashMap ;
6
-
7
13
use crate :: orm:: algorithm:: Algorithm ;
8
14
use crate :: orm:: dataset:: Dataset ;
9
15
use crate :: orm:: estimator:: SklearnBox ;
10
16
use crate :: orm:: task:: Task ;
11
17
18
+ use pgx:: * ;
19
+
12
20
#[ pg_extern]
13
21
pub fn sklearn_version ( ) -> String {
14
22
let mut version = String :: new ( ) ;
@@ -25,7 +33,7 @@ pub fn sklearn_train(
25
33
task : Task ,
26
34
algorithm : Algorithm ,
27
35
dataset : & Dataset ,
28
- hyperparams : & JsonB ,
36
+ hyperparams : & serde_json :: Map < std :: string :: String , serde_json :: Value > ,
29
37
) -> SklearnBox {
30
38
let module = include_str ! ( concat!(
31
39
env!( "CARGO_MANIFEST_DIR" ) ,
@@ -75,12 +83,15 @@ pub fn sklearn_train(
75
83
SklearnBox :: new ( estimator)
76
84
}
77
85
78
- pub fn sklearn_test ( estimator : & SklearnBox , x_test : & [ f32 ] , num_features : usize ) -> Vec < f32 > {
86
+ pub fn sklearn_test ( estimator : & SklearnBox , dataset : & Dataset ) -> Vec < f32 > {
79
87
let module = include_str ! ( concat!(
80
88
env!( "CARGO_MANIFEST_DIR" ) ,
81
89
"/src/engines/wrappers.py"
82
90
) ) ;
83
91
92
+ let x_test = dataset. x_test ( ) ;
93
+ let num_features = dataset. num_features ;
94
+
84
95
let y_hat: Vec < f32 > = Python :: with_gil ( |py| -> Vec < f32 > {
85
96
let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
86
97
let predictor = module. getattr ( "predictor" ) . unwrap ( ) ;
0 commit comments