diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 168d54c59..f86800a25 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -18,7 +18,7 @@ jobs: runs-on: ${{ matrix.os }} defaults: run: - working-directory: pgml-sdks/rust/pgml/javascript + working-directory: pgml-sdks/pgml/javascript steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 @@ -34,14 +34,11 @@ jobs: run: | npm i npm run build-release - mv index.node ${{ matrix.neon-out-name }} - - name: Display output files - run: ls -R - name: Upload built .node file uses: actions/upload-artifact@v3 with: name: node-artifacts - path: pgml-sdks/rust/pgml/javascript/${{ matrix.neon-out-name }} + path: pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} retention-days: 1 # publish-javascript-sdk: # needs: build-javascript-sdk diff --git a/.github/workflows/python-sdk.yml b/.github/workflows/python-sdk.yml index fc562778b..e8d042fff 100644 --- a/.github/workflows/python-sdk.yml +++ b/.github/workflows/python-sdk.yml @@ -14,7 +14,7 @@ jobs: runs-on: ${{ matrix.os }} defaults: run: - working-directory: pgml-sdks/rust/pgml + working-directory: pgml-sdks/pgml steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 @@ -62,7 +62,7 @@ jobs: runs-on: macos-latest defaults: run: - working-directory: pgml-sdks/rust/pgml + working-directory: pgml-sdks/pgml steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 @@ -101,7 +101,7 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] defaults: run: - working-directory: pgml-sdks\rust\pgml + working-directory: pgml-sdks\pgml steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index db7b5d11f..dc5b7dada 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -2,17 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "ahash" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] - [[package]] name = "ahash" version = "0.8.3" @@ -20,6 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", + "getrandom", "once_cell", "version_check", ] @@ -73,9 +63,9 @@ dependencies = [ [[package]] name = "atoi" -version = "1.0.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" dependencies = [ "num-traits", ] @@ -88,15 +78,15 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.13.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] -name = "base64" -version = "0.21.2" +name = "base64ct" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bitflags" @@ -104,6 +94,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +dependencies = [ + "serde", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -171,6 +170,12 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "const-oid" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" + [[package]] name = "core-foundation" version = "0.9.3" @@ -275,6 +280,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "der" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -282,30 +298,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - [[package]] name = "dotenvy" version = "0.15.7" @@ -317,6 +314,9 @@ name = "either" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -333,6 +333,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.1" @@ -354,6 +360,17 @@ dependencies = [ "libc", ] +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -369,6 +386,18 @@ dependencies = [ "instant", ] +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "pin-project", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -443,13 +472,13 @@ dependencies = [ [[package]] name = "futures-intrusive" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.11.2", + "parking_lot", ] [[package]] @@ -532,7 +561,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -551,7 +580,7 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" dependencies = [ - "ahash 0.8.3", + "ahash", "allocator-api2", ] @@ -612,6 +641,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" +dependencies = [ + "windows-sys 0.48.0", +] + [[package]] name = "http" version = "0.2.9" @@ -732,6 +770,16 @@ dependencies = [ "hashbrown 0.12.3", ] +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", +] + [[package]] name = "indicatif" version = "0.17.6" @@ -751,6 +799,17 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +[[package]] +name = "inherent" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "instant" version = "0.1.12" @@ -806,6 +865,9 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin 0.5.2", +] [[package]] name = "libc" @@ -823,6 +885,23 @@ dependencies = [ "winapi", ] +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + +[[package]] +name = "libsqlite3-sys" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.3.8" @@ -977,6 +1056,44 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -984,6 +1101,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -1014,7 +1132,7 @@ version = "0.10.55" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "foreign-types", "libc", @@ -1040,6 +1158,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-src" +version = "111.26.0+1.1.1u" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc62c9f12b22b8f5208c23a7200a442b2e5999f8bdf80233852122b5a4f6f37" +dependencies = [ + "cc", +] + [[package]] name = "openssl-sys" version = "0.9.90" @@ -1048,6 +1175,7 @@ checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" dependencies = [ "cc", "libc", + "openssl-src", "pkg-config", "vcpkg", ] @@ -1058,17 +1186,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -1076,21 +1193,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall 0.2.16", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -1101,7 +1204,7 @@ checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.3.5", + "redox_syscall", "smallvec", "windows-targets 0.48.0", ] @@ -1112,6 +1215,15 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.0" @@ -1120,7 +1232,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml" -version = "0.9.0" +version = "0.9.1" dependencies = [ "anyhow", "async-trait", @@ -1146,6 +1258,26 @@ dependencies = [ "uuid", ] +[[package]] +name = "pin-project" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -1158,6 +1290,27 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.27" @@ -1195,7 +1348,7 @@ dependencies = [ "indoc", "libc", "memoffset", - "parking_lot 0.12.1", + "parking_lot", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -1309,33 +1462,13 @@ dependencies = [ "getrandom", ] -[[package]] -name = "redox_syscall" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" -dependencies = [ - "bitflags", -] - [[package]] name = "redox_syscall" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags", -] - -[[package]] -name = "redox_users" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" -dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", + "bitflags 1.3.2", ] [[package]] @@ -1361,7 +1494,7 @@ version = "0.11.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" dependencies = [ - "base64 0.21.2", + "base64", "bytes", "encoding_rs", "futures-core", @@ -1401,12 +1534,34 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted", "web-sys", "winapi", ] +[[package]] +name = "rsa" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab43bb47d23c1a631b4b680199a45255dce26fa9ab2fa902581f624ff13e6a8" +dependencies = [ + "byteorder", + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-iter", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rust_bridge" version = "0.1.0" @@ -1438,7 +1593,7 @@ version = "0.37.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d69718bf81c6127a49dc64e44a742e8bb9213c0ff8869a22c308f84c1d4ab06" dependencies = [ - "bitflags", + "bitflags 1.3.2", "errno", "io-lifetimes", "libc", @@ -1448,14 +1603,13 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.8" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" dependencies = [ - "log", "ring", + "rustls-webpki", "sct", - "webpki", ] [[package]] @@ -1464,7 +1618,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.21.2", + "base64", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261e9e0888cba427c3316e6322805653c9425240b6fd96cee7cb671ab70ab8d0" +dependencies = [ + "ring", + "untrusted", ] [[package]] @@ -1500,10 +1664,11 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.28.5" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbab99b8cd878ab7786157b7eb8df96333a6807cc6e45e8888c85b51534b401a" +checksum = "28c05a5bf6403834be253489bbe95fa9b1e5486bc843b61f60d26b5c9c1e244b" dependencies = [ + "inherent", "sea-query-attr", "sea-query-derive", "serde_json", @@ -1523,9 +1688,9 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.3.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cea85029985b40dfbf18318d85fe985c04db7c1b4e5e8e0a0a0cdff5f1e30f9" +checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" dependencies = [ "sea-query", "serde_json", @@ -1534,9 +1699,9 @@ dependencies = [ [[package]] name = "sea-query-derive" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63f62030c60f3a691f5fe251713b4e220b306e50a71e1d6f9cce1f24bb781978" +checksum = "bd78f2e0ee8e537e9195d1049b752e0433e2cac125426bccb7b5c3e508096117" dependencies = [ "heck", "proc-macro2", @@ -1551,7 +1716,7 @@ version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", @@ -1657,6 +1822,16 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signature" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" +dependencies = [ + "digest", + "rand_core", +] + [[package]] name = "slab" version = "0.4.8" @@ -1688,6 +1863,25 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1e996ef02c474957d681f1b05213dfb0abab947b446a62d37770b23500184a" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "sqlformat" version = "0.2.1" @@ -1701,99 +1895,212 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.6.3" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" +checksum = "8e58421b6bc416714d5115a2ca953718f6c621a51b68e4f4922aea5a4391a721" dependencies = [ "sqlx-core", "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", ] [[package]] name = "sqlx-core" -version = "0.6.3" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" +checksum = "dd4cef4251aabbae751a3710927945901ee1d97ee96d757f6880ebb9a79bfd53" dependencies = [ - "ahash 0.7.6", + "ahash", "atoi", - "base64 0.13.1", - "bitflags", "byteorder", "bytes", "chrono", "crc", "crossbeam-queue", - "dirs", "dotenvy", "either", "event-listener", "futures-channel", "futures-core", "futures-intrusive", + "futures-io", "futures-util", "hashlink", "hex", - "hkdf", - "hmac", - "indexmap", - "itoa", - "libc", + "indexmap 2.0.0", "log", - "md-5", "memchr", "once_cell", "paste", "percent-encoding", - "rand", "rustls", "rustls-pemfile", "serde", "serde_json", - "sha1", "sha2", "smallvec", "sqlformat", - "sqlx-rt", - "stringprep", "thiserror", "time 0.3.22", + "tokio", "tokio-stream", + "tracing", "url", "uuid", "webpki-roots", - "whoami", ] [[package]] name = "sqlx-macros" -version = "0.6.3" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "208e3165167afd7f3881b16c1ef3f2af69fa75980897aac8874a0696516d12c2" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 1.0.109", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" +checksum = "8a4a8336d278c62231d87f24e8a7a74898156e34c1c18942857be2acb29c7dfc" dependencies = [ "dotenvy", "either", "heck", + "hex", "once_cell", "proc-macro2", "quote", + "serde", "serde_json", "sha2", "sqlx-core", - "sqlx-rt", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", "syn 1.0.109", + "tempfile", + "tokio", "url", ] [[package]] -name = "sqlx-rt" -version = "0.6.3" +name = "sqlx-mysql" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" +checksum = "8ca69bf415b93b60b80dc8fda3cb4ef52b2336614d8da2de5456cc942a110482" dependencies = [ + "atoi", + "base64", + "bitflags 2.4.0", + "byteorder", + "bytes", + "chrono", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", "once_cell", - "tokio", - "tokio-rustls", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time 0.3.22", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0db2df1b8731c3651e204629dd55e52adbae0462fa1bdcbed56a2302c18181e" +dependencies = [ + "atoi", + "base64", + "bitflags 2.4.0", + "byteorder", + "chrono", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand", + "serde", + "serde_json", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time 0.3.22", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4c21bf34c7cae5b283efb3ac1bcc7670df7561124dc2f8bdc0b59be40f79a2" +dependencies = [ + "atoi", + "chrono", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "time 0.3.22", + "tracing", + "url", + "uuid", ] [[package]] @@ -1866,7 +2173,7 @@ dependencies = [ "autocfg", "cfg-if", "fastrand", - "redox_syscall 0.3.5", + "redox_syscall", "rustix", "windows-sys 0.48.0", ] @@ -1992,17 +2299,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls", - "tokio", - "webpki", -] - [[package]] name = "tokio-stream" version = "0.1.14" @@ -2041,6 +2337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2304,23 +2601,13 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "webpki-roots" -version = "0.22.6" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" +checksum = "b291546d5d9d1eab74f069c77749f2cb8504a12caa20f0f2de93ddbf6f411888" dependencies = [ - "webpki", + "rustls-webpki", ] [[package]] @@ -2328,10 +2615,6 @@ name = "whoami" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c70234412ca409cc04e864e89523cb0fc37f5e1344ebed5a3ebf4192b6b9f68" -dependencies = [ - "wasm-bindgen", - "web-sys", -] [[package]] name = "winapi" @@ -2504,3 +2787,9 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" dependencies = [ "winapi", ] + +[[package]] +name = "zeroize" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index c854cb61b..b3d15786a 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "0.9.0" +version = "0.9.1" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" @@ -15,7 +15,7 @@ crate-type = ["lib", "cdylib"] [dependencies] rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"} -sqlx = { version = "0.6", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid", "chrono"] } +sqlx = { version = "0.7", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid", "chrono"] } serde_json = "1.0.9" anyhow = "1.0.9" tokio = { version = "1.28.2", features = [ "macros" ] } @@ -26,10 +26,10 @@ neon = { version = "0.10", optional = true, default-features = false, features = itertools = "0.10.5" uuid = {version = "1.3.3", features = ["v4", "serde"] } md5 = "0.7.0" -sea-query = { version = "0.28.5", features = ["attr", "thread-safe", "with-json", "postgres-array"] } -sea-query-binder = { version = "0.3.1", features = ["sqlx-postgres", "with-json", "postgres-array"] } +sea-query = { version = "0.30.1", features = ["attr", "thread-safe", "with-json", "postgres-array"] } +sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-json", "postgres-array"] } regex = "1.8.4" -reqwest = { version = "0.11", features = ["json"] } +reqwest = { version = "0.11", features = ["json", "native-tls-vendored"] } async-trait = "0.1.71" tracing = { version = "0.1.37" } tracing-subscriber = { version = "0.3.17", features = ["json"] } diff --git a/pgml-sdks/pgml/javascript/README.md b/pgml-sdks/pgml/javascript/README.md index 77a687833..de4acede9 100644 --- a/pgml-sdks/pgml/javascript/README.md +++ b/pgml-sdks/pgml/javascript/README.md @@ -208,9 +208,11 @@ const collection = pgml.newCollection("test_collection", CUSTOM_DATABASE_URL) ### Upserting Documents -Documents are dictionaries with two required keys: `id` and `text`. All other keys/value pairs are stored as metadata for the document. +The `upsert_documents` method can be used to insert new documents and update existing documents. -**Upsert documents with metadata** +New documents are dictionaries with two required keys: `id` and `text`. All other keys/value pairs are stored as metadata for the document. + +**Upsert new documents with metadata** ```javascript const documents = [ { @@ -228,6 +230,98 @@ const collection = pgml.newCollection("test_collection") await collection.upsert_documents(documents) ``` +Document metadata can be updated by upserting the document without the `text` key. + +**Update document metadata** +```javascript +documents = [ + { + id: "Document 1", + random_key: "this will be NEW metadata for the document" + }, + { + id: "Document 2", + random_key: "this will be NEW metadata for the document" + } +] +collection = pgml.newCollection("test_collection") +await collection.upsert_documents(documents) +``` + +### Getting Documents + +Documents can be retrieved using the `get_documents` method on the collection object + +**Get the first 100 documents** +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.get_documents({ limit: 100 }) +``` + +#### Pagination + +The JavaScript SDK supports limit-offset pagination and keyset pagination + +**Limit-Offset pagination** +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.get_documents({ limit: 100, offset: 10 }) +``` + +**Keyset pagination** +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.get_documents({ limit: 100, last_row_id: 10 }) +``` + +The `last_row_id` can be taken from the `row_id` field in the returned document's dictionary. + +#### Filtering + +Metadata and full text filtering are supported just like they are in vector recall. + +**Metadata and full text filtering** +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.get_documents({ + limit: 100, + offset: 10, + filter: { + metadata: { + id: { + $eq: 1 + } + }, + full_text_search: { + configuration: "english", + text: "Some full text query" + } + } +}) + +``` + +### Deleting Documents + +Documents can be deleted with the `delete_documents` method on the collection object. + +Metadata and full text filtering are supported just like they are in vector recall. + +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.delete_documents({ + metadata: { + id: { + $eq: 1 + } + }, + full_text_search: { + configuration: "english", + text: "Some full text query" + } +}) +``` + ### Searching Collections The JavaScript SDK is specifically designed to provide powerful, flexible vector search. @@ -326,7 +420,7 @@ const results = await collection.query() .fetch_all() ``` -The above query would filter out all documents that do not have a key `special` with a value `True` or (have a key `uuid` equal to 1 and a key `index` less than 100). +The above query would filter out all documents that do not have a key `special` with a value `true` or (have a key `uuid` equal to 1 and a key `index` less than 100). #### Full Text Filtering @@ -418,7 +512,7 @@ const model = pgml.newModel() const splitter = pgml.newSplitter() const pipeline = pgml.newPipeline("test_pipeline", model, splitter, { "full_text_search": { - active: True, + active: true, configuration: "english" } }) diff --git a/pgml-sdks/pgml/javascript/build.js b/pgml-sdks/pgml/javascript/build.js new file mode 100644 index 000000000..30fe55bfa --- /dev/null +++ b/pgml-sdks/pgml/javascript/build.js @@ -0,0 +1,48 @@ +const os = require("os"); +const { exec } = require("node:child_process"); + +const type = os.type(); +const arch = os.arch(); + +const set_name = (type, arch) => { + if (type == "Darwin" && arch == "x64") { + return "x86_64-apple-darwin-index.node"; + } else if (type == "Darwin" && arch == "arm64") { + return "aarch64-apple-darwin-index.node"; + } else if ((type == "Windows" || type == "Windows_NT") && arch == "x64") { + return "x86_64-pc-windows-gnu-index.node"; + } else if (type == "Linux" && arch == "x64") { + return "x86_64-unknown-linux-gnu-index.node"; + } else if (type == "Linux" && arch == "arm64") { + return "aarch64-unknown-linux-gnu-index.node"; + } else { + console.log("UNSUPPORTED TYPE OR ARCH:", type, arch); + process.exit(1); + } +}; + +let name = set_name(type, arch); + +let args = process.argv.slice(2); +let release = args.includes("--release"); + +let shell_args = + type == "Windows" || type == "Windows_NT" ? { shell: "powershell.exe" } : {}; + +exec( + ` + rm -r dist; + mkdir dist; + npx cargo-cp-artifact -nc "${name}" -- cargo build --message-format=json-render-diagnostics -F javascript ${release ? "--release" : ""}; + mv ${name} dist; + `, + shell_args, + (err, stdout, stderr) => { + if (err) { + console.log("ERR:", err); + } else { + console.log("STDOUT:", stdout); + console.log("STDERR:", stderr); + } + }, +); diff --git a/pgml-sdks/pgml/javascript/index.js b/pgml-sdks/pgml/javascript/index.js index 5ebc5b4d3..47ab75a8e 100644 --- a/pgml-sdks/pgml/javascript/index.js +++ b/pgml-sdks/pgml/javascript/index.js @@ -3,26 +3,22 @@ const os = require("os") const type = os.type() const arch = os.arch() -try { - const pgml = require("./index.node") +if (type == "Darwin" && arch == "x64") { + const pgml = require("./dist/x86_64-apple-darwin-index.node") module.exports = pgml -} catch (e) { - if (type == "Darwin" && arch == "x64") { - const pgml = require("./dist/x86_64-apple-darwin-index.node") - module.exports = pgml - } else if (type == "Darwin" && arch == "arm64") { - const pgml = require("./dist/aarch64-apple-darwin-index.node") - module.exports = pgml - } else if (type == "Windows" && arch == "x64") { - const pgml = require("./dist/x86_64-pc-windows-gnu-index.node") - module.exports = pgml - } else if (type == "Linux" && arch == "x64") { - const pgml = require("./dist/x86_64-unknown-linux-gnu-index.node") - module.exports = pgml - } else if (type == "Linux" && arch == "arm64") { - const pgml = require("./dist/aarch64-unknown-linux-gnu-index.node") - module.exports = pgml - } else { - console.log("UNSUPPORTED TYPE OR ARCH:", type, arch) - } +} else if (type == "Darwin" && arch == "arm64") { + const pgml = require("./dist/aarch64-apple-darwin-index.node") + module.exports = pgml +} else if ((type == "Windows" || type == "Windows_NT") && arch == "x64") { + const pgml = require("./dist/x86_64-pc-windows-gnu-index.node") + module.exports = pgml +} else if (type == "Linux" && arch == "x64") { + const pgml = require("./dist/x86_64-unknown-linux-gnu-index.node") + module.exports = pgml +} else if (type == "Linux" && arch == "arm64") { + const pgml = require("./dist/aarch64-unknown-linux-gnu-index.node") + module.exports = pgml +} else { + console.log("UNSUPPORTED TYPE OR ARCH:", type, arch) + process.exit(1); } diff --git a/pgml-sdks/pgml/javascript/package.json b/pgml-sdks/pgml/javascript/package.json index 771b8c24e..551b7156d 100644 --- a/pgml-sdks/pgml/javascript/package.json +++ b/pgml-sdks/pgml/javascript/package.json @@ -1,13 +1,17 @@ { "name": "pgml", - "version": "0.9.0", + "version": "0.9.1", "description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone", - "keywords": ["postgres", "machine learning", "vector databases", "embeddings"], + "keywords": [ + "postgres", + "machine learning", + "vector databases", + "embeddings" + ], "main": "index.js", "scripts": { - "build": "cargo-cp-artifact -nc index.node -- cargo build --message-format=json-render-diagnostics -F javascript", - "build-debug": "npm run build --", - "build-release": "npm run build -- --release" + "build": "node build.js", + "build-release": "node build.js --release" }, "author": { "name": "PostgresML", diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index 5e5b76061..f4895edf4 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -21,6 +21,7 @@ const generate_dummy_documents = (count: number) => { project: "a10", uuid: i * 10, floating_uuid: i * 1.1, + test: null, name: `Test Document ${i}`, }); } @@ -156,3 +157,66 @@ it("pipeline to dict", async () => { expect(pipeline_dict["name"]).toBe("test_j_p_ptd_0"); await collection.archive(); }); + +/////////////////////////////////////////////////// +// Test document related functions //////////////// +/////////////////////////////////////////////////// + +it("can upsert and get documents", async () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline("test_p_p_cuagd_0", model, splitter, { + full_text_search: { active: true, configuration: "english" }, + }); + let collection = pgml.newCollection("test_p_c_cuagd_1"); + await collection.add_pipeline(pipeline); + await collection.upsert_documents(generate_dummy_documents(10)); + + let documents = await collection.get_documents(); + expect(documents).toHaveLength(10); + + documents = await collection.get_documents({ + offset: 1, + limit: 2, + filter: { metadata: { id: { $gt: 0 } } }, + }); + expect(documents).toHaveLength(2); + expect(documents[0]["document"]["id"]).toBe(2); + let last_row_id = documents[1]["row_id"]; + + documents = await collection.get_documents({ + filter: { + metadata: { id: { $gt: 3 } }, + full_text_search: { configuration: "english", text: "4" }, + }, + last_row_id: last_row_id, + }); + expect(documents).toHaveLength(1); + expect(documents[0]["document"]["id"]).toBe(4); + + await collection.archive(); +}); + +it("can delete documents", async () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline( + "test_p_p_cdd_0", + model, + splitter, + + { full_text_search: { active: true, configuration: "english" } }, + ); + let collection = pgml.newCollection("test_p_c_cdd_2"); + await collection.add_pipeline(pipeline); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.delete_documents({ + metadata: { id: { $gte: 0 } }, + full_text_search: { configuration: "english", text: "0" }, + }); + let documents = await collection.get_documents(); + expect(documents).toHaveLength(2); + expect(documents[0]["document"]["id"]).toBe(1); + + await collection.archive(); +}); diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index bd5d98d7f..df7bfa417 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "0.9.0" +version = "0.9.1" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, diff --git a/pgml-sdks/pgml/python/README.md b/pgml-sdks/pgml/python/README.md index b6b70aaea..a05c184ce 100644 --- a/pgml-sdks/pgml/python/README.md +++ b/pgml-sdks/pgml/python/README.md @@ -213,9 +213,11 @@ collection = Collection("test_collection", CUSTOM_DATABASE_URL) ### Upserting Documents -Documents are dictionaries with two required keys: `id` and `text`. All other keys/value pairs are stored as metadata for the document. +The `upsert_documents` method can be used to insert new documents and update existing documents. -**Upsert documents with metadata** +New documents are dictionaries with two required keys: `id` and `text`. All other keys/value pairs are stored as metadata for the document. + +**Upsert new documents with metadata** ```python documents = [ { @@ -233,6 +235,97 @@ collection = Collection("test_collection") await collection.upsert_documents(documents) ``` +Document metadata can be updated by upserting the document without the `text` key. + +**Update document metadata** +```python +documents = [ + { + "id": "Document 1", + "random_key": "this will be NEW metadata for the document" + }, + { + "id": "Document 2", + "random_key": "this will be NEW metadata for the document" + } +] +collection = Collection("test_collection") +await collection.upsert_documents(documents) +``` + +### Getting Documents + +Documents can be retrieved using the `get_documents` method on the collection object + +**Get the first 100 documents** +```python +collection = Collection("test_collection") +documents = await collection.get_documents({ "limit": 100 }) +``` + +#### Pagination + +The Python SDK supports limit-offset pagination and keyset pagination + +**Limit-Offset pagination** +```python +collection = Collection("test_collection") +documents = await collection.get_documents({ "limit": 100, "offset": 10 }) +``` + +**Keyset pagination** +```python +collection = Collection("test_collection") +documents = await collection.get_documents({ "limit": 100, "last_row_id": 10 }) +``` + +The `last_row_id` can be taken from the `row_id` field in the returned document's dictionary. + +#### Filtering + +Metadata and full text filtering are supported just like they are in vector recall. + +**Metadata and full text filtering** +```python +collection = Collection("test_collection") +documents = await collection.get_documents({ + "limit": 100, + "offset": 10, + "filter": { + "metadata": { + "id": { + "$eq": 1 + } + }, + "full_text_search": { + "configuration": "english", + "text": "Some full text query" + } + } +}) + +``` + +### Deleting Documents + +Documents can be deleted with the `delete_documents` method on the collection object. + +Metadata and full text filtering are supported just like they are in vector recall. + +```python +documents = await collection.delete_documents({ + "metadata": { + "id": { + "$eq": 1 + } + }, + "full_text_search": { + "configuration": "english", + "text": "Some full text query" + } +}) +``` + ### Searching Collections The Python SDK is specifically designed to provide powerful, flexible vector search. @@ -350,7 +443,7 @@ results = ( .vector_recall("Here is some query", pipeline) .limit(10) .filter({ - "full_text": { + "full_text_search": { "configuration": "english", "text": "Match Me" } diff --git a/pgml-sdks/pgml/python/pgml/pgml.pyi b/pgml-sdks/pgml/python/pgml/pgml.pyi index 02895348d..9ef3103be 100644 --- a/pgml-sdks/pgml/python/pgml/pgml.pyi +++ b/pgml-sdks/pgml/python/pgml/pgml.pyi @@ -3,3 +3,89 @@ def py_init_logger(level: Optional[str] = "", format: Optional[str] = "") -> Non Json = Any DateTime = int + +# Top of file key: A12BECOD! +from typing import List, Dict, Optional, Self, Any + + +class Builtins: + def __init__(self, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self + ... + def query(self, query: str) -> QueryRunner + ... + async def transform(self, task: Json, inputs: List[str], args: Optional[Json] = Any) -> Json + ... + +class Collection: + def __init__(self, name: str, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self + ... + async def add_pipeline(self, pipeline: Pipeline) -> None + ... + async def remove_pipeline(self, pipeline: Pipeline) -> None + ... + async def enable_pipeline(self, pipeline: Pipeline) -> None + ... + async def disable_pipeline(self, pipeline: Pipeline) -> None + ... + async def upsert_documents(self, documents: List[Json]) -> None + ... + async def get_documents(self, args: Optional[Json] = Any) -> List[Json] + ... + async def delete_documents(self, filter: Json) -> None + ... + async def vector_search(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any, top_k: Optional[int] = 1) -> List[tuple[float, str, Json]] + ... + async def archive(self) -> None + ... + def query(self) -> QueryBuilder + ... + async def get_pipelines(self) -> List[Pipeline] + ... + async def get_pipeline(self, name: str) -> Pipeline + ... + async def exists(self) -> bool + ... + +class Model: + def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", source: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self + ... + +class Pipeline: + def __init__(self, name: str, model: Optional[Model] = Any, splitter: Optional[Splitter] = Any, parameters: Optional[Json] = Any) -> Self + ... + async def get_status(self) -> PipelineSyncData + ... + async def to_dict(self) -> Json + ... + +class QueryBuilder: + def limit(self, limit: int) -> Self + ... + def filter(self, filter: Json) -> Self + ... + def vector_recall(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any) -> Self + ... + async def fetch_all(self) -> List[tuple[float, str, Json]] + ... + def to_full_string(self) -> str + ... + +class QueryRunner: + async def fetch_all(self) -> Json + ... + async def execute(self) -> None + ... + def bind_string(self, bind_value: str) -> Self + ... + def bind_int(self, bind_value: int) -> Self + ... + def bind_float(self, bind_value: float) -> Self + ... + def bind_bool(self, bind_value: bool) -> Self + ... + def bind_json(self, bind_value: Json) -> Self + ... + +class Splitter: + def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self + ... diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 88c19685d..a355b27a8 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -32,6 +32,7 @@ def generate_dummy_documents(count: int) -> List[Dict[str, Any]]: "project": "a10", "floating_uuid": i * 1.01, "uuid": i * 10, + "test": None, "name": "Test Document {}".format(i), } ) @@ -181,6 +182,74 @@ async def test_pipeline_to_dict(): await collection.archive() +################################################### +## Test document related functions ################ +################################################### + + +@pytest.mark.asyncio +async def test_upsert_and_get_documents(): + model = pgml.Model() + splitter = pgml.Splitter() + pipeline = pgml.Pipeline( + "test_p_p_tuagd_0", + model, + splitter, + {"full_text_search": {"active": True, "configuration": "english"}}, + ) + collection = pgml.Collection(name="test_p_c_tuagd_2") + await collection.add_pipeline( + pipeline, + ) + await collection.upsert_documents(generate_dummy_documents(10)) + + documents = await collection.get_documents() + assert len(documents) == 10 + + documents = await collection.get_documents( + {"offset": 1, "limit": 2, "filter": {"metadata": {"id": {"$gt": 0}}}} + ) + assert len(documents) == 2 and documents[0]["document"]["id"] == 2 + last_row_id = documents[-1]["row_id"] + + documents = await collection.get_documents( + { + "filter": { + "metadata": {"id": {"$gt": 3}}, + "full_text_search": {"configuration": "english", "text": "4"}, + }, + "last_row_id": last_row_id, + } + ) + assert len(documents) == 1 and documents[0]["document"]["id"] == 4 + + await collection.archive() + + +@pytest.mark.asyncio +async def test_delete_documents(): + model = pgml.Model() + splitter = pgml.Splitter() + pipeline = pgml.Pipeline( + "test_p_p_tdd_0", + model, + splitter, + {"full_text_search": {"active": True, "configuration": "english"}}, + ) + collection = pgml.Collection("test_p_c_tdd_1") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(3)) + await collection.delete_documents( + { + "metadata": {"id": {"$gte": 0}}, + "full_text_search": {"configuration": "english", "text": "0"}, + } + ) + documents = await collection.get_documents() + assert len(documents) == 2 and documents[0]["document"]["id"] == 1 + await collection.archive() + + ################################################### ## Test with multiprocessing ###################### ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 3686c1c1b..60465c130 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -10,7 +10,6 @@ pub struct Builtins { use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; - #[cfg(feature = "python")] use crate::{query_runner::QueryRunnerPython, types::JsonPython}; diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index ba0a1af3e..23fe6df42 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -2,6 +2,8 @@ use anyhow::Context; use indicatif::MultiProgress; use itertools::Itertools; use rust_bridge::{alias, alias_methods}; +use sea_query::{Alias, Expr, JoinType, Order, PostgresQueryBuilder, Query}; +use sea_query_binder::SqlxBinder; use sqlx::postgres::PgPool; use sqlx::Executor; use sqlx::PgConnection; @@ -9,10 +11,18 @@ use std::borrow::Cow; use std::time::SystemTime; use tracing::{instrument, warn}; +use crate::filter_builder; use crate::{ - get_or_initialize_pool, model::ModelRuntime, models, pipeline::Pipeline, queries, - query_builder, query_builder::QueryBuilder, remote_embeddings::build_remote_embeddings, - splitter::Splitter, types::DateTime, types::Json, utils, + get_or_initialize_pool, + model::ModelRuntime, + models, + pipeline::Pipeline, + queries, query_builder, + query_builder::QueryBuilder, + remote_embeddings::build_remote_embeddings, + splitter::Splitter, + types::{DateTime, IntoTableNameAndSchema, Json, SIden, TryToNumeric}, + utils, }; #[cfg(feature = "python")] @@ -101,6 +111,7 @@ pub struct Collection { new, upsert_documents, get_documents, + delete_documents, get_pipelines, get_pipeline, add_pipeline, @@ -192,7 +203,7 @@ impl Collection { let project_id: i64 = sqlx::query_scalar("INSERT INTO pgml.projects (name, task) VALUES ($1, 'embedding'::pgml.task) ON CONFLICT (name) DO UPDATE SET task = EXCLUDED.task RETURNING id, task::TEXT") .bind(&self.name) - .fetch_one(&mut transaction) + .fetch_one(&mut *transaction) .await?; transaction @@ -202,7 +213,7 @@ impl Collection { let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id) VALUES ($1, $2) ON CONFLICT (name) DO NOTHING RETURNING *") .bind(&self.name) .bind(project_id) - .fetch_one(&mut transaction) + .fetch_one(&mut *transaction) .await?; let collection_database_data = CollectionDatabaseData { @@ -320,7 +331,7 @@ impl Collection { "DROP TABLE IF EXISTS %s", embeddings_table_name )) - .execute(&mut transaction) + .execute(&mut *transaction) .await?; // Need to delete from the tsvectors table only if no other pipelines use the @@ -331,7 +342,7 @@ impl Collection { self.pipelines_table_name)) .bind(parameters["full_text_search"]["configuration"].as_str()) .bind(database_data.id) - .execute(&mut transaction) + .execute(&mut *transaction) .await?; sqlx::query(&query_builder!( @@ -339,7 +350,7 @@ impl Collection { self.pipelines_table_name )) .bind(database_data.id) - .execute(&mut transaction) + .execute(&mut *transaction) .await?; transaction.commit().await?; @@ -541,60 +552,42 @@ impl Collection { /// serde_json::json!({"id": 1, "text": "hello world"}).into(), /// serde_json::json!({"id": 2, "text": "hello world"}).into(), /// ]; - /// collection.upsert_documents(documents, Some(true)).await?; + /// collection.upsert_documents(documents).await?; /// Ok(()) /// } /// ``` #[instrument(skip(self, documents))] - pub async fn upsert_documents( - &mut self, - documents: Vec, - strict: Option, - ) -> anyhow::Result<()> { + pub async fn upsert_documents(&mut self, documents: Vec) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; - let strict = strict.unwrap_or(true); - let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); - let documents: anyhow::Result> = documents.into_iter().map(|mut document| { - let document = document - .as_object_mut() - .expect("Documents must be a vector of objects"); - let text = match document.remove("text") { - Some(t) => t, - None => { - if strict { - anyhow::bail!("`text` is not a key in document, throwing error. To supress this error, pass strict: false"); - } else { - warn!("`text` is not a key in document, skipping document. To throw an error instead, pass strict: true"); - } - return Ok(None) - } - }; - let text = text.as_str().context("`text` must be a string")?.to_string(); - - // We don't want the text included in the document metadata, but everything else - // should be in there - let metadata = serde_json::to_value(&document)?.into(); - - let md5_digest = match document.get("id") { - Some(k) => md5::compute(k.to_string().as_bytes()), - None => { - if strict { - anyhow::bail!("`id` is not a key in document, throwing error. To supress this error, pass strict: false"); - } else { - warn!("`id` is not a key in document, skipping document. To throw an error instead, pass strict: true"); - } - return Ok(None) - } - }; - let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - - Ok(Some((source_uuid, text, metadata))) - }).collect(); + let documents: anyhow::Result> = documents + .into_iter() + .map(|mut document| { + let document = document + .as_object_mut() + .expect("Documents must be a vector of objects"); + let text = document + .remove("text") + .map(|t| t.as_str().expect("`text` must be a string").to_string()); + + // We don't want the text included in the document metadata, but everything else + // should be in there + let metadata = serde_json::to_value(&document)?.into(); + + let id = document + .get("id") + .context("`id` must be a key in documen")? + .to_string(); + let md5_digest = md5::compute(id.as_bytes()); + let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + + Ok(Some((source_uuid, text, metadata))) + }) + .collect(); // We could continue chaining the above iterators but types become super annoying to // deal with, especially because we are dealing with async functions. This is much easier to read @@ -606,26 +599,41 @@ impl Collection { // We want the length before we filter out any None values let chunk_len = chunk.len(); // Filter out the None values - let chunk: Vec<&(uuid::Uuid, String, Json)> = + let mut chunk: Vec<&(uuid::Uuid, Option, Json)> = chunk.iter().filter_map(|x| x.as_ref()).collect(); - // Make sure we didn't filter everything out + // If the chunk is empty, we can skip the rest of the loop if chunk.is_empty() { progress_bar.inc(chunk_len as u64); continue; } + // Split the chunk into two groups, one with text, and one with just metadata + let split_index = itertools::partition(&mut chunk, |(_, text, _)| text.is_some()); + let (text_chunk, metadata_chunk) = chunk.split_at(split_index); + + // Start the transaction let mut transaction = pool.begin().await?; - // First delete any documents that already have the same UUID then insert the new ones. + + // Update the metadata + sqlx::query(query_builder!( + "UPDATE %s d SET metadata = v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", + self.documents_table_name + ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) + .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) + .execute(&mut *transaction).await?; + + // First delete any documents that already have the same UUID as documents in + // text_chunk, then insert the new ones. // We are essentially upserting in two steps sqlx::query(&query_builder!( "DELETE FROM %s WHERE source_uuid IN (SELECT source_uuid FROM %s WHERE source_uuid = ANY($1::uuid[]))", self.documents_table_name, self.documents_table_name )). - bind(&chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()). - execute(&mut transaction).await?; - let query_string_values = (0..chunk.len()) + bind(&text_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()). + execute(&mut *transaction).await?; + let query_string_values = (0..text_chunk.len()) .map(|i| format!("(${}, ${}, ${})", i * 3 + 1, i * 3 + 2, i * 3 + 3)) .collect::>() .join(","); @@ -635,11 +643,10 @@ impl Collection { ); let query = query_builder!(query_string, self.documents_table_name); let mut query = sqlx::query_scalar(&query); - for (source_uuid, text, metadata) in chunk.into_iter() { + for (source_uuid, text, metadata) in text_chunk.iter() { query = query.bind(source_uuid).bind(text).bind(metadata); } - - let ids: Vec = query.fetch_all(&mut transaction).await?; + let ids: Vec = query.fetch_all(&mut *transaction).await?; document_ids.extend(ids); progress_bar.inc(chunk_len as u64); transaction.commit().await?; @@ -655,8 +662,7 @@ impl Collection { /// /// # Arguments /// - /// * `last_id` - The last id of the document to get. If none, starts at 0 - /// * `limit` - The number of documents to get. If none, gets 100 + /// * `args` - The filters and options to apply to the query /// /// # Example /// @@ -665,36 +671,190 @@ impl Collection { /// /// async fn example() -> anyhow::Result<()> { /// let mut collection = Collection::new("my_collection", None); - /// let documents = collection.get_documents(None, None).await?; + /// let documents = collection.get_documents(None).await?; /// Ok(()) /// } #[instrument(skip(self))] - pub async fn get_documents( - &mut self, - last_id: Option, - limit: Option, - ) -> anyhow::Result> { + pub async fn get_documents(&mut self, args: Option) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let documents: Vec = sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE id > $1 ORDER BY id ASC LIMIT $2", - self.documents_table_name - )) - .bind(last_id.unwrap_or(0)) - .bind(limit.unwrap_or(100)) - .fetch_all(&pool) - .await?; - documents + + let mut args = args.unwrap_or_default().0; + let args = args.as_object_mut().context("args must be an object")?; + + // Get limit or set it to 1000 + let limit = args + .remove("limit") + .map(|l| l.try_to_u64()) + .unwrap_or(Ok(1000))?; + + let mut query = Query::select(); + query + .from_as( + self.documents_table_name.to_table_tuple(), + SIden::Str("documents"), + ) + .expr(Expr::cust("*")) // Adds the * in SELECT * FROM + .order_by((SIden::Str("documents"), SIden::Str("id")), Order::Asc) + .limit(limit); + + if let Some(last_row_id) = args.remove("last_row_id") { + let last_row_id = last_row_id + .try_to_u64() + .context("last_row_id must be an integer")?; + query.and_where(Expr::col((SIden::Str("documents"), SIden::Str("id"))).gt(last_row_id)); + } + + if let Some(offset) = args.remove("offset") { + let offset = offset.try_to_u64().context("offset must be an integer")?; + query.offset(offset); + } + + if let Some(mut filter) = args.remove("filter") { + let filter = filter + .as_object_mut() + .context("filter must be a Json object")?; + + if let Some(f) = filter.remove("metadata") { + query.cond_where( + filter_builder::FilterBuilder::new(f, "documents", "metadata").build(), + ); + } + if let Some(f) = filter.remove("full_text_search") { + let f = f + .as_object() + .context("Full text filter must be a Json object")?; + let configuration = f + .get("configuration") + .context("In full_text_search `configuration` is required")? + .as_str() + .context("In full_text_search `configuration` must be a string")?; + let filter_text = f + .get("text") + .context("In full_text_search `text` is required")? + .as_str() + .context("In full_text_search `text` must be a string")?; + query + .join_as( + JoinType::InnerJoin, + self.documents_tsvectors_table_name.to_table_tuple(), + Alias::new("documents_tsvectors"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), + ) + .and_where( + Expr::col(( + SIden::Str("documents_tsvectors"), + SIden::Str("configuration"), + )) + .eq(configuration), + ) + .and_where(Expr::cust_with_values( + format!( + "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", + configuration + ), + [filter_text], + )); + } + } + + let (sql, values) = query.build_sqlx(PostgresQueryBuilder); + let documents: Vec = + sqlx::query_as_with(&sql, values).fetch_all(&pool).await?; + Ok(documents .into_iter() - .map(|d| { - serde_json::to_value(d) - .map(|t| t.into()) - .map_err(|e| anyhow::anyhow!(e)) - }) - .collect() + .map(|d| d.into_user_friendly_json()) + .collect()) + } + /// Deletes documents in a [Collection] + /// + /// # Arguments + /// + /// * `filter` - The filters to apply + /// + /// # Example + /// + /// ``` + /// use pgml::Collection; + /// + /// async fn example() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None); + /// let documents = collection.delete_documents(serde_json::json!({ + /// "metadata": { + /// "id": { + /// "eq": 1 + /// } + /// } + /// }).into()).await?; + /// Ok(()) + /// } + #[instrument(skip(self))] + pub async fn delete_documents(&mut self, mut filter: Json) -> anyhow::Result<()> { + let pool = get_or_initialize_pool(&self.database_url).await?; + + let mut query = Query::delete(); + query.from_table(self.documents_table_name.to_table_tuple()); + + let filter = filter + .as_object_mut() + .context("filter must be a Json object")?; + + if let Some(f) = filter.remove("metadata") { + query + .cond_where(filter_builder::FilterBuilder::new(f, "documents", "metadata").build()); + } + + if let Some(mut f) = filter.remove("full_text_search") { + let f = f + .as_object_mut() + .context("Full text filter must be a Json object")?; + let configuration = f + .get("configuration") + .context("In full_text_search `configuration` is required")? + .as_str() + .context("In full_text_search `configuration` must be a string")?; + let filter_text = f + .get("text") + .context("In full_text_search `text` is required")? + .as_str() + .context("In full_text_search `text` must be a string")?; + let mut inner_select_query = Query::select(); + inner_select_query + .from_as( + self.documents_tsvectors_table_name.to_table_tuple(), + SIden::Str("documents_tsvectors"), + ) + .column(SIden::Str("document_id")) + .and_where(Expr::cust_with_values( + format!( + "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", + configuration + ), + [filter_text], + )) + .and_where( + Expr::col(( + SIden::Str("documents_tsvectors"), + SIden::Str("configuration"), + )) + .eq(configuration), + ); + query.and_where( + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .in_subquery(inner_select_query), + ); + } + + let (sql, values) = query.build_sqlx(PostgresQueryBuilder); + sqlx::query_with(&sql, values).fetch_all(&pool).await?; + Ok(()) } #[instrument(skip(self))] - pub async fn sync_pipelines(&mut self, document_ids: Option>) -> anyhow::Result<()> { + pub(crate) async fn sync_pipelines( + &mut self, + document_ids: Option>, + ) -> anyhow::Result<()> { self.verify_in_database(false).await?; let pipelines = self.get_pipelines().await?; if !pipelines.is_empty() { @@ -711,10 +871,6 @@ impl Collection { .expect("Failed to execute pipeline"); }) .await; - // pipelines.into_iter().for_each - // for mut pipeline in pipelines { - // pipeline.execute(&document_ids, mp.clone()).await?; - // } eprintln!("Done Syncing Pipelines\n"); } Ok(()) @@ -878,14 +1034,14 @@ impl Collection { sqlx::query("UPDATE pgml.collections SET name = $1, active = FALSE where name = $2") .bind(&archive_table_name) .bind(&self.name) - .execute(&mut transaciton) + .execute(&mut *transaciton) .await?; sqlx::query(&query_builder!( "ALTER SCHEMA %s RENAME TO %s", &self.name, archive_table_name )) - .execute(&mut transaciton) + .execute(&mut *transaciton) .await?; transaciton.commit().await?; Ok(()) @@ -1062,45 +1218,3 @@ impl Collection { .unwrap() } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::init_logger; - - #[sqlx::test] - async fn can_upsert_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cud_2", None); - - // Test basic upsert - let documents = vec![ - serde_json::json!({"id": 1, "text": "hello world"}).into(), - serde_json::json!({"text": "hello world"}).into(), - ]; - collection - .upsert_documents(documents.clone(), Some(false)) - .await?; - let document = &collection.get_documents(None, Some(1)).await?[0]; - assert_eq!(document["text"], "hello world"); - - // Test strictness - assert!(collection - .upsert_documents(documents, Some(true)) - .await - .is_err()); - - // Test upsert - let documents = vec![ - serde_json::json!({"id": 1, "text": "hello world 2"}).into(), - serde_json::json!({"text": "hello world"}).into(), - ]; - collection - .upsert_documents(documents.clone(), Some(false)) - .await?; - let document = &collection.get_documents(None, Some(1)).await?[0]; - assert_eq!(document["text"], "hello world 2"); - collection.archive().await?; - Ok(()) - } -} diff --git a/pgml-sdks/pgml/src/filter_builder.rs b/pgml-sdks/pgml/src/filter_builder.rs index 20e6c1acc..cf32ffa4b 100644 --- a/pgml-sdks/pgml/src/filter_builder.rs +++ b/pgml-sdks/pgml/src/filter_builder.rs @@ -116,9 +116,7 @@ fn get_value_type(value: &serde_json::Value) -> String { get_value_type(value) } else if value.is_string() { "text".to_string() - } else if value.is_i64() { - "float8".to_string() - } else if value.is_f64() { + } else if value.is_i64() || value.is_f64() { "float8".to_string() } else if value.is_boolean() { "bool".to_string() @@ -278,29 +276,35 @@ mod tests { } #[test] - fn eq_ne_comparison_operators() { - let basic_comparison_operators = vec!["", "NOT "]; - let basic_comparison_operators_names = vec!["$eq", "$ne"]; - for (operator, name) in basic_comparison_operators - .into_iter() - .zip(basic_comparison_operators_names.into_iter()) - { - let sql = construct_filter_builder_with_json(json!({ - "id": {name: 1}, - "id2": {"id3": {name: "test"}}, - "id4": {"id5": {"id6": {name: true}}}, - "id7": {"id8": {"id9": {"id10": {name: [1, 2, 3]}}}} - })) - .build() - .to_valid_sql_query(); - assert_eq!( - sql, - format!( - r##"SELECT "id" FROM "test_table" WHERE {}"test_table"."metadata" @> E'{{\"id\":1}}' AND {}"test_table"."metadata" @> E'{{\"id2\":{{\"id3\":\"test\"}}}}' AND {}"test_table"."metadata" @> E'{{\"id4\":{{\"id5\":{{\"id6\":true}}}}}}' AND {}"test_table"."metadata" @> E'{{\"id7\":{{\"id8\":{{\"id9\":{{\"id10\":[1,2,3]}}}}}}}}'"##, - operator, operator, operator, operator - ) - ); - } + fn eq_operator() { + let sql = construct_filter_builder_with_json(json!({ + "id": {"$eq": 1}, + "id2": {"id3": {"$eq": "test"}}, + "id4": {"id5": {"id6": {"$eq": true}}}, + "id7": {"id8": {"id9": {"id10": {"$eq": [1, 2, 3]}}}} + })) + .build() + .to_valid_sql_query(); + assert_eq!( + sql, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}' AND ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"## + ); + } + + #[test] + fn ne_operator() { + let sql = construct_filter_builder_with_json(json!({ + "id": {"$ne": 1}, + "id2": {"id3": {"$ne": "test"}}, + "id4": {"id5": {"id6": {"$ne": true}}}, + "id7": {"id8": {"id9": {"id10": {"$ne": [1, 2, 3]}}}} + })) + .build() + .to_valid_sql_query(); + assert_eq!( + sql, + r##"SELECT "id" FROM "test_table" WHERE (NOT ("test_table"."metadata")) @> E'{\"id\":1}' AND (NOT ("test_table"."metadata")) @> E'{\"id2\":{\"id3\":\"test\"}}' AND (NOT ("test_table"."metadata")) @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND (NOT ("test_table"."metadata")) @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"## + ); } #[test] @@ -320,7 +324,7 @@ mod tests { assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} 1 AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} 1"##, + r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata"#>>'{{id}}')::float8) {} 1 AND (("test_table"."metadata"#>>'{{id2,id3}}')::float8) {} 1"##, operator, operator ) ); @@ -344,7 +348,7 @@ mod tests { assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} (1) AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} (1)"##, + r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata"#>>'{{id}}')::float8) {} (1) AND (("test_table"."metadata"#>>'{{id2,id3}}')::float8) {} (1)"##, operator, operator ) ); @@ -363,7 +367,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"## + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"## ); } @@ -379,7 +383,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"## + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"## ); } @@ -395,7 +399,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE NOT ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}')"## + r##"SELECT "id" FROM "test_table" WHERE NOT (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}')"## ); } @@ -415,7 +419,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') AND "test_table"."metadata" @> E'{\"id4\":1}'"## + r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') AND ("test_table"."metadata") @> E'{\"id4\":1}'"## ); let sql = construct_filter_builder_with_json(json!({ "$or": [ @@ -431,7 +435,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') OR "test_table"."metadata" @> E'{\"id4\":1}'"## + r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') OR ("test_table"."metadata") @> E'{\"id4\":1}'"## ); let sql = construct_filter_builder_with_json(json!({ "metadata": {"$or": [ @@ -443,7 +447,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR "test_table"."metadata" @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"## + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR ("test_table"."metadata") @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"## ); } } diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index 6465c408d..06e158be2 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -1,5 +1,5 @@ use neon::prelude::*; -use rust_bridge::javascript::{IntoJsResult, FromJsType}; +use rust_bridge::javascript::{FromJsType, IntoJsResult}; use crate::{ pipeline::PipelineSyncData, @@ -54,7 +54,7 @@ impl IntoJsResult for Json { } Ok(js_object.upcast()) } - _ => panic!("Unsupported type for JSON conversion"), + serde_json::Value::Null => Ok(cx.null().upcast()), } } } @@ -113,6 +113,8 @@ impl FromJsType for Json { json.insert(key, json_value.0); } Ok(Self(serde_json::Value::Object(json))) + } else if arg.is_a::(cx) { + Ok(Self(serde_json::Value::Null)) } else { panic!("Unsupported type for Json conversion"); } diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index 728c2a0ce..3d81c9377 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -40,7 +40,7 @@ impl ToPyObject for Json { } dict.to_object(py) } - _ => panic!("Unsupported type for JSON conversion"), + serde_json::Value::Null => py.None(), } } } @@ -100,6 +100,9 @@ impl FromPyObject<'_> for Json { } Ok(Self(serde_json::Value::Array(json_values))) } else { + if ob.is_none() { + return Ok(Self(serde_json::Value::Null)); + } panic!("Unsupported type for JSON conversion"); } } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 96ee99b6b..8c6c355ec 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -145,7 +145,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { fn js_init_logger( mut cx: neon::context::FunctionContext, ) -> neon::result::JsResult { - use rust_bridge::javascript::{IntoJsResult, FromJsType}; + use rust_bridge::javascript::{FromJsType, IntoJsResult}; let level = cx.argument_opt(0); let level = >::from_option_js_type(&mut cx, level)?; let format = cx.argument_opt(1); @@ -170,6 +170,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { mod tests { use super::*; use crate::{model::Model, pipeline::Pipeline, splitter::Splitter, types::Json}; + use serde_json::json; fn generate_dummy_documents(count: usize) -> Vec { let mut documents = Vec::new(); @@ -188,6 +189,10 @@ mod tests { documents } + /////////////////////////////// + // Collection & Pipelines ///// + /////////////////////////////// + #[sqlx::test] async fn can_create_collection() -> anyhow::Result<()> { init_logger(None, None).ok(); @@ -310,7 +315,7 @@ mod tests { collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; collection - .upsert_documents(generate_dummy_documents(3), Some(true)) + .upsert_documents(generate_dummy_documents(3)) .await?; let status_1 = pipeline1.get_status().await?; let status_2 = pipeline2.get_status().await?; @@ -326,6 +331,10 @@ mod tests { Ok(()) } + /////////////////////////////// + // Various Searches /////////// + /////////////////////////////// + #[sqlx::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { init_logger(None, None).ok(); @@ -351,7 +360,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswle_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection .vector_search("Here is some query", &mut pipeline, None, None) @@ -390,7 +399,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection .vector_search("Here is some query", &mut pipeline, None, None) @@ -425,7 +434,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection .query() @@ -466,7 +475,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection .query() @@ -477,4 +486,438 @@ mod tests { collection.archive().await?; Ok(()) } + + #[sqlx::test] + async fn can_filter_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let model = Model::new(None, None, None); + let splitter = Splitter::new(None, None); + let mut pipeline = Pipeline::new( + "test_r_p_cfd_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + let mut collection = Collection::new("test_r_c_cfd_2", None); + collection.add_pipeline(&mut pipeline).await?; + collection + .upsert_documents(generate_dummy_documents(5)) + .await?; + + let filters = vec![ + (5, json!({}).into()), + ( + 3, + json!({ + "metadata": { + "id": { + "$lt": 3 + } + } + }) + .into(), + ), + ( + 1, + json!({ + "full_text_search": { + "configuration": "english", + "text": "1", + } + }) + .into(), + ), + ]; + + for (expected_result_count, filter) in filters { + let results = collection + .query() + .vector_recall("Here is some query", &mut pipeline, None) + .filter(filter) + .fetch_all() + .await?; + println!("{:?}", results); + assert_eq!(results.len(), expected_result_count); + } + + collection.archive().await?; + Ok(()) + } + + /////////////////////////////// + // Working With Documents ///// + /////////////////////////////// + + #[sqlx::test] + async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cuafgd_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + + let mut collection = Collection::new("test_r_c_cuagd_2", None); + collection.add_pipeline(&mut pipeline).await?; + + // Test basic upsert + let documents = vec![ + serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), + serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), + serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(), + ]; + collection.upsert_documents(documents.clone()).await?; + let document = &collection.get_documents(None).await?[0]; + assert_eq!(document["document"]["text"], "hello world 1"); + + // Test upsert of text and metadata + let documents = vec![ + serde_json::json!({"id": 1, "text": "hello world new"}).into(), + serde_json::json!({"id": 2, "random_key": 12}).into(), + serde_json::json!({"id": 3, "random_key": 13}).into(), + ]; + collection.upsert_documents(documents.clone()).await?; + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "metadata": { + "random_key": { + "$eq": 12 + } + } + } + }) + .into(), + )) + .await?; + assert_eq!(documents[0]["document"]["text"], "hello world 2"); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "metadata": { + "random_key": { + "$gte": 13 + } + } + } + }) + .into(), + )) + .await?; + assert_eq!(documents[0]["document"]["text"], "hello world 3"); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "full_text_search": { + "configuration": "english", + "text": "new" + } + } + }) + .into(), + )) + .await?; + assert_eq!(documents[0]["document"]["text"], "hello world new"); + assert_eq!(documents[0]["document"]["id"].as_i64().unwrap(), 1); + + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_paginate_get_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cpgd_2", None); + collection + .upsert_documents(generate_dummy_documents(10)) + .await?; + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 5, + "offset": 0 + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![1, 2, 3, 4, 5] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 2, + "offset": 5 + }) + .into(), + )) + .await?; + let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![6, 7] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 2, + "last_row_id": last_row_id + }) + .into(), + )) + .await?; + let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![8, 9] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 1, + "last_row_id": last_row_id + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![10] + ); + + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cfapgd_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + + let mut collection = Collection::new("test_r_c_cfapgd_1", None); + collection.add_pipeline(&mut pipeline).await?; + + collection + .upsert_documents(generate_dummy_documents(10)) + .await?; + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "metadata": { + "id": { + "$gte": 2 + } + } + }, + "limit": 2, + "offset": 0 + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["document"]["id"].as_i64().unwrap()) + .collect::>(), + vec![2, 3] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "metadata": { + "id": { + "$lte": 5 + } + } + }, + "limit": 100, + "offset": 4 + }) + .into(), + )) + .await?; + let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + assert_eq!( + documents + .into_iter() + .map(|d| d["document"]["id"].as_i64().unwrap()) + .collect::>(), + vec![4, 5] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "full_text_search": { + "configuration": "english", + "text": "document" + } + }, + "limit": 100, + "last_row_id": last_row_id + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["document"]["id"].as_i64().unwrap()) + .collect::>(), + vec![6, 7, 8, 9] + ); + + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_filter_and_delete_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let model = Model::new(None, None, None); + let splitter = Splitter::new(None, None); + let mut pipeline = Pipeline::new( + "test_r_p_cfadd_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + + let mut collection = Collection::new("test_r_c_cfadd_1", None); + collection.add_pipeline(&mut pipeline).await?; + collection + .upsert_documents(generate_dummy_documents(10)) + .await?; + + collection + .delete_documents( + serde_json::json!({ + "metadata": { + "id": { + "$lt": 2 + } + } + }) + .into(), + ) + .await?; + let documents = collection.get_documents(None).await?; + assert_eq!(documents.len(), 8); + assert!(documents + .iter() + .all(|d| d["document"]["id"].as_i64().unwrap() >= 2)); + + collection + .delete_documents( + serde_json::json!({ + "full_text_search": { + "configuration": "english", + "text": "2" + } + }) + .into(), + ) + .await?; + let documents = collection.get_documents(None).await?; + assert_eq!(documents.len(), 7); + assert!(documents + .iter() + .all(|d| d["document"]["id"].as_i64().unwrap() > 2)); + + collection + .delete_documents( + serde_json::json!({ + "metadata": { + "id": { + "$gte": 6 + } + }, + "full_text_search": { + "configuration": "english", + "text": "6" + } + }) + .into(), + ) + .await?; + let documents = collection.get_documents(None).await?; + assert_eq!(documents.len(), 6); + assert!(documents + .iter() + .all(|d| d["document"]["id"].as_i64().unwrap() != 6)); + + collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 54f3fc5a0..07b2a1c98 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -20,7 +20,7 @@ use crate::types::JsonPython; /// annoying, but with the traits implimented below is a breeze and can be done just using .into /// Our model runtimes -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ModelRuntime { Python, OpenAI, diff --git a/pgml-sdks/pgml/src/models.rs b/pgml-sdks/pgml/src/models.rs index 2b735f5a0..07440d4e3 100644 --- a/pgml-sdks/pgml/src/models.rs +++ b/pgml-sdks/pgml/src/models.rs @@ -69,6 +69,19 @@ pub struct Document { pub text: String, } +impl Document { + pub fn into_user_friendly_json(mut self) -> Json { + self.metadata["text"] = self.text.into(); + serde_json::json!({ + "row_id": self.id, + "created_at": self.created_at, + "source_uuid": self.source_uuid, + "document": self.metadata, + }) + .into() + } +} + // A collection of documents #[enum_def] #[derive(FromRow)] diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 4e7b2d709..87e632b34 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -659,7 +659,7 @@ impl Pipeline { ), embedding_length )) - .execute(&mut transaction) + .execute(&mut *transaction) .await?; transaction .execute( diff --git a/pgml-sdks/pgml/src/query_builder/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs similarity index 91% rename from pgml-sdks/pgml/src/query_builder/query_builder.rs rename to pgml-sdks/pgml/src/query_builder.rs index 9f70b49cc..a759cc7e4 100644 --- a/pgml-sdks/pgml/src/query_builder/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -1,54 +1,25 @@ use anyhow::Context; -use itertools::Itertools; use rust_bridge::{alias, alias_methods}; use sea_query::{ - query::SelectStatement, Alias, CommonTableExpression, Expr, Func, Iden, JoinType, Order, + query::SelectStatement, Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, QueryStatementWriter, WithClause, }; use sea_query_binder::SqlxBinder; use std::borrow::Cow; use crate::{ - filter_builder, get_or_initialize_pool, models, pipeline::Pipeline, - remote_embeddings::build_remote_embeddings, types::Json, Collection, + filter_builder, get_or_initialize_pool, + model::ModelRuntime, + models, + pipeline::Pipeline, + remote_embeddings::build_remote_embeddings, + types::{IntoTableNameAndSchema, Json, SIden}, + Collection, }; #[cfg(feature = "python")] use crate::{pipeline::PipelinePython, types::JsonPython}; -#[derive(Clone)] -enum SIden<'a> { - Str(&'a str), - String(String), -} - -impl Iden for SIden<'_> { - fn unquoted(&self, s: &mut dyn std::fmt::Write) { - write!( - s, - "{}", - match self { - SIden::Str(s) => s, - SIden::String(s) => s.as_str(), - } - ) - .unwrap(); - } -} - -trait IntoTableNameAndSchema { - fn to_table_tuple<'b>(&self) -> (SIden<'b>, SIden<'b>); -} - -impl IntoTableNameAndSchema for String { - fn to_table_tuple<'b>(&self) -> (SIden<'b>, SIden<'b>) { - self.split('.') - .map(|s| SIden::String(s.to_string())) - .collect_tuple() - .expect("Malformed table name in IntoTableNameAndSchema") - } -} - #[derive(Clone, Debug)] struct QueryBuilderState {} @@ -88,7 +59,7 @@ impl QueryBuilder { if let Some(f) = filter.remove("metadata") { self = self.filter_metadata(f); } - if let Some(f) = filter.remove("full_text") { + if let Some(f) = filter.remove("full_text_search") { self = self.filter_full_text(f); } self @@ -131,7 +102,7 @@ impl QueryBuilder { .eq(configuration), ) .and_where(Expr::cust_with_values( - &format!( + format!( "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", configuration ), @@ -273,6 +244,11 @@ impl QueryBuilder { .as_ref() .context("Pipeline must be verified to perform vector search with remote embeddings")?; + // If the model runtime is python, the error was not caused by an unsupported runtime + if model.runtime == ModelRuntime::Python { + return Err(anyhow::anyhow!(e)); + } + let query_parameters = self.query_parameters.to_owned().unwrap_or_default(); let remote_embeddings = diff --git a/pgml-sdks/pgml/src/query_builder/mod.rs b/pgml-sdks/pgml/src/query_builder/mod.rs deleted file mode 100644 index 102e40e0b..000000000 --- a/pgml-sdks/pgml/src/query_builder/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod query_builder; -pub use query_builder::*; diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index fa390517e..d3d1ce306 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -1,4 +1,7 @@ +use anyhow::Context; +use itertools::Itertools; use rust_bridge::alias_manual; +use sea_query::Iden; use serde::Serialize; use std::ops::{Deref, DerefMut}; @@ -39,6 +42,27 @@ impl Serialize for Json { } } +pub(crate) trait TryToNumeric { + fn try_to_u64(&self) -> anyhow::Result; +} + +impl TryToNumeric for serde_json::Value { + fn try_to_u64(&self) -> anyhow::Result { + match self { + serde_json::Value::Number(n) => { + if n.is_f64() { + Ok(n.as_f64().unwrap() as u64) + } else if n.is_i64() { + Ok(n.as_i64().unwrap() as u64) + } else { + n.as_u64().context("limit must be an integer") + } + } + _ => Err(anyhow::anyhow!("Json value is not a number")), + } + } +} + /// A wrapper around sqlx::types::chrono::DateTime #[derive(sqlx::Type, Debug, Clone)] #[sqlx(transparent)] @@ -50,3 +74,36 @@ impl Serialize for DateTime { self.0.timestamp().serialize(serializer) } } + +#[derive(Clone)] +pub(crate) enum SIden<'a> { + Str(&'a str), + String(String), +} + +impl Iden for SIden<'_> { + fn unquoted(&self, s: &mut dyn std::fmt::Write) { + write!( + s, + "{}", + match self { + SIden::Str(s) => s, + SIden::String(s) => s.as_str(), + } + ) + .unwrap(); + } +} + +pub(crate) trait IntoTableNameAndSchema { + fn to_table_tuple<'b>(&self) -> (SIden<'b>, SIden<'b>); +} + +impl IntoTableNameAndSchema for String { + fn to_table_tuple<'b>(&self) -> (SIden<'b>, SIden<'b>) { + self.split('.') + .map(|s| SIden::String(s.to_string())) + .collect_tuple() + .expect("Malformed table name in IntoTableNameAndSchema") + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy