diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index c9d5723db..74f0c7825 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -171,6 +171,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "base64ct" version = "1.6.0" @@ -244,6 +250,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-targets 0.52.0", ] @@ -267,7 +274,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim", + "strsim 0.10.0", ] [[package]] @@ -457,8 +464,18 @@ version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.14.4", + "darling_macro 0.14.4", +] + +[[package]] +name = "darling" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +dependencies = [ + "darling_core 0.20.9", + "darling_macro 0.20.9", ] [[package]] @@ -471,21 +488,46 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim", + "strsim 0.10.0", "syn 1.0.109", ] +[[package]] +name = "darling_core" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.48", +] + [[package]] name = "darling_macro" version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ - "darling_core", + "darling_core 0.14.4", "quote", "syn 1.0.109", ] +[[package]] +name = "darling_macro" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +dependencies = [ + "darling_core 0.20.9", + "quote", + "syn 2.0.48", +] + [[package]] name = "der" version = "0.7.8" @@ -504,6 +546,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -789,13 +832,19 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 2.2.2", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.3" @@ -812,7 +861,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown", + "hashbrown 0.14.3", ] [[package]] @@ -973,6 +1022,17 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + [[package]] name = "indexmap" version = "2.2.2" @@ -980,7 +1040,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.3", + "serde", ] [[package]] @@ -1558,6 +1619,7 @@ dependencies = [ "sea-query-binder", "serde", "serde_json", + "serde_with", "sqlx", "tokio", "tracing", @@ -1822,7 +1884,7 @@ version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "encoding_rs", "futures-core", @@ -1951,7 +2013,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.7", ] [[package]] @@ -2023,7 +2085,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "878cf3d57f0e5bfacd425cdaccc58b4c06d68a7b71c63fc28710a20c88676808" dependencies = [ - "darling", + "darling 0.14.4", "heck", "quote", "syn 1.0.109", @@ -2135,6 +2197,36 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.2.2", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" +dependencies = [ + "darling 0.20.9", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "sha1" version = "0.10.6" @@ -2302,7 +2394,7 @@ dependencies = [ "futures-util", "hashlink", "hex", - "indexmap", + "indexmap 2.2.2", "log", "memchr", "once_cell", @@ -2372,7 +2464,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" dependencies = [ "atoi", - "base64", + "base64 0.21.7", "bitflags 2.4.2", "byteorder", "bytes", @@ -2416,7 +2508,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" dependencies = [ "atoi", - "base64", + "base64 0.21.7", "bitflags 2.4.2", "byteorder", "crc", @@ -2492,6 +2584,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.5.0" diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index ba4037bce..21474428b 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -46,6 +46,7 @@ inquire = "0.6" parking_lot = "0.12.1" once_cell = "1.19.0" url = "2.5.0" +serde_with = "3.8.1" [features] default = [] diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index 9fa4e4954..f35e8efbb 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -74,7 +74,7 @@ it("can create builtins", () => { it("can search", async () => { let pipeline = pgml.newPipeline("test_j_p_cs", { - title: { semantic_search: { model: "intfloat/e5-small" } }, + title: { semantic_search: { model: "intfloat/e5-small-v2", parameters: { prompt: "passage: " } } }, body: { splitter: { model: "recursive_character" }, semantic_search: { @@ -92,17 +92,19 @@ it("can search", async () => { query: { full_text_search: { body: { query: "Test", boost: 1.2 } }, semantic_search: { - title: { query: "This is a test", boost: 2.0 }, + title: { + query: "This is a test", parameters: { prompt: "query: " }, boost: 2.0 + }, body: { query: "This is the body test", boost: 1.01 }, }, filter: { id: { $gt: 1 } }, - }, + }, limit: 10 }, pipeline, ); let ids = results["results"].map((r: any) => r["id"]); - expect(ids).toEqual([5, 4, 3]); + expect(ids).toEqual([4, 3, 5]); await collection.archive(); }); @@ -110,11 +112,10 @@ it("can search", async () => { // Test various vector searches /////////////////// /////////////////////////////////////////////////// - it("can vector search", async () => { - let pipeline = pgml.newPipeline("test_j_p_cvs_0", { + let pipeline = pgml.newPipeline("1", { title: { - semantic_search: { model: "intfloat/e5-small" }, + semantic_search: { model: "intfloat/e5-small-v2", parameters: { prompt: "passage: " } }, full_text_search: { configuration: "english" }, }, body: { @@ -132,7 +133,7 @@ it("can vector search", async () => { { query: { fields: { - title: { query: "Test document: 2", full_text_filter: "test" }, + title: { query: "Test document: 2", parameters: { prompt: "query: " }, full_text_filter: "test" }, body: { query: "Test document: 2" }, }, filter: { id: { "$gt": 2 } }, @@ -142,14 +143,14 @@ it("can vector search", async () => { pipeline, ); let ids = results.map(r => r["document"]["id"]); - expect(ids).toEqual([3, 4, 4, 3]); + expect(ids).toEqual([4, 3, 3, 4]); await collection.archive(); }); it("can vector search with query builder", async () => { - let model = pgml.newModel(); + let model = pgml.newModel("intfloat/e5-small-v2", "pgml", { prompt: "passage: " }); let splitter = pgml.newSplitter(); - let pipeline = pgml.newSingleFieldPipeline("test_j_p_cvswqb_0", model, splitter); + let pipeline = pgml.newSingleFieldPipeline("0", model, splitter); let collection = pgml.newCollection("test_j_c_cvswqb_2"); await collection.upsert_documents(generate_dummy_documents(3)); await collection.add_pipeline(pipeline); @@ -159,10 +160,101 @@ it("can vector search with query builder", async () => { .limit(10) .fetch_all(); let ids = results.map(r => r[2]["id"]); - expect(ids).toEqual([2, 1, 0]); + expect(ids).toEqual([1, 2, 0]); await collection.archive(); }); +/////////////////////////////////////////////////// +// Test rag /////////////////////////////////////// +/////////////////////////////////////////////////// + +it("can rag", async () => { + let pipeline = pgml.newPipeline("0", { + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small-v2", + parameters: { prompt: "passage: " }, + }, + }, + }); + let collection = pgml.newCollection("test_j_c_cr_0") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + const results = await collection.rag( + { + "CONTEXT": { + vector_search: { + query: { + fields: { + body: { query: "Test document: 2", parameters: { prompt: "query: " } }, + }, + }, + document: { keys: ["id"] }, + limit: 5, + }, + aggregate: { join: "\n" }, + }, + completion: { + model: "meta-llama/Meta-Llama-3-8B-Instruct", + prompt: "Some text with {CONTEXT}", + max_tokens: 10, + }, + }, + pipeline + ); + expect(results["rag"][0].length).toBeGreaterThan(0); + expect(results["sources"]["CONTEXT"].length).toBeGreaterThan(0); + await collection.archive() +}) + + +it("can rag stream", async () => { + let pipeline = pgml.newPipeline("0", { + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small-v2", + parameters: { prompt: "passage: " }, + }, + }, + }); + let collection = pgml.newCollection("test_j_c_cr_0") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + const results = await collection.rag_stream( + { + "CONTEXT": { + vector_search: { + query: { + fields: { + body: { query: "Test document: 2", parameters: { prompt: "query: " } }, + }, + }, + document: { keys: ["id"] }, + limit: 5, + }, + aggregate: { join: "\n" }, + }, + completion: { + model: "meta-llama/Meta-Llama-3-8B-Instruct", + prompt: "Some text with {CONTEXT}", + max_tokens: 10, + }, + }, + pipeline + ); + let output = []; + let it = results.stream(); + let result = await it.next(); + while (!result.done) { + output.push(result.value); + result = await it.next(); + } + expect(output.length).toBeGreaterThan(0); + await collection.archive() +}) + /////////////////////////////////////////////////// // Test document related functions //////////////// /////////////////////////////////////////////////// @@ -222,14 +314,14 @@ it("can order documents", async () => { /////////////////////////////////////////////////// it("can transformer pipeline", async () => { - const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform(["AI is going to"], { max_new_tokens: 5 }); + const t = pgml.newTransformerPipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct"); + const it = await t.transform(["AI is going to"], { max_tokens: 5 }); expect(it.length).toBeGreaterThan(0) }); it("can transformer pipeline stream", async () => { - const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform_stream("AI is going to", { max_new_tokens: 5 }); + const t = pgml.newTransformerPipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct"); + const it = await t.transform_stream("AI is going to", { max_tokens: 5 }); let result = await it.next(); let output = []; while (!result.done) { @@ -246,7 +338,7 @@ it("can transformer pipeline stream", async () => { it("can open source ai create", () => { const client = pgml.newOpenSourceAI(); const results = client.chat_completions_create( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { role: "system", @@ -257,6 +349,7 @@ it("can open source ai create", () => { content: "How many helicopters can a human eat in one sitting?", }, ], + 10 ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -265,7 +358,7 @@ it("can open source ai create", () => { it("can open source ai create async", async () => { const client = pgml.newOpenSourceAI(); const results = await client.chat_completions_create_async( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { role: "system", @@ -276,6 +369,7 @@ it("can open source ai create async", async () => { content: "How many helicopters can a human eat in one sitting?", }, ], + 10 ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -284,7 +378,7 @@ it("can open source ai create async", async () => { it("can open source ai create stream", () => { const client = pgml.newOpenSourceAI(); const it = client.chat_completions_create_stream( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { role: "system", @@ -295,10 +389,11 @@ it("can open source ai create stream", () => { content: "How many helicopters can a human eat in one sitting?", }, ], + 10 ); let result = it.next(); while (!result.done) { - expect(result.value.choices.length).toBeGreaterThan(0); + expect(result.value.choices.length).toBeGreaterThanOrEqual(0); result = it.next(); } }); @@ -306,7 +401,7 @@ it("can open source ai create stream", () => { it("can open source ai create stream async", async () => { const client = pgml.newOpenSourceAI(); const it = await client.chat_completions_create_stream_async( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { role: "system", @@ -317,10 +412,11 @@ it("can open source ai create stream async", async () => { content: "How many helicopters can a human eat in one sitting?", }, ], + 10 ); let result = await it.next(); while (!result.done) { - expect(result.value.choices.length).toBeGreaterThan(0); + expect(result.value.choices.length).toBeGreaterThanOrEqual(0); result = await it.next(); } }); diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index e4186d4d3..87adf5ba7 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -83,7 +83,12 @@ async def test_can_search(): pipeline = pgml.Pipeline( "test_p_p_tcs_0", { - "title": {"semantic_search": {"model": "intfloat/e5-small"}}, + "title": { + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + } + }, "body": { "splitter": {"model": "recursive_character"}, "semantic_search": { @@ -102,7 +107,11 @@ async def test_can_search(): "query": { "full_text_search": {"body": {"query": "Test", "boost": 1.2}}, "semantic_search": { - "title": {"query": "This is a test", "boost": 2.0}, + "title": { + "query": "This is a test", + "parameters": {"prompt": "passage: "}, + "boost": 2.0, + }, "body": {"query": "This is the body test", "boost": 1.01}, }, "filter": {"id": {"$gt": 1}}, @@ -112,7 +121,7 @@ async def test_can_search(): pipeline, ) ids = [result["id"] for result in results["results"]] - assert ids == [5, 4, 3] + assert ids == [3, 5, 4] await collection.archive() @@ -127,12 +136,18 @@ async def test_can_vector_search(): "test_p_p_tcvs_0", { "title": { - "semantic_search": {"model": "intfloat/e5-small"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, "full_text_search": {"configuration": "english"}, }, "text": { "splitter": {"model": "recursive_character"}, - "semantic_search": {"model": "intfloat/e5-small"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, }, }, ) @@ -143,8 +158,15 @@ async def test_can_vector_search(): { "query": { "fields": { - "title": {"query": "Test document: 2", "full_text_filter": "test"}, - "text": {"query": "Test document: 2"}, + "title": { + "query": "Test document: 2", + "parameters": {"prompt": "passage: "}, + "full_text_filter": "test", + }, + "text": { + "query": "Test document: 2", + "parameters": {"prompt": "passage: "}, + }, }, "filter": {"id": {"$gt": 2}}, }, @@ -159,7 +181,7 @@ async def test_can_vector_search(): @pytest.mark.asyncio async def test_can_vector_search_with_query_builder(): - model = pgml.Model() + model = pgml.Model("intfloat/e5-small-v2", "pgml", {"prompt": "passage: "}) splitter = pgml.Splitter() pipeline = pgml.SingleFieldPipeline("test_p_p_tcvswqb_1", model, splitter) collection = pgml.Collection(name="test_p_c_tcvswqb_5") @@ -172,7 +194,106 @@ async def test_can_vector_search_with_query_builder(): .fetch_all() ) ids = [document["id"] for (_, _, document) in results] - assert ids == [2, 1, 0] + assert ids == [1, 2, 0] + await collection.archive() + + +################################################### +## Test RAG ####################################### +################################################### + + +@pytest.mark.asyncio +async def test_can_rag(): + pipeline = pgml.Pipeline( + "1", + { + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, + }, + }, + ) + collection = pgml.Collection("test_p_c_cr") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.rag( + { + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "test", + "parameters": {"prompt": "query: "}, + }, + }, + }, + "document": {"keys": ["id"]}, + "limit": 5, + }, + "aggregate": {"join": "\n"}, + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT}", + "max_tokens": 10, + }, + }, + pipeline, + ) + assert len(results["rag"][0]) > 0 + assert len(results["sources"]["CONTEXT"]) > 0 + await collection.archive() + + +@pytest.mark.asyncio +async def test_can_rag_stream(): + pipeline = pgml.Pipeline( + "1", + { + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, + }, + }, + ) + collection = pgml.Collection("test_p_c_crs") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.rag_stream( + { + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "test", + "parameters": {"prompt": "query: "}, + }, + }, + }, + "document": {"keys": ["id"]}, + "limit": 5, + }, + "aggregate": {"join": "\n"}, + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT}", + "max_tokens": 10, + }, + }, + pipeline, + ) + async for c in results.stream(): + assert len(c) > 0 await collection.archive() @@ -235,15 +356,19 @@ async def test_order_documents(): @pytest.mark.asyncio async def test_transformer_pipeline(): - t = pgml.TransformerPipeline("text-generation") + t = pgml.TransformerPipeline( + "text-generation", "meta-llama/Meta-Llama-3-8B-Instruct" + ) it = await t.transform(["AI is going to"], {"max_new_tokens": 5}) assert len(it) > 0 @pytest.mark.asyncio async def test_transformer_pipeline_stream(): - t = pgml.TransformerPipeline("text-generation") - it = await t.transform_stream("AI is going to", {"max_new_tokens": 5}) + t = pgml.TransformerPipeline( + "text-generation", "meta-llama/Meta-Llama-3-8B-Instruct" + ) + it = await t.transform_stream("AI is going to", {"max_tokens": 5}) total = [] async for c in it: total.append(c) @@ -258,7 +383,7 @@ async def test_transformer_pipeline_stream(): def test_open_source_ai_create(): client = pgml.OpenSourceAI() results = client.chat_completions_create( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { "role": "system", @@ -269,6 +394,7 @@ def test_open_source_ai_create(): "content": "How many helicopters can a human eat in one sitting?", }, ], + max_tokens=10, temperature=0.85, ) assert len(results["choices"]) > 0 @@ -278,7 +404,7 @@ def test_open_source_ai_create(): async def test_open_source_ai_create_async(): client = pgml.OpenSourceAI() results = await client.chat_completions_create_async( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { "role": "system", @@ -289,6 +415,7 @@ async def test_open_source_ai_create_async(): "content": "How many helicopters can a human eat in one sitting?", }, ], + max_tokens=10, temperature=0.85, ) assert len(results["choices"]) > 0 @@ -297,7 +424,7 @@ async def test_open_source_ai_create_async(): def test_open_source_ai_create_stream(): client = pgml.OpenSourceAI() results = client.chat_completions_create_stream( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { "role": "system", @@ -311,15 +438,17 @@ def test_open_source_ai_create_stream(): temperature=0.85, n=3, ) + output = [] for c in results: - assert len(c["choices"]) > 0 + output.append(c["choices"]) + assert len(output) > 0 @pytest.mark.asyncio async def test_open_source_ai_create_stream_async(): client = pgml.OpenSourceAI() results = await client.chat_completions_create_stream_async( - "HuggingFaceH4/zephyr-7b-beta", + "meta-llama/Meta-Llama-3-8B-Instruct", [ { "role": "system", @@ -333,8 +462,10 @@ async def test_open_source_ai_create_stream_async(): temperature=0.85, n=3, ) + output = [] async for c in results: - assert len(c["choices"]) > 0 + output.append(c["choices"]) + assert len(output) > 0 ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 652bf0b8c..6a4200457 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -84,7 +84,7 @@ impl Builtins { query.bind(task.0) }; let results = query.bind(inputs).bind(args).fetch_all(&pool).await?; - let results = results.get(0).unwrap().get::(0); + let results = results.first().unwrap().get::(0); Ok(Json(results)) } } @@ -108,7 +108,10 @@ mod tests { async fn can_transform() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); - let task = Json::from(serde_json::json!("translation_en_to_fr")); + let task = Json::from(serde_json::json!({ + "task": "text-generation", + "model": "meta-llama/Meta-Llama-3-8B-Instruct" + })); let inputs = vec!["test1".to_string(), "test2".to_string()]; let results = builtins.transform(task, inputs, None).await?; assert!(results.as_array().is_some()); diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 2f1291e82..b5a34bbbd 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -21,7 +21,9 @@ use walkdir::WalkDir; use crate::debug_sqlx_query; use crate::filter_builder::FilterBuilder; use crate::pipeline::FieldAction; +use crate::rag_query_builder::build_rag_query; use crate::search_query_builder::build_search_query; +use crate::types::GeneralJsonAsyncIterator; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ get_or_initialize_pool, models, order_by_builder, @@ -34,7 +36,39 @@ use crate::{ }; #[cfg(feature = "python")] -use crate::{pipeline::PipelinePython, query_builder::QueryBuilderPython, types::JsonPython}; +use crate::{ + pipeline::PipelinePython, + query_builder::QueryBuilderPython, + types::{GeneralJsonAsyncIteratorPython, JsonPython}, +}; + +/// A RAGStream Struct +#[derive(alias)] +#[allow(dead_code)] +pub struct RAGStream { + general_json_async_iterator: Option, + sources: Json, +} + +// Required that we implement clone for our rust-bridge macros but it will not be used +impl Clone for RAGStream { + fn clone(&self) -> Self { + panic!("Cannot clone RAGStream") + } +} + +#[alias_methods(stream, sources)] +impl RAGStream { + pub fn stream(&mut self) -> anyhow::Result { + self.general_json_async_iterator + .take() + .context("Cannot call stream method more than once") + } + + pub fn sources(&self) -> anyhow::Result { + panic!("Cannot get sources yet for RAG streaming") + } +} /// Our project tasks #[derive(Debug, Clone)] @@ -127,6 +161,8 @@ pub struct Collection { add_search_event, vector_search, query, + rag, + rag_stream, exists, archive, upsert_directory, @@ -315,6 +351,9 @@ impl Collection { let mp = MultiProgress::new(); mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; + + // TODO: Revisit this. If the pipeline is added but fails to sync, then it will be "out of sync" with the documents in the table + // This is rare, but could happen pipeline .resync(project_info, pool.acquire().await?.as_mut()) .await?; @@ -1086,6 +1125,50 @@ impl Collection { .collect()) } + #[instrument(skip(self))] + pub async fn rag(&self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let (built_query, values) = build_rag_query(query.clone(), self, pipeline, false).await?; + let mut results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; + Ok(std::mem::take(&mut results[0].0)) + } + + #[instrument(skip(self))] + pub async fn rag_stream( + &self, + query: Json, + pipeline: &mut Pipeline, + ) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + + let (built_query, values) = build_rag_query(query.clone(), self, pipeline, true).await?; + + let mut transaction = pool.begin().await?; + + sqlx::query_with(&built_query, values) + .execute(&mut *transaction) + .await?; + + let s = futures::stream::try_unfold(transaction, move |mut transaction| async move { + let mut res: Vec = sqlx::query_scalar("FETCH 1 FROM c") + .fetch_all(&mut *transaction) + .await?; + if !res.is_empty() { + Ok(Some((std::mem::take(&mut res[0]), transaction))) + } else { + transaction.commit().await?; + Ok(None) + } + }); + + Ok(RAGStream { + general_json_async_iterator: Some(GeneralJsonAsyncIterator(Box::pin(s))), + sources: serde_json::json!({}).into(), + }) + } + /// Archives a [Collection] /// This will free up the name to be reused. It does not delete it. /// diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index ddfc37341..8060e23f1 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -29,6 +29,7 @@ mod pipeline; mod queries; mod query_builder; mod query_runner; +mod rag_query_builder; mod remote_embeddings; mod search_query_builder; mod single_field_pipeline; @@ -281,6 +282,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { mod tests { use super::*; use crate::types::Json; + use futures::StreamExt; use serde_json::json; fn generate_dummy_documents(count: usize) -> Vec { @@ -329,7 +331,7 @@ mod tests { #[tokio::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut pipeline = Pipeline::new("test_p_carp_58", Some(json!({}).into()))?; + let mut pipeline = Pipeline::new("0", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_carp_1", None)?; assert!(collection.database_data.is_none()); collection.add_pipeline(&mut pipeline).await?; @@ -344,8 +346,8 @@ mod tests { #[tokio::test] async fn can_add_remove_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut pipeline1 = Pipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; - let mut pipeline2 = Pipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; + let mut pipeline1 = Pipeline::new("0", Some(json!({}).into()))?; + let mut pipeline2 = Pipeline::new("1", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_carps_11", None)?; collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; @@ -354,7 +356,7 @@ mod tests { collection.remove_pipeline(&pipeline1).await?; let pipelines = collection.get_pipelines().await?; assert!(pipelines.len() == 1); - assert!(collection.get_pipeline("test_r_p_carps_1").await.is_err()); + assert!(collection.get_pipeline("0").await.is_err()); collection.archive().await?; Ok(()) } @@ -363,14 +365,17 @@ mod tests { async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_capaud_107"; - let pipeline_name = "test_r_p_capaud_6"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -382,9 +387,9 @@ mod tests { } }, "semantic_search": { - "model": "hkunlp/instructor-base", + "model": "intfloat/e5-small-v2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " } }, "full_text_search": { @@ -521,14 +526,17 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cudaap_9"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -536,7 +544,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small", + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -581,7 +592,7 @@ mod tests { #[tokio::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { - let mut pipeline = Pipeline::new("test_p_dep_1", Some(json!({}).into()))?; + let mut pipeline = Pipeline::new("0", Some(json!({}).into()))?; let mut collection = Collection::new("test_r_c_dep_1", None)?; collection.add_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; @@ -601,14 +612,17 @@ mod tests { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cudaep_43"; let mut collection = Collection::new(collection_name, None)?; - let pipeline_name = "test_r_p_cudaep_9"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } } }) @@ -646,14 +660,17 @@ mod tests { collection .upsert_documents(documents[..2].to_owned(), None) .await?; - let pipeline_name1 = "test_r_p_rpdt1_0"; + let pipeline_name1 = "0"; let mut pipeline = Pipeline::new( pipeline_name1, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -661,7 +678,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small", + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -697,14 +717,17 @@ mod tests { .await?; assert!(tsvectors.len() == 8); - let pipeline_name2 = "test_r_p_rpdt2_0"; + let pipeline_name2 = "1"; let mut pipeline = Pipeline::new( pipeline_name2, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -712,7 +735,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small", + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -792,16 +818,19 @@ mod tests { #[tokio::test] async fn pipeline_sync_status() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_pss_5"; + let collection_name = "test_r_c_pss_6"; let mut collection = Collection::new(collection_name, None)?; - let pipeline_name = "test_r_p_pss_0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -902,14 +931,17 @@ mod tests { internal_init_logger(None, None).ok(); let collection_name = "test_r_c_cschpfp_4"; let mut collection = Collection::new(collection_name, None)?; - let pipeline_name = "test_r_p_cschpfp_0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small", + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + }, "hnsw": { "m": 100, "ef_construction": 200 @@ -948,18 +980,21 @@ mod tests { #[tokio::test] async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cswle_121"; + let collection_name = "test_r_c_cswle_123"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cswle_9"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -970,12 +1005,15 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "semantic_search": { - "model": "hkunlp/instructor-base", + "model": "intfloat/e5-small-v2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " } }, "full_text_search": { @@ -984,7 +1022,10 @@ mod tests { }, "notes": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } } }) @@ -1007,17 +1048,23 @@ mod tests { "semantic_search": { "title": { "query": "This is a test", + "parameters": { + "prompt": "query: ", + }, "boost": 2.0 }, "body": { "query": "This is the body test", "parameters": { - "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", + "prompt": "query: ", }, "boost": 1.01 }, "notes": { "query": "This is the notes test", + "parameters": { + "prompt": "query: ", + }, "boost": 1.01 } }, @@ -1039,7 +1086,7 @@ mod tests { .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![9, 2, 7, 8, 3]); + assert_eq!(ids, vec![9, 3, 4, 7, 5]); let pool = get_or_initialize_pool(&None).await?; @@ -1064,7 +1111,7 @@ mod tests { // Document ids are 1 based in the db not 0 based like they are here assert_eq!( search_results.iter().map(|sr| sr.2).collect::>(), - vec![10, 3, 8, 9, 4] + vec![10, 4, 5, 8, 6] ); let event = json!({"clicked": true}); @@ -1097,14 +1144,17 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cswre_8"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -1138,6 +1188,9 @@ mod tests { "semantic_search": { "title": { "query": "This is a test", + "parameters": { + "prompt": "query: ", + }, "boost": 2.0 }, "body": { @@ -1163,7 +1216,7 @@ mod tests { .iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![2, 3, 7, 4, 8]); + assert_eq!(ids, vec![3, 9, 4, 7, 5]); collection.archive().await?; Ok(()) } @@ -1179,16 +1232,16 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cvswle_0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "hkunlp/instructor-base", + "model": "intfloat/e5-small-v2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " } }, "full_text_search": { @@ -1200,7 +1253,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, }, }) @@ -1216,13 +1272,16 @@ mod tests { "title": { "query": "Test document: 2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " }, "full_text_filter": "test", "boost": 1.2 }, "body": { "query": "Test document: 2", + "parameters": { + "prompt": "passage: " + }, "boost": 1.0 }, }, @@ -1232,6 +1291,11 @@ mod tests { } } }, + "document": { + "keys": [ + "id" + ] + }, "limit": 5 }) .into(), @@ -1242,7 +1306,7 @@ mod tests { .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![8, 4, 7, 6, 9]); + assert_eq!(ids, vec![4, 8, 5, 6, 9]); collection.archive().await?; Ok(()) } @@ -1254,14 +1318,17 @@ mod tests { let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; - let pipeline_name = "test_r_p_cvswre_0"; + let pipeline_name = "0"; let mut pipeline = Pipeline::new( pipeline_name, Some( json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -1289,7 +1356,10 @@ mod tests { "fields": { "title": { "full_text_filter": "test", - "query": "Test document: 2" + "query": "Test document: 2", + "parameters": { + "prompt": "passage: " + }, }, "body": { "query": "Test document: 2" @@ -1311,7 +1381,7 @@ mod tests { .into_iter() .map(|r| r["document"]["id"].as_u64().unwrap()) .collect(); - assert_eq!(ids, vec![4, 5, 6, 7, 9]); + assert_eq!(ids, vec![4, 8, 5, 6, 9]); collection.archive().await?; Ok(()) } @@ -1321,12 +1391,15 @@ mod tests { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test r_c_cvswqb_7", None)?; let mut pipeline = Pipeline::new( - "test_r_p_cvswqb_0", + "0", Some( json!({ "text": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -1342,7 +1415,16 @@ mod tests { collection.add_pipeline(&mut pipeline).await?; let results = collection .query() - .vector_recall("test query", &pipeline, None) + .vector_recall( + "test query", + &pipeline, + Some( + json!({ + "prompt": "query: " + }) + .into(), + ), + ) .limit(3) .filter( json!({ @@ -1369,6 +1451,108 @@ mod tests { Ok(()) } + #[tokio::test] + async fn can_vector_search_with_local_embeddings_and_specify_document_keys( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_cvswleasdk_0"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .vector_search( + json!({ + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + }, + }, + }, + }, + "document": { + "keys": [ + "id", + "title" + ] + }, + "limit": 5 + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("id")); + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("title")); + assert!(!results[0]["document"] + .as_object() + .unwrap() + .contains_key("body")); + + let results = collection + .vector_search( + json!({ + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + }, + }, + }, + }, + "limit": 5 + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("id")); + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("title")); + assert!(results[0]["document"] + .as_object() + .unwrap() + .contains_key("body")); + collection.archive().await?; + Ok(()) + } + /////////////////////////////// // Working With Documents ///// /////////////////////////////// @@ -1925,7 +2109,10 @@ mod tests { json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -1936,7 +2123,10 @@ mod tests { "model": "recursive_character" }, "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } }, "full_text_search": { "configuration": "english" @@ -1944,7 +2134,10 @@ mod tests { }, "notes": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } } }) @@ -1959,4 +2152,630 @@ mod tests { collection.archive().await?; Ok(()) } + + /////////////////////////////// + // RAG //////////////////////// + /////////////////////////////// + + #[tokio::test] + async fn can_rag_with_local_embeddings() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_crwle_1"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + + // Single variable test + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 5 + }, + "aggregate": { + "join": "\n" + } + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT}", + "max_tokens": 10, + } + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); + + // Multi-variable test + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CONTEXT2": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 3", + "parameters": { + "prompt": "query: " + } + }, + } + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT} AND {CONTEXT2}", + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); + + // Chat test + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); + + // Multi-variable chat test + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CONTEXT2": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 3", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + } + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT} AND {CONTEXT2}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); + + // Chat test with custom SQL query + let results = collection + .rag( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "boost": 1.0, + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CUSTOM": { + "sql": "SELECT 'test'" + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT} - {CUSTOM}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + assert!(!results["rag"].as_array().unwrap()[0] + .as_str() + .unwrap() + .is_empty()); + + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn can_rag_stream_with_local_embeddings() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_crswle_1"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; + + // Single variable test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 5 + }, + "aggregate": { + "join": "\n" + } + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT}", + "max_tokens": 10, + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + // Multi-variable test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CONTEXT2": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "completion": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "prompt": "Some text with {CONTEXT} - {CONTEXT2}", + "max_tokens": 10, + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + // Single variable chat test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 5 + }, + "aggregate": { + "join": "\n" + } + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + // Multi-variable chat test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CONTEXT2": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT} - {CONTEXT2}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + // Raw SQL test + let mut results = collection + .rag_stream( + json!({ + "CONTEXT": { + "vector_search": { + "query": { + "fields": { + "body": { + "query": "Test document: 2", + "parameters": { + "prompt": "query: " + } + }, + }, + }, + "document": { + "keys": [ + "id" + ] + }, + "limit": 2 + }, + "aggregate": { + "join": "\n" + } + }, + "CUSTOM": { + "sql": "SELECT 'test'" + }, + "chat": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a friendly and helpful chatbot" + }, + { + "role": "user", + "content": "Some text with {CONTEXT} - {CUSTOM}", + } + ], + "max_tokens": 10 + } + }) + .into(), + &mut pipeline, + ) + .await?; + let mut stream = results.stream()?; + while let Some(o) = stream.next().await { + o?; + } + + collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index e21397a31..f7348ad11 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -23,6 +23,15 @@ fn try_model_nice_name_to_model_name_and_parameters( model_name: &str, ) -> Option<(&'static str, Json)> { match model_name { + "meta-llama/Meta-Llama-3-8B-Instruct" => Some(( + "meta-llama/Meta-Llama-3-8B-Instruct", + serde_json::json!({ + "task": "conversationa", + "model": "meta-llama/Meta-Llama-3-8B-Instruct" + }) + .into(), + )), + "mistralai/Mistral-7B-Instruct-v0.1" => Some(( "mistralai/Mistral-7B-Instruct-v0.1", serde_json::json!({ @@ -201,7 +210,7 @@ impl OpenSourceAI { Ok(( TransformerPipeline::new( "conversational", - Some(model_name.to_string()), + model_name, Some(model.clone()), self.database_url.clone(), ), @@ -221,7 +230,7 @@ mistralai/Mistral-7B-v0.1 Ok(( TransformerPipeline::new( "conversational", - Some(real_model_name.to_string()), + real_model_name, Some(parameters.clone()), self.database_url.clone(), ), @@ -252,7 +261,9 @@ mistralai/Mistral-7B-v0.1 let md5_digest = md5::compute(to_hash.as_bytes()); let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; - let mut args = serde_json::json!({ "max_new_tokens": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }); + // TODO: Add n + + let mut args = serde_json::json!({ "max_tokens": max_tokens, "temperature": temperature }); if let Some(t) = chat_template .or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string())) { @@ -340,7 +351,9 @@ mistralai/Mistral-7B-v0.1 let md5_digest = md5::compute(to_hash.as_bytes()); let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; - let mut args = serde_json::json!({ "max_new_tokens": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }); + // TODO: Add n + + let mut args = serde_json::json!({ "max_tokens": max_tokens, "temperature": temperature }); if let Some(t) = chat_template .or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string())) { @@ -420,7 +433,7 @@ mod tests { #[test] fn can_open_source_ai_create() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let results = client.chat_completions_create(Json::from_serializable("HuggingFaceH4/zephyr-7b-beta"), vec![ + let results = client.chat_completions_create(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), ], Some(10), None, Some(3), None)?; @@ -431,7 +444,7 @@ mod tests { #[sqlx::test] fn can_open_source_ai_create_async() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let results = client.chat_completions_create_async(Json::from_serializable("HuggingFaceH4/zephyr-7b-beta"), vec![ + let results = client.chat_completions_create_async(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), ], Some(10), None, Some(3), None).await?; @@ -442,7 +455,7 @@ mod tests { #[sqlx::test] fn can_open_source_ai_create_stream_async() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("HuggingFaceH4/zephyr-7b-beta"), vec![ + let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), ], Some(10), None, Some(3), None).await?; @@ -455,7 +468,7 @@ mod tests { #[test] fn can_open_source_ai_create_stream() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let iterator = client.chat_completions_create_stream(Json::from_serializable("HuggingFaceH4/zephyr-7b-beta"), vec![ + let iterator = client.chat_completions_create_stream(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), ], Some(10), None, Some(3), None)?; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 4250f9db1..ca496d3a0 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -71,7 +71,7 @@ impl QueryBuilder { self.pipeline = Some(pipeline.clone()); self.query["query"]["fields"]["text"]["query"] = json!(query); if let Some(query_parameters) = query_parameters { - self.query["query"]["fields"]["text"]["model_parameters"] = query_parameters.0; + self.query["query"]["fields"]["text"]["parameters"] = query_parameters.0; } self } diff --git a/pgml-sdks/pgml/src/rag_query_builder.rs b/pgml-sdks/pgml/src/rag_query_builder.rs new file mode 100644 index 000000000..4f4279260 --- /dev/null +++ b/pgml-sdks/pgml/src/rag_query_builder.rs @@ -0,0 +1,375 @@ +use sea_query::{ + Alias, CommonTableExpression, Expr, PostgresQueryBuilder, Query, SimpleExpr, WithClause, +}; +use sea_query_binder::{SqlxBinder, SqlxValues}; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, FromInto}; +use std::collections::HashMap; + +use crate::{ + collection::Collection, + debug_sea_query, models, + pipeline::Pipeline, + types::{CustomU64Convertor, IntoTableNameAndSchema, Json}, + vector_search_query_builder::{build_sqlx_query, ValidQuery}, +}; + +const fn default_temperature() -> f32 { + 1. +} +const fn default_max_tokens() -> u64 { + 1000000 +} +const fn default_top_p() -> f32 { + 1. +} +const fn default_presence_penalty() -> f32 { + 0. +} + +#[allow(dead_code)] +const fn default_n() -> u64 { + 0 +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidAggregate { + join: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct VectorSearch { + vector_search: ValidQuery, + aggregate: ValidAggregate, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct RawSQL { + sql: String, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +#[serde(untagged)] +enum ValidVariable { + VectorSearch(VectorSearch), + RawSQL(RawSQL), +} + +#[serde_as] +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidCompletion { + model: String, + prompt: String, + #[serde(default = "default_temperature")] + temperature: f32, + // Need this when coming from JavaScript as everything is an f64 from JS + #[serde(default = "default_max_tokens")] + #[serde_as(as = "FromInto")] + max_tokens: u64, + #[serde(default = "default_top_p")] + top_p: f32, + #[serde(default = "default_presence_penalty")] + presence_penalty: f32, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +struct ChatMessage { + role: String, + content: String, +} + +#[serde_as] +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(deny_unknown_fields)] +struct ValidChat { + model: String, + messages: Vec, + #[serde(default = "default_temperature")] + temperature: f32, + // Need this when coming from JavaScript as everything is an f64 from JS + #[serde(default = "default_max_tokens")] + #[serde_as(as = "FromInto")] + max_tokens: u64, + #[serde(default = "default_top_p")] + top_p: f32, + #[serde(default = "default_presence_penalty")] + presence_penalty: f32, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +struct ValidRAG { + completion: Option, + chat: Option, + #[serde(flatten)] + variables: HashMap, +} + +#[derive(Debug, Clone)] +struct CompletionRAG { + completion: ValidCompletion, + prompt_expr: SimpleExpr, +} + +#[derive(Debug, Clone)] +struct FormattedMessage { + content_expr: SimpleExpr, + message: ChatMessage, +} + +#[derive(Debug, Clone)] +struct ChatRAG { + chat: ValidChat, + messages: Vec, +} + +#[derive(Debug, Clone)] +enum ValidRAGWrapper { + Completion(CompletionRAG), + Chat(ChatRAG), +} + +impl TryFrom for ValidRAGWrapper { + type Error = anyhow::Error; + + fn try_from(rag: ValidRAG) -> Result { + match (rag.completion, rag.chat) { + (None, None) => anyhow::bail!("Must provide either `completion` or `chat`"), + (None, Some(chat)) => Ok(ValidRAGWrapper::Chat(ChatRAG { + messages: chat + .messages + .iter() + .map(|c| FormattedMessage { + content_expr: Expr::cust_with_values("$1", [c.content.clone()]), + message: c.clone(), + }) + .collect(), + chat, + })), + (Some(completion), None) => Ok(ValidRAGWrapper::Completion(CompletionRAG { + prompt_expr: Expr::cust_with_values("$1", [completion.prompt.clone()]), + completion, + })), + (Some(_), Some(_)) => anyhow::bail!("Cannot provide both `completion` and `chat`"), + } + } +} + +pub async fn build_rag_query( + query: Json, + collection: &Collection, + pipeline: &Pipeline, + stream: bool, +) -> anyhow::Result<(String, SqlxValues)> { + let rag: ValidRAG = serde_json::from_value(query.0)?; + + // Convert it to something more convenient to work with + let mut rag_f: ValidRAGWrapper = rag.clone().try_into()?; + + // Confirm that all variables are uppercase + if !rag.variables.keys().all(|f| &f.to_uppercase() == f) { + anyhow::bail!("All variables in RAG query must be uppercase") + } + + let mut final_query = Query::select(); + + let mut with_clause = WithClause::new(); + let pipeline_table = format!("{}.pipelines", collection.name); + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + with_clause.cte(pipeline_cte); + + let mut json_objects = Vec::new(); + + for (var_name, var_query) in rag.variables.iter() { + let (var_replace_select, var_source) = match var_query { + ValidVariable::VectorSearch(vector_search) => { + let (sqlx_select_statement, sqlx_ctes) = build_sqlx_query( + serde_json::json!(vector_search.vector_search).into(), + collection, + pipeline, + false, + Some(var_name), + ) + .await?; + for cte in sqlx_ctes { + with_clause.cte(cte); + } + let mut sqlx_query = CommonTableExpression::from_select(sqlx_select_statement); + sqlx_query.table_name(Alias::new(var_name)); + with_clause.cte(sqlx_query); + ( + format!( + r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#, + vector_search.aggregate.join + ), + format!( + r#"(SELECT json_agg(jsonb_build_object('chunk', chunk, 'document', document, 'score', score)) FROM "{var_name}")"# + ), + ) + } + ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)), + }; + + if !stream { + json_objects.push(format!("'{var_name}', {var_source}")); + } + + match &mut rag_f { + ValidRAGWrapper::Completion(completion) => { + completion.prompt_expr = Expr::cust_with_expr( + format!("replace($1, '{{{var_name}}}', {var_replace_select})"), + completion.prompt_expr.clone(), + ); + } + ValidRAGWrapper::Chat(chat) => { + for message in &mut chat.messages { + if message.message.content.contains(&format!("{{{var_name}}}")) { + message.content_expr = Expr::cust_with_expr( + format!("replace($1, '{{{var_name}}}', {var_replace_select})"), + message.content_expr.clone(), + ) + } + } + } + } + } + + let transform_expr = match rag_f { + ValidRAGWrapper::Completion(completion) => { + let mut args = serde_json::json!(completion.completion); + args.as_object_mut().unwrap().remove("model"); + args.as_object_mut().unwrap().remove("prompt"); + let args_expr = Expr::cust_with_values("$1", [args]); + + let task_expr = Expr::cust_with_values( + "$1", + [serde_json::json!({ + "task": "text-generation", + "model": completion.completion.model + })], + ); + + if stream { + Expr::cust_with_exprs( + " + pgml.transform_stream( + task => $1, + input => $2, + args => $3 + ) + ", + [task_expr, completion.prompt_expr, args_expr], + ) + } else { + Expr::cust_with_exprs( + " + pgml.transform( + task => $1, + inputs => zzzzz_zzzzz_start $2 zzzzz_zzzzz_end, + args => $3 + ) + ", + [task_expr, completion.prompt_expr, args_expr], + ) + } + } + ValidRAGWrapper::Chat(chat) => { + let mut args = serde_json::json!(chat.chat); + args.as_object_mut().unwrap().remove("model"); + args.as_object_mut().unwrap().remove("messages"); + let args_expr = Expr::cust_with_values("$1", [args]); + + let task_expr = Expr::cust_with_values( + "$1", + [serde_json::json!({ + "task": "conversational", + "model": chat.chat.model + })], + ); + + let dollar_string = chat + .messages + .iter() + .enumerate() + .map(|(i, _c)| format!("${}", i + 1)) + .collect::>() + .join(", "); + let prompt_exprs = chat.messages.into_iter().map(|cm| { + let role_expr = Expr::cust_with_values("$1", [cm.message.role]); + Expr::cust_with_exprs( + "jsonb_build_object('role', $1, 'content', $2)", + [role_expr, cm.content_expr], + ) + }); + let inputs_expr = Expr::cust_with_exprs(format!("{dollar_string}"), prompt_exprs); + + if stream { + Expr::cust_with_exprs( + " + pgml.transform_stream( + task => $1, + inputs => zzzzz_zzzzz_start $2 zzzzz_zzzzz_end, + args => $3 + ) + ", + [task_expr, inputs_expr, args_expr], + ) + } else { + Expr::cust_with_exprs( + " + pgml.transform( + task => $1, + inputs => zzzzz_zzzzz_start $2 zzzzz_zzzzz_end, + args => $3 + ) + ", + [task_expr, inputs_expr, args_expr], + ) + } + } + }; + + if stream { + final_query.expr(transform_expr); + } else { + let sources = format!(",'sources', jsonb_build_object({})", json_objects.join(",")); + final_query.expr(Expr::cust_with_expr( + format!( + r#" + jsonb_build_object( + 'rag', + $1{sources} + ) + "# + ), + transform_expr, + )); + } + + let (sql, values) = final_query + .with(with_clause) + .build_sqlx(PostgresQueryBuilder); + + let sql = sql.replace("zzzzz_zzzzz_start", "ARRAY["); + let sql = sql.replace("zzzzz_zzzzz_end", "]"); + + let sql = if stream { + format!("DECLARE c CURSOR FOR {sql}") + } else { + sql + }; + + debug_sea_query!(RAG, sql, values); + + Ok((sql, values)) +} diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs index 3fb6a0db4..e76371541 100644 --- a/pgml-sdks/pgml/src/search_query_builder.rs +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -1,12 +1,12 @@ use anyhow::Context; -use serde::Deserialize; -use std::collections::HashMap; - use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, SimpleExpr, WithClause, }; use sea_query_binder::{SqlxBinder, SqlxValues}; +use serde::Deserialize; +use serde_with::{serde_as, FromInto}; +use std::collections::HashMap; use crate::{ collection::Collection, @@ -16,7 +16,7 @@ use crate::{ models, pipeline::Pipeline, remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden}, + types::{CustomU64Convertor, IntoTableNameAndSchema, Json, SIden}, }; #[derive(Debug, Deserialize)] @@ -42,13 +42,19 @@ struct ValidQueryActions { filter: Option, } +const fn default_limit() -> u64 { + 10 +} + +#[serde_as] #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] struct ValidQuery { query: ValidQueryActions, // Need this when coming from JavaScript as everything is an f64 from JS - #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] - limit: Option, + #[serde(default = "default_limit")] + #[serde_as(as = "FromInto")] + limit: u64, } pub async fn build_search_query( @@ -57,7 +63,7 @@ pub async fn build_search_query( pipeline: &Pipeline, ) -> anyhow::Result<(String, SqlxValues)> { let valid_query: ValidQuery = serde_json::from_value(query.0.clone())?; - let limit = valid_query.limit.unwrap_or(10); + let limit = valid_query.limit; let pipeline_table = format!("{}.pipelines", collection.name); let documents_table = format!("{}.documents", collection.name); diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 7a6141675..f7911a56d 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -10,7 +10,7 @@ pub struct TransformerPipeline { database_url: Option, } -use crate::types::GeneralJsonAsyncIterator; +use crate::types::{CustomU64Convertor, GeneralJsonAsyncIterator}; use crate::{get_or_initialize_pool, types::Json}; #[cfg(feature = "python")] @@ -25,22 +25,18 @@ impl TransformerPipeline { /// * `model` - The model to use /// * `args` - The arguments to pass to the task /// * `database_url` - The database url to use. If None, the `PGML_DATABASE_URL` environment variable will be used - pub fn new( - task: &str, - model: Option, - args: Option, - database_url: Option, - ) -> Self { + pub fn new(task: &str, model: &str, args: Option, database_url: Option) -> Self { let mut args = args.unwrap_or_default(); let a = args.as_object_mut().expect("args must be an object"); a.insert("task".to_string(), task.to_string().into()); - if let Some(m) = model { - a.insert("model".to_string(), m.into()); - } + a.insert("model".to_string(), model.into()); + // We must convert any floating point values to integers or our extension will get angry - if let Some(v) = a.remove("gpu_layers") { - let int_v = v.as_f64().expect("gpu_layers must be an integer") as i64; - a.insert("gpu_layers".to_string(), int_v.into()); + for field in ["gpu_layers"] { + if let Some(v) = a.remove(field) { + let x: u64 = CustomU64Convertor(v).into(); + a.insert(field.to_string(), x.into()); + } } Self { @@ -57,7 +53,21 @@ impl TransformerPipeline { #[instrument(skip(self))] pub async fn transform(&self, inputs: Vec, args: Option) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; - let args = args.unwrap_or_default(); + let mut args = args.unwrap_or_default(); + let a = args.as_object_mut().context("args must be an object")?; + + // Backwards compatible + if let Some(x) = a.remove("max_new_tokens") { + a.insert("max_tokens".to_string(), x); + } + + // We must convert any floating point values to integers or our extension will get angry + for field in ["max_tokens", "n"] { + if let Some(v) = a.remove(field) { + let x: u64 = CustomU64Convertor(v).into(); + a.insert(field.to_string(), x.into()); + } + } // We set the task in the new constructor so we can unwrap here let results = if self.task["task"].as_str().unwrap() == "conversational" { @@ -85,7 +95,7 @@ impl TransformerPipeline { .fetch_all(&pool) .await? }; - let results = results.get(0).unwrap().get::(0); + let results = results.first().unwrap().get::(0); Ok(Json(results)) } @@ -100,9 +110,24 @@ impl TransformerPipeline { batch_size: Option, ) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; - let args = args.unwrap_or_default(); let batch_size = batch_size.unwrap_or(1); + let mut args = args.unwrap_or_default(); + let a = args.as_object_mut().context("args must be an object")?; + + // Backwards compatible + if let Some(x) = a.remove("max_new_tokens") { + a.insert("max_tokens".to_string(), x); + } + + // We must convert any floating point values to integers or our extension will get angry + for field in ["max_tokens", "n"] { + if let Some(v) = a.remove(field) { + let x: u64 = CustomU64Convertor(v).into(); + a.insert(field.to_string(), x.into()); + } + } + let mut transaction = pool.begin().await?; // We set the task in the new constructor so we can unwrap here if self.task["task"].as_str().unwrap() == "conversational" { @@ -178,29 +203,7 @@ mod tests { #[sqlx::test] async fn transformer_pipeline_can_transform() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let t = TransformerPipeline::new( - "translation_en_to_fr", - Some("t5-base".to_string()), - None, - None, - ); - let results = t - .transform( - vec![ - serde_json::Value::String("How are you doing today?".to_string()).into(), - serde_json::Value::String("How are you doing today?".to_string()).into(), - ], - None, - ) - .await?; - assert!(results.as_array().is_some()); - Ok(()) - } - - #[sqlx::test] - async fn transformer_pipeline_can_transform_with_default_model() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let t = TransformerPipeline::new("translation_en_to_fr", None, None, None); + let t = TransformerPipeline::new("translation_en_to_fr", "t5-base", None, None); let results = t .transform( vec![ @@ -219,13 +222,8 @@ mod tests { internal_init_logger(None, None).ok(); let t = TransformerPipeline::new( "text-generation", - Some("TheBloke/zephyr-7B-beta-GPTQ".to_string()), - Some( - serde_json::json!({ - "model_type": "mistral", "revision": "main", "device_map": "auto" - }) - .into(), - ), + "meta-llama/Meta-Llama-3-8B-Instruct", + None, None, ); let mut stream = t diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index 86cd4ea2c..2d47de710 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -4,8 +4,32 @@ use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; use std::ops::{Deref, DerefMut}; +#[derive(Serialize, Deserialize)] +pub struct CustomU64Convertor(pub Value); + +impl From for CustomU64Convertor { + fn from(value: u64) -> Self { + Self(json!(value)) + } +} + +impl From for u64 { + fn from(value: CustomU64Convertor) -> Self { + if value.0.is_f64() { + value.0.as_f64().unwrap() as u64 + } else if value.0.is_i64() { + value.0.as_i64().unwrap() as u64 + } else if value.0.is_u64() { + value.0.as_u64().unwrap() + } else { + panic!("Cannot convert value into u64") + } + } +} + /// A wrapper around `serde_json::Value` #[derive(alias_manual, sqlx::Type, Debug, Clone, Deserialize, PartialEq, Eq)] #[sqlx(transparent)] diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index c1d447bb0..47718231f 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -5,10 +5,6 @@ use std::fs; use std::path::Path; use std::time::Duration; -use serde::de::{self, Visitor}; -use serde::Deserializer; -use std::fmt; - /// A more type flexible version of format! #[macro_export] macro_rules! query_builder { @@ -100,40 +96,3 @@ pub fn get_file_contents(path: &Path) -> anyhow::Result { .with_context(|| format!("Error reading file: {}", path.display()))?, }) } - -struct U64Visitor; -impl<'de> Visitor<'de> for U64Visitor { - type Value = u64; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("some number") - } - - fn visit_i32(self, value: i32) -> Result - where - E: de::Error, - { - Ok(value as u64) - } - - fn visit_u64(self, value: u64) -> Result - where - E: de::Error, - { - Ok(value) - } - - fn visit_f64(self, value: f64) -> Result - where - E: de::Error, - { - Ok(value as u64) - } -} - -pub fn deserialize_u64<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - deserializer.deserialize_u64(U64Visitor).map(Some) -} diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs index 6c0381b19..1b5976eba 100644 --- a/pgml-sdks/pgml/src/vector_search_query_builder.rs +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -1,12 +1,12 @@ use anyhow::Context; -use serde::Deserialize; -use std::collections::HashMap; - use sea_query::{ Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, - WithClause, + SelectStatement, WithClause, }; use sea_query_binder::{SqlxBinder, SqlxValues}; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, FromInto}; +use std::collections::HashMap; use crate::{ collection::Collection, @@ -16,10 +16,10 @@ use crate::{ models, pipeline::Pipeline, remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden}, + types::{CustomU64Convertor, IntoTableNameAndSchema, Json, SIden}, }; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] struct ValidField { query: String, @@ -28,31 +28,49 @@ struct ValidField { boost: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] struct ValidQueryActions { fields: Option>, filter: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] -struct ValidQuery { +struct ValidDocument { + keys: Option>, +} + +const fn default_limit() -> u64 { + 10 +} + +#[serde_as] +#[derive(Debug, Deserialize, Serialize, Clone)] +// #[serde(deny_unknown_fields)] +pub struct ValidQuery { query: ValidQueryActions, // Need this when coming from JavaScript as everything is an f64 from JS - #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] - limit: Option, + #[serde(default = "default_limit")] + #[serde_as(as = "FromInto")] + limit: u64, + // Document related items + document: Option, } -pub async fn build_vector_search_query( +pub async fn build_sqlx_query( query: Json, collection: &Collection, pipeline: &Pipeline, -) -> anyhow::Result<(String, SqlxValues)> { + include_pipeline_table_cte: bool, + prefix: Option<&str>, +) -> anyhow::Result<(SelectStatement, Vec)> { let valid_query: ValidQuery = serde_json::from_value(query.0)?; - let limit = valid_query.limit.unwrap_or(10); + let limit = valid_query.limit; let fields = valid_query.query.fields.unwrap_or_default(); + let prefix = prefix.unwrap_or(""); + if fields.is_empty() { anyhow::bail!("at least one field is required to search over") } @@ -61,16 +79,18 @@ pub async fn build_vector_search_query( let documents_table = format!("{}.documents", collection.name); let mut queries = Vec::new(); - let mut with_clause = WithClause::new(); + let mut ctes = Vec::new(); - let mut pipeline_cte = Query::select(); - pipeline_cte - .from(pipeline_table.to_table_tuple()) - .columns([models::PipelineIden::Schema]) - .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); - let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); - pipeline_cte.table_name(Alias::new("pipeline")); - with_clause.cte(pipeline_cte); + if include_pipeline_table_cte { + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + ctes.push(pipeline_cte); + } for (key, vf) in fields { let model_runtime = pipeline @@ -116,15 +136,15 @@ pub async fn build_vector_search_query( Alias::new("embedding"), ); let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); - with_clause.cte(embedding_cte); + embedding_cte.table_name(Alias::new(format!("{prefix}{key}_embedding"))); + ctes.push(embedding_cte); query .expr(Expr::cust(format!( - r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{prefix}{key}_embedding")::vector)) * {boost} AS score"# ))) .order_by_expr(Expr::cust(format!( - r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + r#"embeddings.embedding <=> (SELECT embedding FROM "{prefix}{key}_embedding")::vector"# )), Order::Asc); } ModelRuntime::OpenAI => { @@ -155,7 +175,9 @@ pub async fn build_vector_search_query( // Build the score CTE query .expr(Expr::cust_with_values( - r#"(1 - (embeddings.embedding <=> $1::vector)) {boost} AS score"#, + format!( + r#"(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"# + ), [embedding.clone()], )) .order_by_expr( @@ -214,12 +236,28 @@ pub async fn build_vector_search_query( } let mut wrapper_query = Query::select(); + + // Allows filtering on which keys to return with the document + if let Some(document) = &valid_query.document { + if let Some(keys) = &document.keys { + let document_queries = keys + .iter() + .map(|key| format!("'{key}', document #> '{{{key}}}'")) + .collect::>() + .join(","); + wrapper_query.expr_as( + Expr::cust(format!("jsonb_build_object({document_queries})")), + Alias::new("document"), + ); + } else { + wrapper_query.column(SIden::Str("document")); + } + } else { + wrapper_query.column(SIden::Str("document")); + } + wrapper_query - .columns([ - SIden::Str("document"), - SIden::Str("chunk"), - SIden::Str("score"), - ]) + .columns([SIden::Str("chunk"), SIden::Str("score")]) .from_subquery(query, Alias::new("s")); queries.push(wrapper_query); @@ -236,6 +274,19 @@ pub async fn build_vector_search_query( .order_by(SIden::Str("score"), Order::Desc) .limit(limit); + Ok((query, ctes)) +} + +pub async fn build_vector_search_query( + query: Json, + collection: &Collection, + pipeline: &Pipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let (query, ctes) = build_sqlx_query(query, collection, pipeline, true, None).await?; + let mut with_clause = WithClause::new(); + for cte in ctes { + with_clause.cte(cte); + } let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); debug_sea_query!(VECTOR_SEARCH, sql, values); diff --git a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs index a453bf14f..1b472e899 100644 --- a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs +++ b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs @@ -72,7 +72,7 @@ pub fn generate_python_alias(parsed: DeriveInput) -> proc_macro::TokenStream { let expanded = quote! { #[cfg(feature = "python")] #[pyo3::pyclass(name = #wrapped_type_name)] - #[derive(Clone, Debug)] + #[derive(Clone)] pub struct #name_ident { pub wrapped: std::boxed::Box<#wrapped_type_ident> } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy