diff --git a/pgml-extension/examples/embedding.sql b/pgml-extension/examples/embedding.sql new file mode 100644 index 000000000..4e6c5968d --- /dev/null +++ b/pgml-extension/examples/embedding.sql @@ -0,0 +1,7 @@ +\timing on + +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"trust_remote_code": true}'); +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cuda", "trust_remote_code": true}'); +SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cpu", "trust_remote_code": true}'); +SELECT pgml.embed('hkunlp/instructor-xl', 'hi mom', '{"instruction": "Encode it with love"}'); +SELECT pgml.embed('mixedbread-ai/mxbai-embed-large-v1', 'test', '{"prompt": "test prompt: "}'); diff --git a/pgml-extension/examples/image_classification.sql b/pgml-extension/examples/image_classification.sql index f9a7888a6..24e363e4a 100644 --- a/pgml-extension/examples/image_classification.sql +++ b/pgml-extension/examples/image_classification.sql @@ -66,7 +66,7 @@ SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperpara -- runtimes SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'python'); -SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust'); +--SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust'); --SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'python', hyperparams => '{"n_estimators": 10}'); -- too slow SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'rust', hyperparams => '{"n_estimators": 10}'); diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index baa2c2500..ea2df12b9 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -527,8 +527,8 @@ def rank(transformer, query, documents, kwargs): return rank_using(model, query, documents, kwargs) -def create_embedding(transformer): - return SentenceTransformer(transformer) +def create_embedding(transformer, kwargs): + return SentenceTransformer(transformer, **kwargs) def embed_using(model, transformer, inputs, kwargs): @@ -545,16 +545,32 @@ def embed_using(model, transformer, inputs, kwargs): def embed(transformer, inputs, kwargs): kwargs = orjson.loads(kwargs) - ensure_device(kwargs) + init_kwarg_keys = [ + "device", + "trust_remote_code", + "revision", + "model_kwargs", + "tokenizer_kwargs", + "config_kwargs", + "truncate_dim", + "token", + ] + init_kwargs = { + key: value for key, value in kwargs.items() if key in init_kwarg_keys + } + encode_kwargs = { + key: value for key, value in kwargs.items() if key not in init_kwarg_keys + } + if transformer not in __cache_sentence_transformer_by_name: __cache_sentence_transformer_by_name[transformer] = create_embedding( - transformer + transformer, init_kwargs ) model = __cache_sentence_transformer_by_name[transformer] - return embed_using(model, transformer, inputs, kwargs) + return embed_using(model, transformer, inputs, encode_kwargs) def clear_gpu_cache(memory_usage: None): diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql index 2256e0ca4..10ffb4339 100644 --- a/pgml-extension/tests/test.sql +++ b/pgml-extension/tests/test.sql @@ -31,5 +31,6 @@ SELECT pgml.load_dataset('wine'); \i examples/vectors.sql \i examples/chunking.sql \i examples/preprocessing.sql +\i examples/embedding.sql -- transformers are generally too slow to run in the test suite --\i examples/transformers.sql
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: