Skip to content

Commit 5a54db4

Browse files
authored
disambiguate api (#566)
1 parent 00cc108 commit 5a54db4

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

pgml-extension/examples/image_classification.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ SELECT left(image::text, 40) || ',...}', target FROM pgml.digits LIMIT 10;
2222
SELECT * FROM pgml.train('Handwritten Digits', 'classification', 'pgml.digits', 'target');
2323

2424
-- check out the predictions
25-
SELECT target, pgml.predict('Handwritten Digits', image::FLOAT4[]) AS prediction
25+
SELECT target, pgml.predict('Handwritten Digits', image) AS prediction
2626
FROM pgml.digits
2727
LIMIT 10;
2828

@@ -95,6 +95,6 @@ SELECT * FROM pgml.deploy('Handwritten Digits', 'rollback');
9595
SELECT * FROM pgml.deploy('Handwritten Digits', 'best_score', 'svm');
9696

9797
-- check out the improved predictions
98-
SELECT target, pgml.predict('Handwritten Digits', image::FLOAT4[]) AS prediction
98+
SELECT target, pgml.predict('Handwritten Digits', image) AS prediction
9999
FROM pgml.digits
100100
LIMIT 10;

pgml-extension/src/api.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,35 @@ fn deploy(
414414
}
415415

416416
#[pg_extern(strict, name = "predict")]
417-
fn predict(project_name: &str, features: Vec<f32>) -> f32 {
417+
fn predict_f32(project_name: &str, features: Vec<f32>) -> f32 {
418418
predict_model(Project::get_deployed_model_id(project_name), features)
419419
}
420420

421+
#[pg_extern(strict, name = "predict")]
422+
fn predict_f64(project_name: &str, features: Vec<f64>) -> f32 {
423+
predict_f32(project_name, features.iter().map(|&i| i as f32).collect())
424+
}
425+
426+
#[pg_extern(strict, name = "predict")]
427+
fn predict_i16(project_name: &str, features: Vec<i16>) -> f32 {
428+
predict_f32(project_name, features.iter().map(|&i| i as f32).collect())
429+
}
430+
431+
#[pg_extern(strict, name = "predict")]
432+
fn predict_i32(project_name: &str, features: Vec<i32>) -> f32 {
433+
predict_f32(project_name, features.iter().map(|&i| i as f32).collect())
434+
}
435+
436+
#[pg_extern(strict, name = "predict")]
437+
fn predict_i64(project_name: &str, features: Vec<i64>) -> f32 {
438+
predict_f32(project_name, features.iter().map(|&i| i as f32).collect())
439+
}
440+
441+
#[pg_extern(strict, name = "predict")]
442+
fn predict_bool(project_name: &str, features: Vec<bool>) -> f32 {
443+
predict_f32(project_name, features.iter().map(|&i| i as u8 as f32).collect())
444+
}
445+
421446
#[pg_extern(strict, name = "predict_proba")]
422447
fn predict_proba(project_name: &str, features: Vec<f32>) -> Vec<f32> {
423448
predict_model_proba(Project::get_deployed_model_id(project_name), features)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy