9
9
/// defined in `src/bindings/sklearn.py`.
10
10
use std:: collections:: HashMap ;
11
11
12
+ use once_cell:: sync:: Lazy ;
12
13
use pyo3:: prelude:: * ;
13
14
use pyo3:: types:: PyTuple ;
14
15
15
16
use crate :: bindings:: Bindings ;
16
17
17
18
use crate :: orm:: * ;
18
19
20
+ static PY_MODULE : Lazy < Py < PyModule > > = Lazy :: new ( ||
21
+ Python :: with_gil ( |py| -> Py < PyModule > {
22
+ let src = include_str ! ( concat!(
23
+ env!( "CARGO_MANIFEST_DIR" ) ,
24
+ "/src/bindings/sklearn.py"
25
+ ) ) ;
26
+
27
+ PyModule :: from_code ( py, src, "" , "" ) . unwrap ( ) . into ( )
28
+ } )
29
+ ) ;
30
+
19
31
pub fn linear_regression ( dataset : & Dataset , hyperparams : & Hyperparams ) -> Box < dyn Bindings > {
20
32
fit ( dataset, hyperparams, "linear_regression" )
21
33
}
@@ -290,17 +302,11 @@ fn fit(
290
302
hyperparams : & Hyperparams ,
291
303
algorithm_task : & ' static str ,
292
304
) -> Box < dyn Bindings > {
293
- let module = include_str ! ( concat!(
294
- env!( "CARGO_MANIFEST_DIR" ) ,
295
- "/src/bindings/sklearn.py"
296
- ) ) ;
297
-
298
305
let hyperparams = serde_json:: to_string ( hyperparams) . unwrap ( ) ;
299
306
300
307
let ( estimator, predict, predict_proba) =
301
308
Python :: with_gil ( |py| -> ( Py < PyAny > , Py < PyAny > , Py < PyAny > ) {
302
- let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
303
- let estimator: Py < PyAny > = module. getattr ( "estimator" ) . unwrap ( ) . into ( ) ;
309
+ let estimator: Py < PyAny > = PY_MODULE . getattr ( py, "estimator" ) . unwrap ( ) . into ( ) ;
304
310
305
311
let train: Py < PyAny > = estimator
306
312
. call1 (
@@ -321,20 +327,20 @@ fn fit(
321
327
. call1 ( py, PyTuple :: new ( py, & [ & dataset. x_train , & dataset. y_train ] ) )
322
328
. unwrap ( ) ;
323
329
324
- let predict: Py < PyAny > = module
325
- . getattr ( "predictor" )
330
+ let predict: Py < PyAny > = PY_MODULE
331
+ . getattr ( py , "predictor" )
326
332
. unwrap ( )
327
- . call1 ( PyTuple :: new ( py, & [ & estimator] ) )
333
+ . call1 ( py , PyTuple :: new ( py, & [ & estimator] ) )
328
334
. unwrap ( )
329
- . extract ( )
335
+ . extract ( py )
330
336
. unwrap ( ) ;
331
337
332
- let predict_proba: Py < PyAny > = module
333
- . getattr ( "predictor_proba" )
338
+ let predict_proba: Py < PyAny > = PY_MODULE
339
+ . getattr ( py , "predictor_proba" )
334
340
. unwrap ( )
335
- . call1 ( PyTuple :: new ( py, & [ & estimator] ) )
341
+ . call1 ( py , PyTuple :: new ( py, & [ & estimator] ) )
336
342
. unwrap ( )
337
- . extract ( )
343
+ . extract ( py )
338
344
. unwrap ( ) ;
339
345
340
346
( estimator, predict, predict_proba)
@@ -389,17 +395,11 @@ impl Bindings for Estimator {
389
395
390
396
/// Serialize self to bytes
391
397
fn to_bytes ( & self ) -> Vec < u8 > {
392
- let module = include_str ! ( concat!(
393
- env!( "CARGO_MANIFEST_DIR" ) ,
394
- "/src/bindings/sklearn.py"
395
- ) ) ;
396
-
397
398
Python :: with_gil ( |py| -> Vec < u8 > {
398
- let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
399
- let save = module. getattr ( "save" ) . unwrap ( ) ;
400
- save. call1 ( PyTuple :: new ( py, & [ & self . estimator ] ) )
399
+ let save = PY_MODULE . getattr ( py, "save" ) . unwrap ( ) ;
400
+ save. call1 ( py, PyTuple :: new ( py, & [ & self . estimator ] ) )
401
401
. unwrap ( )
402
- . extract ( )
402
+ . extract ( py )
403
403
. unwrap ( )
404
404
} )
405
405
}
@@ -409,34 +409,28 @@ impl Bindings for Estimator {
409
409
where
410
410
Self : Sized ,
411
411
{
412
- let module = include_str ! ( concat!(
413
- env!( "CARGO_MANIFEST_DIR" ) ,
414
- "/src/bindings/sklearn.py"
415
- ) ) ;
416
-
417
412
Python :: with_gil ( |py| -> Box < dyn Bindings > {
418
- let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
419
- let load = module. getattr ( "load" ) . unwrap ( ) ;
413
+ let load = PY_MODULE . getattr ( py, "load" ) . unwrap ( ) ;
420
414
let estimator: Py < PyAny > = load
421
- . call1 ( PyTuple :: new ( py, & [ bytes] ) )
415
+ . call1 ( py , PyTuple :: new ( py, & [ bytes] ) )
422
416
. unwrap ( )
423
- . extract ( )
417
+ . extract ( py )
424
418
. unwrap ( ) ;
425
419
426
- let predict: Py < PyAny > = module
427
- . getattr ( "predictor" )
420
+ let predict: Py < PyAny > = PY_MODULE
421
+ . getattr ( py , "predictor" )
428
422
. unwrap ( )
429
- . call1 ( PyTuple :: new ( py, & [ & estimator] ) )
423
+ . call1 ( py , PyTuple :: new ( py, & [ & estimator] ) )
430
424
. unwrap ( )
431
- . extract ( )
425
+ . extract ( py )
432
426
. unwrap ( ) ;
433
427
434
- let predict_proba: Py < PyAny > = module
435
- . getattr ( "predictor_proba" )
428
+ let predict_proba: Py < PyAny > = PY_MODULE
429
+ . getattr ( py , "predictor_proba" )
436
430
. unwrap ( )
437
- . call1 ( PyTuple :: new ( py, & [ & estimator] ) )
431
+ . call1 ( py , PyTuple :: new ( py, & [ & estimator] ) )
438
432
. unwrap ( )
439
- . extract ( )
433
+ . extract ( py )
440
434
. unwrap ( ) ;
441
435
442
436
Box :: new ( Estimator {
@@ -449,18 +443,12 @@ impl Bindings for Estimator {
449
443
}
450
444
451
445
fn sklearn_metric ( name : & str , ground_truth : & [ f32 ] , y_hat : & [ f32 ] ) -> f32 {
452
- let module = include_str ! ( concat!(
453
- env!( "CARGO_MANIFEST_DIR" ) ,
454
- "/src/bindings/sklearn.py"
455
- ) ) ;
456
-
457
446
Python :: with_gil ( |py| -> f32 {
458
- let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
459
- let calculate_metric = module. getattr ( "calculate_metric" ) . unwrap ( ) ;
447
+ let calculate_metric = PY_MODULE . getattr ( py, "calculate_metric" ) . unwrap ( ) ;
460
448
let wrapper: Py < PyAny > = calculate_metric
461
- . call1 ( PyTuple :: new ( py, & [ name] ) )
449
+ . call1 ( py , PyTuple :: new ( py, & [ name] ) )
462
450
. unwrap ( )
463
- . extract ( )
451
+ . extract ( py )
464
452
. unwrap ( ) ;
465
453
466
454
let score: f32 = wrapper
@@ -490,18 +478,12 @@ pub fn recall(ground_truth: &[f32], y_hat: &[f32]) -> f32 {
490
478
}
491
479
492
480
pub fn confusion_matrix ( ground_truth : & [ f32 ] , y_hat : & [ f32 ] ) -> Vec < Vec < f32 > > {
493
- let module = include_str ! ( concat!(
494
- env!( "CARGO_MANIFEST_DIR" ) ,
495
- "/src/bindings/sklearn.py"
496
- ) ) ;
497
-
498
481
Python :: with_gil ( |py| -> Vec < Vec < f32 > > {
499
- let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
500
- let calculate_metric = module. getattr ( "calculate_metric" ) . unwrap ( ) ;
482
+ let calculate_metric = PY_MODULE . getattr ( py, "calculate_metric" ) . unwrap ( ) ;
501
483
let wrapper: Py < PyAny > = calculate_metric
502
- . call1 ( PyTuple :: new ( py, & [ "confusion_matrix" ] ) )
484
+ . call1 ( py , PyTuple :: new ( py, & [ "confusion_matrix" ] ) )
503
485
. unwrap ( )
504
- . extract ( )
486
+ . extract ( py )
505
487
. unwrap ( ) ;
506
488
507
489
let matrix: Vec < Vec < f32 > > = wrapper
@@ -515,18 +497,12 @@ pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec<Vec<f32>> {
515
497
}
516
498
517
499
pub fn regression_metrics ( ground_truth : & [ f32 ] , y_hat : & [ f32 ] ) -> HashMap < String , f32 > {
518
- let module = include_str ! ( concat!(
519
- env!( "CARGO_MANIFEST_DIR" ) ,
520
- "/src/bindings/sklearn.py"
521
- ) ) ;
522
-
523
500
Python :: with_gil ( |py| -> HashMap < String , f32 > {
524
- let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
525
- let calculate_metric = module. getattr ( "regression_metrics" ) . unwrap ( ) ;
501
+ let calculate_metric = PY_MODULE . getattr ( py, "regression_metrics" ) . unwrap ( ) ;
526
502
let scores: HashMap < String , f32 > = calculate_metric
527
- . call1 ( PyTuple :: new ( py, & [ ground_truth, y_hat] ) )
503
+ . call1 ( py , PyTuple :: new ( py, & [ ground_truth, y_hat] ) )
528
504
. unwrap ( )
529
- . extract ( )
505
+ . extract ( py )
530
506
. unwrap ( ) ;
531
507
532
508
scores
@@ -538,18 +514,12 @@ pub fn classification_metrics(
538
514
y_hat : & [ f32 ] ,
539
515
num_classes : usize ,
540
516
) -> HashMap < String , f32 > {
541
- let module = include_str ! ( concat!(
542
- env!( "CARGO_MANIFEST_DIR" ) ,
543
- "/src/bindings/sklearn.py"
544
- ) ) ;
545
-
546
517
let mut scores = Python :: with_gil ( |py| -> HashMap < String , f32 > {
547
- let module = PyModule :: from_code ( py, module, "" , "" ) . unwrap ( ) ;
548
- let calculate_metric = module. getattr ( "classification_metrics" ) . unwrap ( ) ;
518
+ let calculate_metric = PY_MODULE . getattr ( py, "classification_metrics" ) . unwrap ( ) ;
549
519
let scores: HashMap < String , f32 > = calculate_metric
550
- . call1 ( PyTuple :: new ( py, & [ ground_truth, y_hat] ) )
520
+ . call1 ( py , PyTuple :: new ( py, & [ ground_truth, y_hat] ) )
551
521
. unwrap ( )
552
- . extract ( )
522
+ . extract ( py )
553
523
. unwrap ( ) ;
554
524
555
525
scores
@@ -564,12 +534,8 @@ pub fn classification_metrics(
564
534
}
565
535
566
536
pub fn package_version ( name : & str ) -> String {
567
- let mut version = String :: new ( ) ;
568
-
569
- Python :: with_gil ( |py| {
537
+ Python :: with_gil ( |py| -> String {
570
538
let package = py. import ( name) . unwrap ( ) ;
571
- version = package. getattr ( "__version__" ) . unwrap ( ) . extract ( ) . unwrap ( ) ;
572
- } ) ;
573
-
574
- version
539
+ package. getattr ( "__version__" ) . unwrap ( ) . extract ( ) . unwrap ( )
540
+ } )
575
541
}
0 commit comments