diff --git a/pgml-extension/src/bindings/langchain/mod.rs b/pgml-extension/src/bindings/langchain/mod.rs index 00ee593fd..7d8d2582f 100644 --- a/pgml-extension/src/bindings/langchain/mod.rs +++ b/pgml-extension/src/bindings/langchain/mod.rs @@ -1,10 +1,9 @@ use anyhow::Result; -use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::{bindings::TracebackError, create_pymodule}; +use crate::create_pymodule; create_pymodule!("/src/bindings/langchain/langchain.py"); diff --git a/pgml-extension/src/bindings/python/mod.rs b/pgml-extension/src/bindings/python/mod.rs index 7f527b0fc..9ab7300c0 100644 --- a/pgml-extension/src/bindings/python/mod.rs +++ b/pgml-extension/src/bindings/python/mod.rs @@ -1,14 +1,13 @@ //! Use virtualenv. use anyhow::Result; -use once_cell::sync::Lazy; use pgrx::iter::TableIterator; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; use crate::config::get_config; -use crate::{bindings::TracebackError, create_pymodule}; +use crate::create_pymodule; static CONFIG_NAME: &str = "pgml.venv"; diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs index 05e85d97c..4b8ce6625 100644 --- a/pgml-extension/src/bindings/sklearn/mod.rs +++ b/pgml-extension/src/bindings/sklearn/mod.rs @@ -11,15 +11,10 @@ use pgrx::*; use std::collections::HashMap; use anyhow::Result; -use once_cell::sync::Lazy; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::{ - bindings::{Bindings, TracebackError}, - create_pymodule, - orm::*, -}; +use crate::{bindings::Bindings, create_pymodule, orm::*}; create_pymodule!("/src/bindings/sklearn/sklearn.py"); diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 91158f860..c4e262761 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -4,7 +4,6 @@ use std::str::FromStr; use std::{collections::HashMap, path::Path}; use anyhow::{anyhow, bail, Context, Result}; -use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -47,22 +46,22 @@ pub fn transform( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) })?; Ok(serde_json::from_str(&results)?) } pub fn get_model_from(task: &Value) -> Result { - Ok(Python::with_gil(|py| -> Result { + Python::with_gil(|py| -> Result { let get_model_from = get_module!(PY_MODULE) .getattr(py, "get_model_from") .format_traceback(py)?; let model = get_model_from .call1(py, PyTuple::new(py, &[task.to_string().into_py(py)])) .format_traceback(py)?; - Ok(model.extract(py).format_traceback(py)?) - })?) + model.extract(py).format_traceback(py) + }) } pub fn embed( @@ -91,7 +90,7 @@ pub fn embed( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) }) } @@ -126,7 +125,7 @@ pub fn tune( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) }) } @@ -176,7 +175,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result o, }; - Ok(result.extract(py).format_traceback(py)?) + result.extract(py).format_traceback(py) }) } @@ -227,7 +226,7 @@ pub fn load_dataset( let load_dataset: Py = get_module!(PY_MODULE) .getattr(py, "load_dataset") .format_traceback(py)?; - Ok(load_dataset + load_dataset .call1( py, PyTuple::new( @@ -242,7 +241,7 @@ pub fn load_dataset( ) .format_traceback(py)? .extract(py) - .format_traceback(py)?) + .format_traceback(py) })?; let table_name = format!("pgml.\"{}\"", name); diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index f220be89d..98c843691 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -241,29 +241,38 @@ def transform(task, args, inputs): return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() -def embed(transformer, inputs, kwargs): - kwargs = orjson.loads(kwargs) +def create_embedding(transformer): + instructor = transformer.startswith("hkunlp/instructor") + klass = INSTRUCTOR if instructor else SentenceTransformer + return klass(transformer) + + +def embed_using(model, transformer, inputs, kwargs): + if isinstance(kwargs, str): + kwargs = orjson.loads(kwargs) - ensure_device(kwargs) instructor = transformer.startswith("hkunlp/instructor") - if instructor: - klass = INSTRUCTOR - texts_with_instructions = [] instruction = kwargs.pop("instruction") for text in inputs: texts_with_instructions.append([instruction, text]) inputs = texts_with_instructions - else: - klass = SentenceTransformer + + return model.encode(inputs, **kwargs) + + +def embed(transformer, inputs, kwargs): + kwargs = orjson.loads(kwargs) + + ensure_device(kwargs) if transformer not in __cache_sentence_transformer_by_name: - __cache_sentence_transformer_by_name[transformer] = klass(transformer) + __cache_sentence_transformer_by_name[transformer] = create_embedding(transformer) model = __cache_sentence_transformer_by_name[transformer] - return model.encode(inputs, **kwargs) + return embed_using(model, transformer, inputs, kwargs) def clear_gpu_cache(memory_usage: None): diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index f87ff736a..89a23888c 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -378,12 +378,12 @@ impl Model { Ok(()) })?; - Ok(model.ok_or_else(|| { + model.ok_or_else(|| { anyhow!( "pgml.models WHERE id = {:?} could not be loaded. Does it exist?", id ) - })?) + }) } pub fn find_cached(id: i64) -> Result> { pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy