Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pgml-extension/pgml_rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ rmp-serde = { version = "1.1.0" }
typetag = "0.2"
pyo3 = { version = "0.17", features = ["auto-initialize"] }
heapless = "0.7.13"
parking_lot = "0.12"

[dev-dependencies]
pgx-tests = "=0.4.5"
Expand Down
4 changes: 2 additions & 2 deletions pgml-extension/pgml_rust/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use parking_lot::Mutex;
use std::collections::HashMap;
use std::fmt::Write;
use std::str::FromStr;
use std::sync::Mutex;

use once_cell::sync::Lazy;
use pgx::*;
Expand Down Expand Up @@ -244,7 +244,7 @@ fn deploy(

#[pg_extern]
fn predict(project_name: &str, features: Vec<f32>) -> f32 {
let mut projects = PROJECT_NAME_TO_PROJECT_ID.lock().unwrap();
let mut projects = PROJECT_NAME_TO_PROJECT_ID.lock();
let project_id = match projects.get(project_name) {
Some(project_id) => *project_id,
None => {
Expand Down
6 changes: 3 additions & 3 deletions pgml-extension/pgml_rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use once_cell::sync::Lazy; // 1.3.1
use pgx::*;
use std::collections::HashMap;
use std::fs;
use std::sync::Mutex;
use parking_lot::Mutex;
use xgboost::{Booster, DMatrix};

pub mod api;
Expand All @@ -25,7 +25,7 @@ static MODELS: Lazy<Mutex<HashMap<i64, Vec<u8>>>> = Lazy::new(|| Mutex::new(Hash

#[pg_extern]
fn model_predict(model_id: i64, features: Vec<f32>) -> f32 {
let mut guard = MODELS.lock().unwrap();
let mut guard = MODELS.lock();

match guard.get(&model_id) {
Some(data) => {
Expand Down Expand Up @@ -59,7 +59,7 @@ fn model_predict(model_id: i64, features: Vec<f32>) -> f32 {

#[pg_extern]
fn model_predict_batch(model_id: i64, features: Vec<f32>, num_rows: i32) -> Vec<f32> {
let mut guard = MODELS.lock().unwrap();
let mut guard = MODELS.lock();

if num_rows < 0 {
error!("Number of rows has to be greater than 0");
Expand Down
6 changes: 3 additions & 3 deletions pgml-extension/pgml_rust/src/orm/estimator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::fmt::Debug;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::Mutex;
use parking_lot::Mutex;

use ndarray::{Array1, Array2};
use once_cell::sync::Lazy;
Expand All @@ -26,7 +26,7 @@ static DEPLOYED_ESTIMATORS_BY_MODEL_ID: Lazy<Mutex<HashMap<i64, Arc<Box<dyn Esti
pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimator>> {
// Get the estimator from process memory, if we already loaded it.
{
let estimators = DEPLOYED_ESTIMATORS_BY_MODEL_ID.lock().unwrap();
let estimators = DEPLOYED_ESTIMATORS_BY_MODEL_ID.lock();
if let Some(estimator) = estimators.get(&model_id) {
return estimator.clone();
}
Expand Down Expand Up @@ -88,7 +88,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
};

// Cache the estimator in process memory.
let mut estimators = DEPLOYED_ESTIMATORS_BY_MODEL_ID.lock().unwrap();
let mut estimators = DEPLOYED_ESTIMATORS_BY_MODEL_ID.lock();
estimators.insert(model_id, Arc::new(estimator));
estimators.get(&model_id).unwrap().clone()
}
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy