From 455861eb1f759c1ede53cc97f58291d615a0bf5f Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 12 Jul 2024 12:35:22 +0200 Subject: [PATCH 1/2] Separate embedding kwargs into init kwargs and encode kwargs --- .../src/bindings/transformers/transformers.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index baa2c2500..ea2df12b9 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -527,8 +527,8 @@ def rank(transformer, query, documents, kwargs): return rank_using(model, query, documents, kwargs) -def create_embedding(transformer): - return SentenceTransformer(transformer) +def create_embedding(transformer, kwargs): + return SentenceTransformer(transformer, **kwargs) def embed_using(model, transformer, inputs, kwargs): @@ -545,16 +545,32 @@ def embed_using(model, transformer, inputs, kwargs): def embed(transformer, inputs, kwargs): kwargs = orjson.loads(kwargs) - ensure_device(kwargs) + init_kwarg_keys = [ + "device", + "trust_remote_code", + "revision", + "model_kwargs", + "tokenizer_kwargs", + "config_kwargs", + "truncate_dim", + "token", + ] + init_kwargs = { + key: value for key, value in kwargs.items() if key in init_kwarg_keys + } + encode_kwargs = { + key: value for key, value in kwargs.items() if key not in init_kwarg_keys + } + if transformer not in __cache_sentence_transformer_by_name: __cache_sentence_transformer_by_name[transformer] = create_embedding( - transformer + transformer, init_kwargs ) model = __cache_sentence_transformer_by_name[transformer] - return embed_using(model, transformer, inputs, kwargs) + return embed_using(model, transformer, inputs, encode_kwargs) def clear_gpu_cache(memory_usage: None): From 465f38d51542c6a5105b31fb2b239c635c42b75e Mon Sep 17 00:00:00 2001 From: Montana Low Date: Fri, 12 Jul 2024 07:51:55 -0700 Subject: [PATCH 2/2] move embedding tests into their own file --- pgml-extension/examples/embedding.sql | 7 +++++++ pgml-extension/examples/image_classification.sql | 2 +- pgml-extension/tests/test.sql | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 pgml-extension/examples/embedding.sql diff --git a/pgml-extension/examples/embedding.sql b/pgml-extension/examples/embedding.sql new file mode 100644 index 000000000..4e6c5968d --- /dev/null +++ b/pgml-extension/examples/embedding.sql @@ -0,0 +1,7 @@ +\timing on + +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"trust_remote_code": true}'); +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cuda", "trust_remote_code": true}'); +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cpu", "trust_remote_code": true}'); +SELECT pgml.embed('hkunlp/instructor-xl', 'hi mom', '{"instruction": "Encode it with love"}'); +SELECT pgml.embed('mixedbread-ai/mxbai-embed-large-v1', 'test', '{"prompt": "test prompt: "}'); diff --git a/pgml-extension/examples/image_classification.sql b/pgml-extension/examples/image_classification.sql index f9a7888a6..24e363e4a 100644 --- a/pgml-extension/examples/image_classification.sql +++ b/pgml-extension/examples/image_classification.sql @@ -66,7 +66,7 @@ SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperpara -- runtimes SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'python'); -SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust'); +--SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust'); --SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'python', hyperparams => '{"n_estimators": 10}'); -- too slow SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'rust', hyperparams => '{"n_estimators": 10}'); diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql index 2256e0ca4..10ffb4339 100644 --- a/pgml-extension/tests/test.sql +++ b/pgml-extension/tests/test.sql @@ -31,5 +31,6 @@ SELECT pgml.load_dataset('wine'); \i examples/vectors.sql \i examples/chunking.sql \i examples/preprocessing.sql +\i examples/embedding.sql -- transformers are generally too slow to run in the test suite --\i examples/transformers.sql pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy