@@ -276,6 +276,7 @@ fn train_joint(
276
276
deploy = false ;
277
277
}
278
278
}
279
+ _ => error ! ( "Training only supports `classification` and `regression` task types." )
279
280
}
280
281
}
281
282
}
@@ -345,6 +346,9 @@ fn deploy(
345
346
"{predicate}\n ORDER BY models.metrics->>'f1' DESC NULLS LAST"
346
347
) ;
347
348
}
349
+
350
+ _ => todo ! ( "Training only supports `classification` and `regression` task types." )
351
+
348
352
} ,
349
353
350
354
Strategy :: most_recent => {
@@ -525,6 +529,163 @@ pub fn transform_string(
525
529
) )
526
530
}
527
531
532
+ #[ cfg( feature = "python" ) ]
533
+ #[ allow( clippy:: too_many_arguments) ]
534
+ #[ pg_extern]
535
+ fn tune (
536
+ project_name : & str ,
537
+ task : default ! ( Option <Task >, "NULL" ) ,
538
+ relation_name : default ! ( Option <& str >, "NULL" ) ,
539
+ y_column_name : default ! ( Option <& str >, "NULL" ) ,
540
+ algorithm : default ! ( Algorithm , "transformers" ) ,
541
+ hyperparams : default ! ( JsonB , "'{}'" ) ,
542
+ search : default ! ( Option <Search >, "NULL" ) ,
543
+ search_params : default ! ( JsonB , "'{}'" ) ,
544
+ search_args : default ! ( JsonB , "'{}'" ) ,
545
+ test_size : default ! ( f32 , 0.25 ) ,
546
+ test_sampling : default ! ( Sampling , "'last'" ) ,
547
+ runtime : default ! ( Option <Runtime >, "NULL" ) ,
548
+ automatic_deploy : default ! ( Option <bool >, true ) ,
549
+ materialize_snapshot : default ! ( bool , false ) ,
550
+ preprocess : default ! ( JsonB , "'{}'" ) ,
551
+ ) -> TableIterator <
552
+ ' static ,
553
+ (
554
+ name ! ( status, String ) ,
555
+ name ! ( task, String ) ,
556
+ name ! ( algorithm, String ) ,
557
+ name ! ( deployed, bool ) ,
558
+ ) ,
559
+ > {
560
+ let project = match Project :: find_by_name ( project_name) {
561
+ Some ( project) => project,
562
+ None => Project :: create ( project_name, match task {
563
+ Some ( task) => task,
564
+ None => error ! ( "Project `{}` does not exist. To create a new project, provide the task (regression or classification)." , project_name) ,
565
+ } ) ,
566
+ } ;
567
+
568
+ if task. is_some ( ) && task. unwrap ( ) != project. task {
569
+ error ! ( "Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead." , project. name, project. task) ;
570
+ }
571
+
572
+ let mut snapshot = match relation_name {
573
+ None => {
574
+ let snapshot = project
575
+ . last_snapshot ( )
576
+ . expect ( "You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model." ) ;
577
+
578
+ info ! ( "Using existing snapshot from {}" , snapshot. snapshot_name( ) , ) ;
579
+
580
+ snapshot
581
+ }
582
+
583
+
584
+ Some ( relation_name) => {
585
+ info ! (
586
+ "Snapshotting table \" {}\" , this may take a little while..." ,
587
+ relation_name
588
+ ) ;
589
+
590
+ let snapshot = Snapshot :: create (
591
+ relation_name,
592
+ vec ! [ y_column_name. expect( "You must pass a `y_column_name` when you pass a `relation_name`" ) . to_string( ) ] ,
593
+ test_size,
594
+ test_sampling,
595
+ materialize_snapshot,
596
+ preprocess,
597
+ ) ;
598
+
599
+ if materialize_snapshot {
600
+ info ! (
601
+ "Snapshot of table \" {}\" created and saved in {}" ,
602
+ relation_name,
603
+ snapshot. snapshot_name( ) ,
604
+ ) ;
605
+ }
606
+
607
+ snapshot
608
+ }
609
+ } ;
610
+
611
+ // # Default repeatable random state when possible
612
+ // let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
613
+ // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
614
+ // hyperparams["random_state"] = 0
615
+ let model = Model :: create (
616
+ & project,
617
+ & mut snapshot,
618
+ algorithm,
619
+ hyperparams,
620
+ search,
621
+ search_params,
622
+ search_args,
623
+ runtime,
624
+ ) ;
625
+
626
+ let new_metrics: & serde_json:: Value = & model. metrics . unwrap ( ) . 0 ;
627
+ let new_metrics = new_metrics. as_object ( ) . unwrap ( ) ;
628
+
629
+ let deployed_metrics = Spi :: get_one_with_args :: < JsonB > (
630
+ "
631
+ SELECT models.metrics
632
+ FROM pgml.models
633
+ JOIN pgml.deployments
634
+ ON deployments.model_id = models.id
635
+ JOIN pgml.projects
636
+ ON projects.id = deployments.project_id
637
+ WHERE projects.name = $1
638
+ ORDER by deployments.created_at DESC
639
+ LIMIT 1;" ,
640
+ vec ! [ ( PgBuiltInOids :: TEXTOID . oid( ) , project_name. into_datum( ) ) ] ,
641
+ ) ;
642
+
643
+ let mut deploy = true ;
644
+ match automatic_deploy {
645
+ // Deploy only if metrics are better than previous model.
646
+ Some ( true ) | None => {
647
+ if let Ok ( Some ( deployed_metrics) ) = deployed_metrics {
648
+ let deployed_metrics = deployed_metrics. 0 . as_object ( ) . unwrap ( ) ;
649
+ match project. task {
650
+ Task :: classification => {
651
+ if deployed_metrics. get ( "f1" ) . unwrap ( ) . as_f64 ( )
652
+ > new_metrics. get ( "f1" ) . unwrap ( ) . as_f64 ( )
653
+ {
654
+ deploy = false ;
655
+ }
656
+ }
657
+ Task :: regression => {
658
+ if deployed_metrics. get ( "r2" ) . unwrap ( ) . as_f64 ( )
659
+ > new_metrics. get ( "r2" ) . unwrap ( ) . as_f64 ( )
660
+ {
661
+ deploy = false ;
662
+ }
663
+ }
664
+ _ => todo ! ( "Deploy tuned based on new metrics." )
665
+ }
666
+
667
+ }
668
+ }
669
+
670
+ Some ( false ) => deploy = false ,
671
+ } ;
672
+
673
+ if deploy {
674
+ project. deploy ( model. id ) ;
675
+ }
676
+
677
+ TableIterator :: new (
678
+ vec ! [ (
679
+ project. name,
680
+ project. task. to_string( ) ,
681
+ model. algorithm. to_string( ) ,
682
+ deploy,
683
+ ) ]
684
+ . into_iter ( ) ,
685
+ )
686
+ }
687
+
688
+
528
689
#[ cfg( feature = "python" ) ]
529
690
#[ pg_extern( name = "sklearn_f1_score" ) ]
530
691
pub fn sklearn_f1_score ( ground_truth : Vec < f32 > , y_hat : Vec < f32 > ) -> f32 {
@@ -811,3 +972,7 @@ mod tests {
811
972
load_all ( "/tmp" ) ;
812
973
}
813
974
}
975
+
976
+
977
+
978
+
0 commit comments