Skip to content

Commit debd9ae

Browse files
Separate embedding kwargs into init kwargs and encode kwargs (#1555)
Co-authored-by: Montana Low <montana.low@gmail.com>
1 parent fec164a commit debd9ae

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

pgml-extension/examples/embedding.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
\timing on
2+
3+
SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"trust_remote_code": true}');
4+
SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cuda", "trust_remote_code": true}');
5+
SELECT pgml.embed('Alibaba-NLP/gte-base-en-v1.5', 'hi mom', '{"device": "cpu", "trust_remote_code": true}');
6+
SELECT pgml.embed('hkunlp/instructor-xl', 'hi mom', '{"instruction": "Encode it with love"}');
7+
SELECT pgml.embed('mixedbread-ai/mxbai-embed-large-v1', 'test', '{"prompt": "test prompt: "}');

pgml-extension/examples/image_classification.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperpara
6666

6767
-- runtimes
6868
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'python');
69-
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust');
69+
--SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear', runtime => 'rust');
7070

7171
--SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'python', hyperparams => '{"n_estimators": 10}'); -- too slow
7272
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', runtime => 'rust', hyperparams => '{"n_estimators": 10}');

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,8 @@ def rank(transformer, query, documents, kwargs):
527527
return rank_using(model, query, documents, kwargs)
528528

529529

530-
def create_embedding(transformer):
531-
return SentenceTransformer(transformer)
530+
def create_embedding(transformer, kwargs):
531+
return SentenceTransformer(transformer, **kwargs)
532532

533533

534534
def embed_using(model, transformer, inputs, kwargs):
@@ -545,16 +545,32 @@ def embed_using(model, transformer, inputs, kwargs):
545545

546546
def embed(transformer, inputs, kwargs):
547547
kwargs = orjson.loads(kwargs)
548-
549548
ensure_device(kwargs)
550549

550+
init_kwarg_keys = [
551+
"device",
552+
"trust_remote_code",
553+
"revision",
554+
"model_kwargs",
555+
"tokenizer_kwargs",
556+
"config_kwargs",
557+
"truncate_dim",
558+
"token",
559+
]
560+
init_kwargs = {
561+
key: value for key, value in kwargs.items() if key in init_kwarg_keys
562+
}
563+
encode_kwargs = {
564+
key: value for key, value in kwargs.items() if key not in init_kwarg_keys
565+
}
566+
551567
if transformer not in __cache_sentence_transformer_by_name:
552568
__cache_sentence_transformer_by_name[transformer] = create_embedding(
553-
transformer
569+
transformer, init_kwargs
554570
)
555571
model = __cache_sentence_transformer_by_name[transformer]
556572

557-
return embed_using(model, transformer, inputs, kwargs)
573+
return embed_using(model, transformer, inputs, encode_kwargs)
558574

559575

560576
def clear_gpu_cache(memory_usage: None):

pgml-extension/tests/test.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,6 @@ SELECT pgml.load_dataset('wine');
3131
\i examples/vectors.sql
3232
\i examples/chunking.sql
3333
\i examples/preprocessing.sql
34+
\i examples/embedding.sql
3435
-- transformers are generally too slow to run in the test suite
3536
--\i examples/transformers.sql

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy