From 8a72cfff496fae16ade65166c9a79448d77b2ff2 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Tue, 19 Sep 2023 09:53:42 -0500 Subject: [PATCH 1/3] separate embed model creation and usage --- .../src/bindings/transformers/transformers.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index f220be89d..98c843691 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -241,29 +241,38 @@ def transform(task, args, inputs): return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() -def embed(transformer, inputs, kwargs): - kwargs = orjson.loads(kwargs) +def create_embedding(transformer): + instructor = transformer.startswith("hkunlp/instructor") + klass = INSTRUCTOR if instructor else SentenceTransformer + return klass(transformer) + + +def embed_using(model, transformer, inputs, kwargs): + if isinstance(kwargs, str): + kwargs = orjson.loads(kwargs) - ensure_device(kwargs) instructor = transformer.startswith("hkunlp/instructor") - if instructor: - klass = INSTRUCTOR - texts_with_instructions = [] instruction = kwargs.pop("instruction") for text in inputs: texts_with_instructions.append([instruction, text]) inputs = texts_with_instructions - else: - klass = SentenceTransformer + + return model.encode(inputs, **kwargs) + + +def embed(transformer, inputs, kwargs): + kwargs = orjson.loads(kwargs) + + ensure_device(kwargs) if transformer not in __cache_sentence_transformer_by_name: - __cache_sentence_transformer_by_name[transformer] = klass(transformer) + __cache_sentence_transformer_by_name[transformer] = create_embedding(transformer) model = __cache_sentence_transformer_by_name[transformer] - return model.encode(inputs, **kwargs) + return embed_using(model, transformer, inputs, kwargs) def clear_gpu_cache(memory_usage: None): From 5bf5d001396268eba87c7e0c848f0b08f0659cdc Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Tue, 19 Sep 2023 11:16:38 -0500 Subject: [PATCH 2/3] fix dead code --- pgml-extension/src/bindings/langchain/mod.rs | 3 +-- pgml-extension/src/bindings/python/mod.rs | 3 +-- pgml-extension/src/bindings/sklearn/mod.rs | 7 +------ pgml-extension/src/bindings/transformers/mod.rs | 1 - 4 files changed, 3 insertions(+), 11 deletions(-) diff --git a/pgml-extension/src/bindings/langchain/mod.rs b/pgml-extension/src/bindings/langchain/mod.rs index 00ee593fd..7d8d2582f 100644 --- a/pgml-extension/src/bindings/langchain/mod.rs +++ b/pgml-extension/src/bindings/langchain/mod.rs @@ -1,10 +1,9 @@ use anyhow::Result; -use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::{bindings::TracebackError, create_pymodule}; +use crate::create_pymodule; create_pymodule!("/src/bindings/langchain/langchain.py"); diff --git a/pgml-extension/src/bindings/python/mod.rs b/pgml-extension/src/bindings/python/mod.rs index 7f527b0fc..9ab7300c0 100644 --- a/pgml-extension/src/bindings/python/mod.rs +++ b/pgml-extension/src/bindings/python/mod.rs @@ -1,14 +1,13 @@ //! Use virtualenv. use anyhow::Result; -use once_cell::sync::Lazy; use pgrx::iter::TableIterator; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; use crate::config::get_config; -use crate::{bindings::TracebackError, create_pymodule}; +use crate::create_pymodule; static CONFIG_NAME: &str = "pgml.venv"; diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs index 05e85d97c..4b8ce6625 100644 --- a/pgml-extension/src/bindings/sklearn/mod.rs +++ b/pgml-extension/src/bindings/sklearn/mod.rs @@ -11,15 +11,10 @@ use pgrx::*; use std::collections::HashMap; use anyhow::Result; -use once_cell::sync::Lazy; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::{ - bindings::{Bindings, TracebackError}, - create_pymodule, - orm::*, -}; +use crate::{bindings::Bindings, create_pymodule, orm::*}; create_pymodule!("/src/bindings/sklearn/sklearn.py"); diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 91158f860..fbdeec4f8 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -4,7 +4,6 @@ use std::str::FromStr; use std::{collections::HashMap, path::Path}; use anyhow::{anyhow, bail, Context, Result}; -use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; From 4e67b52774646f440513b3331bc7a9fe2381f403 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Tue, 19 Sep 2023 11:17:19 -0500 Subject: [PATCH 3/3] fix clippy lints --- .../src/bindings/transformers/mod.rs | 18 +++++++++--------- pgml-extension/src/orm/model.rs | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index fbdeec4f8..c4e262761 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -46,22 +46,22 @@ pub fn transform( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) })?; Ok(serde_json::from_str(&results)?) } pub fn get_model_from(task: &Value) -> Result { - Ok(Python::with_gil(|py| -> Result { + Python::with_gil(|py| -> Result { let get_model_from = get_module!(PY_MODULE) .getattr(py, "get_model_from") .format_traceback(py)?; let model = get_model_from .call1(py, PyTuple::new(py, &[task.to_string().into_py(py)])) .format_traceback(py)?; - Ok(model.extract(py).format_traceback(py)?) - })?) + model.extract(py).format_traceback(py) + }) } pub fn embed( @@ -90,7 +90,7 @@ pub fn embed( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) }) } @@ -125,7 +125,7 @@ pub fn tune( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) }) } @@ -175,7 +175,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result o, }; - Ok(result.extract(py).format_traceback(py)?) + result.extract(py).format_traceback(py) }) } @@ -226,7 +226,7 @@ pub fn load_dataset( let load_dataset: Py = get_module!(PY_MODULE) .getattr(py, "load_dataset") .format_traceback(py)?; - Ok(load_dataset + load_dataset .call1( py, PyTuple::new( @@ -241,7 +241,7 @@ pub fn load_dataset( ) .format_traceback(py)? .extract(py) - .format_traceback(py)?) + .format_traceback(py) })?; let table_name = format!("pgml.\"{}\"", name); diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index f87ff736a..89a23888c 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -378,12 +378,12 @@ impl Model { Ok(()) })?; - Ok(model.ok_or_else(|| { + model.ok_or_else(|| { anyhow!( "pgml.models WHERE id = {:?} could not be loaded. Does it exist?", id ) - })?) + }) } pub fn find_cached(id: i64) -> Result> { pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy