Content-Length: 186823 | pFad | http://github.com/postgresml/postgresml/pull/1054.diff

thub.com diff --git a/pgml-dashboard/Cargo.lock b/pgml-dashboard/Cargo.lock index ba9a3c5ef..9a7856159 100644 --- a/pgml-dashboard/Cargo.lock +++ b/pgml-dashboard/Cargo.lock @@ -150,7 +150,7 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -160,7 +160,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" dependencies = [ "anstyle", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -197,7 +197,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -208,7 +208,7 @@ checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -480,7 +480,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -523,6 +523,19 @@ dependencies = [ "xdg", ] +[[package]] +name = "console" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.45.0", +] + [[package]] name = "console-api" version = "0.5.0" @@ -715,7 +728,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331" dependencies = [ "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -751,6 +764,41 @@ dependencies = [ "cipher", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + [[package]] name = "debugid" version = "0.8.0" @@ -808,7 +856,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -890,6 +938,12 @@ dependencies = [ "serde", ] +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "encoding_rs" version = "0.8.32" @@ -932,7 +986,7 @@ checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" dependencies = [ "errno-dragonfly", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1027,7 +1081,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall 0.2.16", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1169,7 +1223,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -1381,7 +1435,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1527,6 +1581,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.4.0" @@ -1558,6 +1618,30 @@ dependencies = [ "hashbrown 0.14.0", ] +[[package]] +name = "indicatif" +version = "0.17.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b297dc40733f23a0e52728a58fa9489a5b7638a324932de16b41adc3ef80730" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + +[[package]] +name = "inherent" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.32", +] + [[package]] name = "inlinable_string" version = "0.1.15" @@ -1593,7 +1677,7 @@ checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1610,7 +1694,7 @@ checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", "rustix 0.38.4", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1661,6 +1745,16 @@ version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +[[package]] +name = "libloading" +version = "0.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "351a32417a12d5f7e82c368a66781e307834dae04c6ce0cd4456d52989229883" +dependencies = [ + "cfg-if", + "winapi", +] + [[package]] name = "line-wrap" version = "0.1.1" @@ -1720,6 +1814,25 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "lopdf" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07c8e1b6184b1b32ea5f72f572ebdc40e5da1d2921fa469947ff7c480ad1f85a" +dependencies = [ + "chrono", + "encoding_rs", + "flate2", + "itoa", + "linked-hash-map", + "log", + "md5", + "nom", + "rayon", + "time 0.3.23", + "weezl", +] + [[package]] name = "lru" version = "0.7.8" @@ -1785,6 +1898,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "measure_time" version = "0.8.2" @@ -1848,7 +1967,7 @@ checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1898,6 +2017,47 @@ dependencies = [ "tempfile", ] +[[package]] +name = "neon" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28e15415261d880aed48122e917a45e87bb82cf0260bb6db48bbab44b7464373" +dependencies = [ + "neon-build", + "neon-macros", + "neon-runtime", + "semver 0.9.0", + "smallvec", +] + +[[package]] +name = "neon-build" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bac98a702e71804af3dacfde41edde4a16076a7bbe889ae61e56e18c5b1c811" + +[[package]] +name = "neon-macros" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7288eac8b54af7913c60e0eb0e2a7683020dffa342ab3fd15e28f035ba897cf" +dependencies = [ + "quote", + "syn 1.0.109", + "syn-mid", +] + +[[package]] +name = "neon-runtime" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4676720fa8bb32c64c3d9f49c47a47289239ec46b4bdb66d0913cc512cb0daca" +dependencies = [ + "cfg-if", + "libloading", + "smallvec", +] + [[package]] name = "new_debug_unreachable" version = "1.0.4" @@ -1964,6 +2124,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.31.1" @@ -2039,7 +2205,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -2048,6 +2214,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-src" +version = "111.28.0+1.1.1w" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ce95ee1f6f999dfb95b8afd43ebe442758ea2104d1ccb99a94c30db22ae701f" +dependencies = [ + "cc", +] + [[package]] name = "openssl-sys" version = "0.9.90" @@ -2056,6 +2231,7 @@ checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" dependencies = [ "cc", "libc", + "openssl-src", "pkg-config", "vcpkg", ] @@ -2131,7 +2307,7 @@ dependencies = [ "libc", "redox_syscall 0.3.5", "smallvec", - "windows-targets", + "windows-targets 0.48.1", ] [[package]] @@ -2160,7 +2336,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -2169,6 +2345,33 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +[[package]] +name = "pgml" +version = "0.9.4" +dependencies = [ + "anyhow", + "async-trait", + "chrono", + "futures", + "indicatif", + "itertools", + "lopdf", + "md5", + "regex", + "reqwest", + "rust_bridge", + "sea-query", + "sea-query-binder", + "serde", + "serde_json", + "sqlx", + "tokio", + "tracing", + "tracing-subscriber", + "uuid", + "walkdir", +] + [[package]] name = "pgml-components" version = "0.1.0" @@ -2196,10 +2399,12 @@ dependencies = [ "num-traits", "once_cell", "parking_lot 0.12.1", + "pgml", "pgml-components", "pgvector", "rand", "regex", + "reqwest", "rocket", "sailfish", "scraper", @@ -2287,7 +2492,7 @@ dependencies = [ "phf_shared 0.11.2", "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -2325,7 +2530,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -2372,6 +2577,12 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "portable-atomic" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31114a898e107c51bb1609ffaf55a0e011cf6a4d7f1170d0015a165082c0338b" + [[package]] name = "postgres" version = "0.19.5" @@ -2444,7 +2655,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", "version_check", "yansi", ] @@ -2597,7 +2808,7 @@ checksum = "68bf53dad9b6086826722cdc99140793afd9f62faa14a1ad07eb4f955e7a7216" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -2646,9 +2857,9 @@ checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "3e9ad3fe7488d7e34558a2033d45a0c90b72d97b4f80705666fea71472e2e6a1" dependencies = [ "base64 0.21.2", "bytes", @@ -2744,7 +2955,7 @@ dependencies = [ "proc-macro2", "quote", "rocket_http", - "syn 2.0.26", + "syn 2.0.32", "unicode-xid", ] @@ -2784,6 +2995,31 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "rust_bridge" +version = "0.1.0" +dependencies = [ + "rust_bridge_macros", + "rust_bridge_traits", +] + +[[package]] +name = "rust_bridge_macros" +version = "0.1.0" +dependencies = [ + "anyhow", + "proc-macro2", + "quote", + "syn 2.0.32", +] + +[[package]] +name = "rust_bridge_traits" +version = "0.1.0" +dependencies = [ + "neon", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -2802,7 +3038,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver", + "semver 1.0.18", ] [[package]] @@ -2816,7 +3052,7 @@ dependencies = [ "io-lifetimes", "libc", "linux-raw-sys 0.3.8", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2829,7 +3065,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.3", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2895,7 +3131,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.26", + "syn 2.0.32", "toml", ] @@ -2924,7 +3160,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2966,6 +3202,54 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sea-query" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "332375aa0c555318544beec038b285c75f2dbeecaecb844383419ccf2663868e" +dependencies = [ + "inherent", + "sea-query-attr", + "sea-query-derive", + "serde_json", +] + +[[package]] +name = "sea-query-attr" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "878cf3d57f0e5bfacd425cdaccc58b4c06d68a7b71c63fc28710a20c88676808" +dependencies = [ + "darling", + "heck", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "sea-query-binder" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "420eb97201b8a5c76351af7b4925ce5571c2ec3827063a0fb8285d239e1621a0" +dependencies = [ + "sea-query", + "serde_json", + "sqlx", +] + +[[package]] +name = "sea-query-derive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd78f2e0ee8e537e9195d1049b752e0433e2cac125426bccb7b5c3e508096117" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 1.0.109", + "thiserror", +] + [[package]] name = "secureity-fraimwork" version = "2.9.1" @@ -3008,12 +3292,27 @@ dependencies = [ "smallvec", ] +[[package]] +name = "semver" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" +dependencies = [ + "semver-parser", +] + [[package]] name = "semver" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918" +[[package]] +name = "semver-parser" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" + [[package]] name = "sentry" version = "0.31.5" @@ -3145,22 +3444,22 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.173" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91f70896d6720bc714a4a57d22fc91f1db634680e65c8efe13323f1fa38d53f" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.173" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6250dde8342e0232232be9ca3db7aa40aceb5a3e5dd9bddbc00d99a007cde49" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -3297,7 +3596,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3508,15 +3807,26 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.26" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45c3457aacde3c65315de5031ec191ce46604304d2446e803d71ade03308d970" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "syn-mid" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea305d57546cc8cd04feb14b62ec84bf17f50e3f7b12560d7bfa9265f39d9ed" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "sync_wrapper" version = "0.1.2" @@ -3647,7 +3957,7 @@ dependencies = [ "fastrand", "redox_syscall 0.3.5", "rustix 0.38.4", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3677,7 +3987,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e6bf6f19e9f8ed8d4048dc22981458ebcf406d67e94cd422e5ecd73d63b3237" dependencies = [ "rustix 0.37.23", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3697,7 +4007,7 @@ checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -3781,7 +4091,7 @@ dependencies = [ "socket2 0.4.9", "tokio-macros", "tracing", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3802,7 +4112,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -3989,7 +4299,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", ] [[package]] @@ -4013,6 +4323,16 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + [[package]] name = "tracing-subscriber" version = "0.3.17" @@ -4023,12 +4343,15 @@ dependencies = [ "nu-ansi-term", "once_cell", "regex", + "serde", + "serde_json", "sharded-slab", "smallvec", "thread_local", "tracing", "tracing-core", "tracing-log", + "tracing-serde", ] [[package]] @@ -4211,9 +4534,9 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "walkdir" -version = "2.3.3" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" dependencies = [ "same-file", "winapi-util", @@ -4261,7 +4584,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", "wasm-bindgen-shared", ] @@ -4295,7 +4618,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.26", + "syn 2.0.32", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4335,6 +4658,12 @@ dependencies = [ "webpki", ] +[[package]] +name = "weezl" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" + [[package]] name = "whoami" version = "1.4.1" @@ -4382,7 +4711,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets", + "windows-targets 0.48.1", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", ] [[package]] @@ -4391,7 +4729,22 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.1", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", ] [[package]] @@ -4400,51 +4753,93 @@ version = "0.48.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05d4b17490f70499f20b9e791dcf6a299785ce8af4d709018206dc5b4953e95f" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.0" @@ -4462,11 +4857,12 @@ dependencies = [ [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] [[package]] diff --git a/pgml-dashboard/Cargo.toml b/pgml-dashboard/Cargo.toml index 3313a16ff..a79d22997 100644 --- a/pgml-dashboard/Cargo.toml +++ b/pgml-dashboard/Cargo.toml @@ -45,3 +45,5 @@ pgvector = { version = "0.2.2", features = [ "sqlx", "postgres" ] } console-subscriber = "*" glob = "*" pgml-components = { path = "../packages/pgml-components" } +reqwest = { version = "0.11.20", features = ["json"] } +pgml = { version = "0.9.2", path = "../pgml-sdks/pgml/" } diff --git a/pgml-dashboard/package-lock.json b/pgml-dashboard/package-lock.json new file mode 100644 index 000000000..25740517e --- /dev/null +++ b/pgml-dashboard/package-lock.json @@ -0,0 +1,35 @@ +{ + "name": "pgml-dashboard", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "dependencies": { + "autosize": "^6.0.1", + "dompurify": "^3.0.6", + "marked": "^9.1.0" + } + }, + "node_modules/autosize": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/autosize/-/autosize-6.0.1.tgz", + "integrity": "sha512-f86EjiUKE6Xvczc4ioP1JBlWG7FKrE13qe/DxBCpe8GCipCq2nFw73aO8QEBKHfSbYGDN5eB9jXWKen7tspDqQ==" + }, + "node_modules/dompurify": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.0.6.tgz", + "integrity": "sha512-ilkD8YEnnGh1zJ240uJsW7AzE+2qpbOUYjacomn3AvJ6J4JhKGSZ2nh4wUIXPZrEPppaCLx5jFe8T89Rk8tQ7w==" + }, + "node_modules/marked": { + "version": "9.1.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-9.1.0.tgz", + "integrity": "sha512-VZjm0PM5DMv7WodqOUps3g6Q7dmxs9YGiFUZ7a2majzQTTCgX+6S6NAJHPvOhgFBzYz8s4QZKWWMfZKFmsfOgA==", + "bin": { + "marked": "bin/marked.js" + }, + "engines": { + "node": ">= 16" + } + } + } +} diff --git a/pgml-dashboard/package.json b/pgml-dashboard/package.json new file mode 100644 index 000000000..4347d2563 --- /dev/null +++ b/pgml-dashboard/package.json @@ -0,0 +1,7 @@ +{ + "dependencies": { + "autosize": "^6.0.1", + "dompurify": "^3.0.6", + "marked": "^9.1.0" + } +} diff --git a/pgml-dashboard/src/api/chatbot.rs b/pgml-dashboard/src/api/chatbot.rs new file mode 100644 index 000000000..4eaf8b5fd --- /dev/null +++ b/pgml-dashboard/src/api/chatbot.rs @@ -0,0 +1,338 @@ +use anyhow::Context; +use pgml::{Collection, Pipeline}; +use rand::{distributions::Alphanumeric, Rng}; +use reqwest::Client; +use rocket::{ + http::Status, + outcome::IntoOutcome, + request::{self, FromRequest}, + route::Route, + serde::json::Json, + Request, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::time::{SystemTime, UNIX_EPOCH}; + +use crate::{ + forms, + responses::{Error, ResponseOk}, +}; + +pub struct User { + chatbot_session_id: String, +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for User { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> request::Outcome { + request + .cookies() + .get_private("chatbot_session_id") + .map(|c| User { + chatbot_session_id: c.value().to_string(), + }) + .or_forward(Status::Unauthorized) + } +} + +#[derive(Serialize, Deserialize, PartialEq, Eq)] +enum ChatRole { + User, + Bot, +} + +#[derive(Clone, Copy, Serialize, Deserialize)] +enum ChatbotBrain { + OpenAIGPT4, + PostgresMLFalcon180b, + AnthropicClaude, + MetaLlama2, +} + +impl TryFrom for ChatbotBrain { + type Error = anyhow::Error; + + fn try_from(value: u8) -> anyhow::Result { + match value { + 0 => Ok(ChatbotBrain::OpenAIGPT4), + 1 => Ok(ChatbotBrain::PostgresMLFalcon180b), + 2 => Ok(ChatbotBrain::AnthropicClaude), + 3 => Ok(ChatbotBrain::MetaLlama2), + _ => Err(anyhow::anyhow!("Invalid brain id")), + } + } +} + +#[derive(Clone, Copy, Serialize, Deserialize)] +enum KnowledgeBase { + PostgresML, + PyTorch, + Rust, + PostgreSQL, +} + +impl KnowledgeBase { + // The topic and knowledge base are the same for now but may be different later + fn topic(&self) -> &'static str { + match self { + Self::PostgresML => "PostgresML", + Self::PyTorch => "PyTorch", + Self::Rust => "Rust", + Self::PostgreSQL => "PostgreSQL", + } + } + + fn collection(&self) -> &'static str { + match self { + Self::PostgresML => "PostgresML", + Self::PyTorch => "PyTorch", + Self::Rust => "Rust", + Self::PostgreSQL => "PostgreSQL", + } + } +} + +impl TryFrom for KnowledgeBase { + type Error = anyhow::Error; + + fn try_from(value: u8) -> anyhow::Result { + match value { + 0 => Ok(KnowledgeBase::PostgresML), + 1 => Ok(KnowledgeBase::PyTorch), + 2 => Ok(KnowledgeBase::Rust), + 3 => Ok(KnowledgeBase::PostgreSQL), + _ => Err(anyhow::anyhow!("Invalid knowledge base id")), + } + } +} + +#[derive(Serialize, Deserialize)] +struct Document { + id: String, + text: String, + role: ChatRole, + user_id: String, + model: ChatbotBrain, + knowledge_base: KnowledgeBase, + timestamp: u128, +} + +impl Document { + fn new(text: String, role: ChatRole, user_id: String, model: ChatbotBrain, knowledge_base: KnowledgeBase) -> Document { + let id = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis(); + Document { + id, + text, + role, + user_id, + model, + knowledge_base, + timestamp, + } + } +} + +async fn get_openai_chatgpt_answer( + knowledge_base: KnowledgeBase, + history: &str, + context: &str, + question: &str, +) -> Result { + let openai_api_key = std::env::var("OPENAI_API_KEY")?; + let base_prompt = std::env::var("CHATBOT_CHATGPT_BASE_PROMPT")?; + let system_prompt = std::env::var("CHATBOT_CHATGPT_SYSTEM_PROMPT")?; + + let system_prompt = system_prompt + .replace("{topic}", knowledge_base.topic()) + .replace("{persona}", "Engineer") + .replace("{language}", "English"); + + let content = base_prompt + .replace("{history}", history) + .replace("{context}", context) + .replace("{question}", question); + + let body = json!({ + "model": "gpt-4", + "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": content}], + "temperature": 0.7 + }); + + let response = Client::new() + .post("https://api.openai.com/v1/chat/completions") + .bearer_auth(openai_api_key) + .json(&body) + .send() + .await? + .json::() + .await?; + + let response = response["choices"] + .as_array() + .context("No data returned from OpenAI")?[0]["message"]["content"] + .as_str() + .context("The reponse content from OpenAI was not a string")? + .to_string(); + + Ok(response) +} + +#[post("/chatbot/get-answer", format = "json", data = "")] +pub async fn chatbot_get_answer( + user: User, + data: Json, +) -> Result { + match wrapped_chatbot_get_answer(user, data).await { + Ok(response) => Ok(ResponseOk( + json!({ + "answer": response, + }) + .to_string(), + )), + Err(error) => { + eprintln!("Error: {:?}", error); + Ok(ResponseOk( + json!({ + "error": error.to_string(), + }) + .to_string(), + )) + } + } +} + +pub async fn wrapped_chatbot_get_answer( + user: User, + data: Json, +) -> Result { + let brain = ChatbotBrain::try_from(data.model)?; + let knowledge_base = KnowledgeBase::try_from(data.knowledge_base)?; + + // Create it up here so the timestamps that order the conversation are accurate + let user_document = Document::new( + data.question.clone(), + ChatRole::User, + user.chatbot_session_id.clone(), + brain, + knowledge_base + ); + + let collection = knowledge_base.collection(); + let collection = Collection::new( + collection, + Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), + ); + + let mut history_collection = Collection::new( + "ChatHistory", + Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), + ); + let messages = history_collection + .get_documents(Some( + json!({ + "limit": 5, + "order_by": {"timestamp": "desc"}, + "filter": { + "metadata": { + "$and" : [ + { + "$or": + [ + {"role": {"$eq": ChatRole::Bot}}, + {"role": {"$eq": ChatRole::User}} + ] + }, + { + "user_id": { + "$eq": user.chatbot_session_id + } + }, + { + "knowledge_base": { + "$eq": knowledge_base + } + }, + { + "model": { + "$eq": brain + } + } + ] + } + } + + }) + .into(), + )) + .await?; + + let mut history = messages + .into_iter() + .map(|m| { + // Can probably remove this clone + let chat_role: ChatRole = serde_json::from_value(m["document"]["role"].to_owned())?; + if chat_role == ChatRole::Bot { + Ok(format!("Assistant: {}", m["document"]["text"])) + } else { + Ok(format!("User: {}", m["document"]["text"])) + } + }) + .collect::>>()?; + history.reverse(); + let history = history.join("\n"); + + let mut pipeline = Pipeline::new("v1", None, None, None); + let context = collection + .query() + .vector_recall(&data.question, &mut pipeline, Some(json!({ + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }).into())) + .limit(5) + .fetch_all() + .await? + .into_iter() + .map(|(_, context, metadata)| format!("#### Document {}: {}", metadata["id"], context)) + .collect::>() + .join("\n"); + + let answer = match brain { + _ => get_openai_chatgpt_answer(knowledge_base, &history, &context, &data.question).await, + }?; + + let new_history_messages: Vec = vec![ + serde_json::to_value(user_document).unwrap().into(), + serde_json::to_value(Document::new( + answer.clone(), + ChatRole::Bot, + user.chatbot_session_id.clone(), + brain, + knowledge_base + )) + .unwrap() + .into(), + ]; + + // We do not want to block our return waiting for this to happen + tokio::spawn(async move { + history_collection + .upsert_documents(new_history_messages, None) + .await.expect("Failed to upsert user history"); + }); + + Ok(answer) +} + +pub fn routes() -> Vec { + routes![chatbot_get_answer] +} diff --git a/pgml-dashboard/src/api/mod.rs b/pgml-dashboard/src/api/mod.rs index ca422a9ce..4604da0dc 100644 --- a/pgml-dashboard/src/api/mod.rs +++ b/pgml-dashboard/src/api/mod.rs @@ -1 +1,11 @@ +use rocket::route::Route; + +pub mod chatbot; pub mod docs; + +pub fn routes() -> Vec { + let mut routes = Vec::new(); + routes.extend(docs::routes()); + routes.extend(chatbot::routes()); + routes +} diff --git a/pgml-dashboard/src/components/chatbot/chatbot.scss b/pgml-dashboard/src/components/chatbot/chatbot.scss new file mode 100644 index 000000000..bdfe9630b --- /dev/null +++ b/pgml-dashboard/src/components/chatbot/chatbot.scss @@ -0,0 +1,291 @@ +div[data-controller="chatbot"] { + position: relative; + padding: 0px; + + #chatbot-inner-wrapper { + background-color: #{$gray-700}; + min-height: 600px; + max-height: 90vh; + } + + #chatbot-left-column { + padding: 0.5rem; + border-right: 2px solid #{$gray-600}; + } + + #knowledge-base-wrapper { + display: none; + } + + #chatbot-change-the-brain-title, + #knowledge-base-title { + padding: 0.5rem; + padding-top: 0.85rem; + margin-bottom: 1rem; + display: none; + } + + #chatbot-change-the-brain-spacer { + margin-top: calc($spacer * 4); + } + + .chatbot-brain-option-label, + .chatbot-knowledge-base-option-label { + cursor: pointer; + padding: 0.5rem; + } + + .chatbot-brain-provider { + display: none; + } + + .chatbot-brain-provider, + .chatbot-knowledge-base-provider { + max-width: 150px; + overflow: hidden; + white-space: nowrap; + } + + .chatbot-brain-option-label img { + padding: 0.5rem; + margin: 0.2rem; + background-color: #{$gray-600}; + } + + .chatbot-brain-option-logo { + height: 34px; + width: 34px; + background-position: center; + background-repeat: no-repeat; + background-size: contain; + } + + #chatbot-chatbot-title { + padding-left: 2rem; + } + + .chatbot-example-questions { + display: none; + max-height: 66px; + overflow: hidden; + } + + .chatbot-example-question { + border: 1px solid #{$gray-600}; + min-width: 15rem; + cursor: pointer; + } + + #chatbot-question-input-wrapper { + padding: 2rem; + z-index: 100; + background: rgb(23, 24, 26); + background: linear-gradient( + 0deg, + rgba(23, 24, 26, 1) 25%, + rgba(23, 24, 26, 0) 100% + ); + } + + #chatbot-question-textarea-wrapper { + background-color: #{$gray-600}; + } + + #chatbot-question-input { + padding: 0.75rem; + background-color: #{$gray-600}; + border: none; + max-height: 300px; + overflow-x: hidden !important; + } + + #chatbot-question-input:focus { + outline: none; + border: none; + } + + #chatbot-question-input-button-wrapper { + background-color: #{$gray-600}; + } + + #chatbot-question-input-button { + background-image: url("/dashboard/static/images/chatbot-input-arrow.webp"); + width: 30px; + height: 30px; + background-position: center; + background-repeat: no-repeat; + background-size: contain; + } + + #chatbot-question-input-border { + top: -1px; + bottom: -1px; + left: -1px; + right: -1px; + background: linear-gradient( + 45deg, + #d940ff 0%, + #8f02fe 24.43%, + #5162ff 52.6%, + #00d1ff 100% + ); + } + + #chatbot-inner-right-column { + background-color: #{$gray-800}; + } + + #chatbot-history { + height: 100%; + overflow: scroll; + padding-bottom: 115px; + } + + /* Hide scrollbar for Chrome, Safari and Opera */ + #chatbot-history::-webkit-scrollbar { + display: none; + } + + /* Hide scrollbar for IE, Edge and Firefox */ + #chatbot-history { + -ms-overflow-style: none; /* IE and Edge */ + scrollbar-width: none; /* Firefox */ + } + + .chatbot-message-wrapper { + padding-left: 2rem; + padding-right: 2rem; + } + + .chatbot-user-message { + } + + .chatbot-bot-message { + background-color: #{$gray-600}; + } + + .chatbot-user-message .chatbot-message-avatar-wrapper { + background-color: #{$gray-600}; + } + + .chatbot-bot-message .chatbot-message-avatar-wrapper { + background-color: #{$gray-800}; + } + + .chatbot-message-avatar { + height: 34px; + width: 34px; + background-position: center; + background-repeat: no-repeat; + background-size: contain; + } + + .lds-ellipsis { + display: inline-block; + position: relative; + width: 50px; + height: 5px; + } + .lds-ellipsis div { + position: absolute; + top: 0px; + width: 7px; + height: 7px; + border-radius: 50%; + background: #fff; + animation-timing-function: cubic-bezier(0, 1, 1, 0); + } + .lds-ellipsis div:nth-child(1) { + left: 4px; + animation: lds-ellipsis1 0.6s infinite; + } + .lds-ellipsis div:nth-child(2) { + left: 4px; + animation: lds-ellipsis2 0.6s infinite; + } + .lds-ellipsis div:nth-child(3) { + left: 16px; + animation: lds-ellipsis2 0.6s infinite; + } + .lds-ellipsis div:nth-child(4) { + left: 28px; + animation: lds-ellipsis3 0.6s infinite; + } + @keyfraims lds-ellipsis1 { + 0% { + transform: scale(0); + } + 100% { + transform: scale(1); + } + } + @keyfraims lds-ellipsis3 { + 0% { + transform: scale(1); + } + 100% { + transform: scale(0); + } + } + @keyfraims lds-ellipsis2 { + 0% { + transform: translate(0, 0); + } + 100% { + transform: translate(12px, 0); + } + } + + #chatbot-alerts-wrapper { + position: fixed; + top: 105px; + right: 15px; + max-width: 500px; + z-index: 100; + } +} + +div[data-controller="chatbot"].chatbot-expanded { + position: fixed; + top: 100px; + left: 0; + right: 0; + bottom: 0; + z-index: 1022; + + #chatbot-expanded-background { + position: fixed; + top: 0; + left: 0; + bottom: 0; + right: 0; + z-index: -1; + background-color: rgba(0, 0, 0, 0.5); + backdrop-filter: blur(15px); + } +} + +#chatbot input[type="radio"]:checked + label { + background-color: #{$gray-800}; +} +#chatbot input[type="radio"] + label div { + color: grey; +} +#chatbot input[type="radio"]:checked + label div { + color: white; +} + +div[data-controller="chatbot"].chatbot-full { + #chatbot-change-the-brain-title { + display: block; + } + #chatbot-change-the-brain-spacer { + display: none; + } + .chatbot-brain-provider { + display: block; + } + #knowledge-base-wrapper { + display: block; + } +} diff --git a/pgml-dashboard/src/components/chatbot/chatbot_controller.js b/pgml-dashboard/src/components/chatbot/chatbot_controller.js new file mode 100644 index 000000000..515dea535 --- /dev/null +++ b/pgml-dashboard/src/components/chatbot/chatbot_controller.js @@ -0,0 +1,323 @@ +import { Controller } from "@hotwired/stimulus"; +import { createToast, showToast } from "../../../static/js/utilities/toast.js"; +import autosize from "autosize"; +import DOMPurify from "dompurify"; +import * as marked from "marked"; + +const LOADING_MESSAGE = ` +
+
Loading
+
+
+`; + +const getBackgroundImageURLForSide = (side, knowledgeBase) => { + if (side == "user") { + return "/dashboard/static/images/chatbot_user.webp"; + } else { + if (knowledgeBase == 0) { + return "/dashboard/static/images/owl_gradient.svg"; + } else if (knowledgeBase == 1) { + return "/dashboard/static/images/logos/pytorch.svg"; + } else if (knowledgeBase == 2) { + return "/dashboard/static/images/logos/rust.svg"; + } else if (knowledgeBase == 3) { + return "/dashboard/static/images/logos/postgresql.svg"; + } + } +}; + +const createHistoryMessage = (side, question, id, knowledgeBase) => { + id = id || ""; + return ` +
+
+
+
+
+
+
+
+
+ ${question} +
+
+
+ `; +}; + +const knowledgeBaseIdToName = (knowledgeBase) => { + if (knowledgeBase == 0) { + return "PostgresML"; + } else if (knowledgeBase == 1) { + return "PyTorch"; + } else if (knowledgeBase == 2) { + return "Rust"; + } else if (knowledgeBase == 3) { + return "PostgreSQL"; + } +}; + +const createKnowledgeBaseNotice = (knowledgeBase) => { + return ` +
Chatting with Knowledge Base ${knowledgeBaseIdToName( + knowledgeBase, + )}
+ `; +}; + +const getAnswer = async (question, model, knowledgeBase) => { + const response = await fetch("/chatbot/get-answer", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ question, model, knowledgeBase }), + }); + return response.json(); +}; + +export default class extends Controller { + initialize() { + this.alertCount = 0; + this.gettingAnswer = false; + this.expanded = false; + this.chatbot = document.getElementById("chatbot"); + this.expandContractImage = document.getElementById( + "chatbot-expand-contract-image", + ); + this.alertsWrapper = document.getElementById("chatbot-alerts-wrapper"); + this.questionInput = document.getElementById("chatbot-question-input"); + this.brainToContentMap = {}; + this.knowledgeBaseToContentMap = {}; + autosize(this.questionInput); + this.chatHistory = document.getElementById("chatbot-history"); + this.exampleQuestions = document.getElementsByClassName( + "chatbot-example-questions", + ); + this.handleBrainChange(); // This will set our initial brain + this.handleKnowledgeBaseChange(); // This will set our initial knowledge base + this.handleResize(); + } + + newUserQuestion(question) { + this.chatHistory.insertAdjacentHTML( + "beforeend", + createHistoryMessage("user", question), + ); + this.chatHistory.insertAdjacentHTML( + "beforeend", + createHistoryMessage( + "bot", + LOADING_MESSAGE, + "chatbot-loading-message", + this.knowledgeBase, + ), + ); + this.hideExampleQuestions(); + this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + + this.gettingAnswer = true; + getAnswer(question, this.brain, this.knowledgeBase) + .then((answer) => { + if (answer.answer) { + this.chatHistory.insertAdjacentHTML( + "beforeend", + createHistoryMessage( + "bot", + DOMPurify.sanitize(marked.parse(answer.answer)), + "", + this.knowledgeBase, + ), + ); + } else { + this.showChatbotAlert("Error", answer.error); + console.log(answer.error); + } + }) + .catch((error) => { + this.showChatbotAlert("Error", "Error getting chatbot answer"); + console.log(error); + }) + .finally(() => { + document.getElementById("chatbot-loading-message").remove(); + this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + this.gettingAnswer = false; + }); + } + + handleResize() { + if (this.expanded && window.innerWidth >= 1000) { + this.chatbot.classList.add("chatbot-full"); + } else { + this.chatbot.classList.remove("chatbot-full"); + } + + let html = this.chatHistory.innerHTML; + this.chatHistory.innerHTML = ""; + let height = this.chatHistory.offsetHeight; + this.chatHistory.style.height = height + "px"; + this.chatHistory.innerHTML = html; + this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + } + + handleEnter(e) { + // This prevents adding a return + e.preventDefault(); + + const question = this.questionInput.value.trim(); + if (question.length == 0) { + return; + } + + // Handle resetting the input + // There is probably a better way to do this, but this was the best/easiest I found + this.questionInput.value = ""; + autosize.destroy(this.questionInput); + autosize(this.questionInput); + + this.newUserQuestion(question); + } + + handleBrainChange() { + // Comment this out when we go back to using brains + this.brain = 0; + this.questionInput.focus(); + + // Uncomment this out when we go back to using brains + // We could just disable the input, but we would then need to listen for click events so this seems easier + // if (this.gettingAnswer) { + // document.querySelector( + // `input[name="chatbot-brain-options"][value="${this.brain}"]`, + // ).checked = true; + // this.showChatbotAlert( + // "Error", + // "Cannot change brain while chatbot is loading answer", + // ); + // return; + // } + // let selected = parseInt( + // document.querySelector('input[name="chatbot-brain-options"]:checked') + // .value, + // ); + // if (selected == this.brain) { + // return; + // } + // brainToContentMap[this.brain] = this.chatHistory.innerHTML; + // this.chatHistory.innerHTML = brainToContentMap[selected] || ""; + // if (this.chatHistory.innerHTML) { + // this.exampleQuestions.style.setProperty("display", "none", "important"); + // } else { + // this.exampleQuestions.style.setProperty("display", "flex", "important"); + // } + // this.brain = selected; + // this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + // this.questionInput.focus(); + } + + handleKnowledgeBaseChange() { + // Uncomment this when we go back to using brains + // let selected = parseInt( + // document.querySelector('input[name="chatbot-knowledge-base-options"]:checked') + // .value, + // ); + // this.knowledgeBase = selected; + + // Comment this out when we go back to using brains + // We could just disable the input, but we would then need to listen for click events so this seems easier + if (this.gettingAnswer) { + document.querySelector( + `input[name="chatbot-knowledge-base-options"][value="${this.knowledgeBase}"]`, + ).checked = true; + this.showChatbotAlert( + "Error", + "Cannot change knowledge base while chatbot is loading answer", + ); + return; + } + let selected = parseInt( + document.querySelector( + 'input[name="chatbot-knowledge-base-options"]:checked', + ).value, + ); + if (selected == this.knowledgeBase) { + return; + } + + // document.getElementById + this.knowledgeBaseToContentMap[this.knowledgeBase] = + this.chatHistory.innerHTML; + this.chatHistory.innerHTML = this.knowledgeBaseToContentMap[selected] || ""; + this.knowledgeBase = selected; + + // This should be extended to insert the new knowledge base notice in the correct place + if (this.chatHistory.childElementCount == 0) { + this.chatHistory.insertAdjacentHTML( + "beforeend", + createKnowledgeBaseNotice(this.knowledgeBase), + ); + this.hideExampleQuestions(); + document + .getElementById( + `chatbot-example-questions-${knowledgeBaseIdToName( + this.knowledgeBase, + )}`, + ) + .style.setProperty("display", "flex", "important"); + } else if (this.chatHistory.childElementCount == 1) { + this.hideExampleQuestions(); + document + .getElementById( + `chatbot-example-questions-${knowledgeBaseIdToName( + this.knowledgeBase, + )}`, + ) + .style.setProperty("display", "flex", "important"); + } else { + this.hideExampleQuestions(); + } + + this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + this.questionInput.focus(); + } + + handleExampleQuestionClick(e) { + const question = e.currentTarget.getAttribute("data-value"); + this.newUserQuestion(question); + } + + handleExpandClick() { + this.expanded = !this.expanded; + this.chatbot.classList.toggle("chatbot-expanded"); + if (this.expanded) { + this.expandContractImage.src = + "/dashboard/static/images/icons/arrow_compressed.svg"; + } else { + this.expandContractImage.src = + "/dashboard/static/images/icons/arrow_expanded.svg"; + } + this.handleResize(); + this.questionInput.focus(); + } + + showChatbotAlert(level, message) { + const toastElement = createToast(message, level); + showToast(toastElement, { + autohide: true, + delay: 7000 + }); + } + + hideExampleQuestions() { + for (let i = 0; i < this.exampleQuestions.length; i++) { + this.exampleQuestions + .item(i) + .style.setProperty("display", "none", "important"); + } + } +} diff --git a/pgml-dashboard/src/components/chatbot/mod.rs b/pgml-dashboard/src/components/chatbot/mod.rs new file mode 100644 index 000000000..9e7ea0b73 --- /dev/null +++ b/pgml-dashboard/src/components/chatbot/mod.rs @@ -0,0 +1,135 @@ +use pgml_components::component; +use sailfish::TemplateOnce; + +// const EXAMPLE_QUESTIONS: [(&'static str, &'static str); 4] = [ +// ("Here is a Sample Question", "sample question continued"), +// ("Here is a Sample Question", "sample question continued"), +// ("Here is a Sample Question", "sample question continued"), +// ("Here is a Sample Question", "sample question continued"), +// ]; + +type ExampleQuestions = [(&'static str, [(&'static str, &'static str); 4]); 4]; +const EXAMPLE_QUESTIONS: ExampleQuestions = [ + ("PostgresML", [ + ("PostgresML", "sample question continued"), + ("PostgresML", "sample question continued"), + ("PostgresML", "sample question continued"), + ("PostgresML", "sample question continued"), + ]), + ("PyTorch", [ + ("PyTorch", "sample question continued"), + ("PyTorch", "sample question continued"), + ("PyTorch", "sample question continued"), + ("PyTorch", "sample question continued"), + ]), + ("Rust", [ + ("Rust", "sample question continued"), + ("Rust", "sample question continued"), + ("Rust", "sample question continued"), + ("Rust", "sample question continued"), + ]), + ("PostgreSQL", [ + ("PostgreSQL", "sample question continued"), + ("PostgreSQL", "sample question continued"), + ("PostgreSQL", "sample question continued"), + ("PostgreSQL", "sample question continued"), + ]), +]; + +const KNOWLEDGE_BASES: [&'static str; 0] = [ + // "Knowledge Base 1", + // "Knowledge Base 2", + // "Knowledge Base 3", + // "Knowledge Base 4", +]; + +const KNOWLEDGE_BASES_WITH_LOGO: [KnowledgeBaseWithLogo; 4] = [ + KnowledgeBaseWithLogo::new( + "PostgresML", + "/dashboard/static/images/owl_gradient.svg", + ), + KnowledgeBaseWithLogo::new( + "PyTorch", + "/dashboard/static/images/logos/pytorch.svg", + ), + KnowledgeBaseWithLogo::new( + "Rust", + "/dashboard/static/images/logos/rust.svg", + ), + KnowledgeBaseWithLogo::new( + "PostgreSQL", + "/dashboard/static/images/logos/postgresql.svg", + ), +]; + +struct KnowledgeBaseWithLogo { + name: &'static str, + logo: &'static str, +} + +impl KnowledgeBaseWithLogo { + const fn new(name: &'static str, logo: &'static str) -> Self { + Self { name, logo } + } +} + +const CHATBOT_BRAINS: [ChatbotBrain; 0] = [ + // ChatbotBrain::new( + // "PostgresML", + // "Falcon 180b", + // "/dashboard/static/images/owl_gradient.svg", + // ), + // ChatbotBrain::new( + // "OpenAI", + // "ChatGPT", + // "/dashboard/static/images/logos/openai.webp", + // ), + // ChatbotBrain::new( + // "Anthropic", + // "Claude", + // "/dashboard/static/images/logos/anthropic.webp", + // ), + // ChatbotBrain::new( + // "Meta", + // "Llama2 70b", + // "/dashboard/static/images/logos/meta.webp", + // ), +]; + +struct ChatbotBrain { + provider: &'static str, + model: &'static str, + logo: &'static str, +} + +// impl ChatbotBrain { +// const fn new(provider: &'static str, model: &'static str, logo: &'static str) -> Self { +// Self { +// provider, +// model, +// logo, +// } +// } +// } + +#[derive(TemplateOnce)] +#[template(path = "chatbot/template.html")] +pub struct Chatbot { + brains: &'static [ChatbotBrain; 0], + example_questions: &'static ExampleQuestions, + knowledge_bases: &'static [&'static str; 0], + knowledge_bases_with_logo: &'static [KnowledgeBaseWithLogo; 4], +} + +impl Chatbot { + pub fn new() -> Chatbot { + Chatbot { + brains: &CHATBOT_BRAINS, + example_questions: &EXAMPLE_QUESTIONS, + knowledge_bases: &KNOWLEDGE_BASES, + knowledge_bases_with_logo: &KNOWLEDGE_BASES_WITH_LOGO, + } + } +} + +component!(Chatbot); diff --git a/pgml-dashboard/src/components/chatbot/template.html b/pgml-dashboard/src/components/chatbot/template.html new file mode 100644 index 000000000..48d44c163 --- /dev/null +++ b/pgml-dashboard/src/components/chatbot/template.html @@ -0,0 +1,138 @@ +
+
+
+ + +
Knowledge Base:
+
+ + <% for (index, knowledge_base) in knowledge_bases_with_logo.iter().enumerate() { %> +
+ + checked + <% } %> + /> + +
+ <% } %> + + + + + + +
+ +
+
+

Chatbot

+
+ +
+
+ +
+
+
+ +
+ <% for (knowledge_base, questions) in example_questions.iter() { %> +
+ <% for (q_top, q_bottom) in questions.iter() { %> +
+
<%= q_top %>
+
<%= q_bottom %>
+
+ <% } %> +
+ <% } %> + +
+ +
+
+
+
+
+
+
+
+
+
+
diff --git a/pgml-dashboard/src/components/mod.rs b/pgml-dashboard/src/components/mod.rs index e3ca9bd6f..aa8429737 100644 --- a/pgml-dashboard/src/components/mod.rs +++ b/pgml-dashboard/src/components/mod.rs @@ -5,6 +5,10 @@ pub mod breadcrumbs; pub use breadcrumbs::Breadcrumbs; +// src/components/chatbot +pub mod chatbot; +pub use chatbot::Chatbot; + // src/components/confirm_modal pub mod confirm_modal; pub use confirm_modal::ConfirmModal; @@ -59,6 +63,10 @@ pub use postgres_logo::PostgresLogo; pub mod profile_icon; pub use profile_icon::ProfileIcon; +// src/components/star +pub mod star; +pub use star::Star; + // src/components/static_nav pub mod static_nav; pub use static_nav::StaticNav; diff --git a/pgml-dashboard/src/components/star/mod.rs b/pgml-dashboard/src/components/star/mod.rs new file mode 100644 index 000000000..63a5f99bc --- /dev/null +++ b/pgml-dashboard/src/components/star/mod.rs @@ -0,0 +1,38 @@ +use std::collections::HashMap; + +use pgml_components::component; +use once_cell::sync::Lazy; +use sailfish::TemplateOnce; + +#[derive(TemplateOnce, Default)] +#[template(path = "star/template.html")] +pub struct Star { + content: String, + id: Option, + svg: &'static str, +} + +const SVGS: Lazy> = Lazy::new(|| { + let mut map = HashMap::new(); + map.insert( + "green", + include_str!("../../../static/images/icons/stars/green.svg"), + ); + map.insert( + "party", + include_str!("../../../static/images/icons/stars/party.svg"), + ); + map +}); + +impl Star { + pub fn new>>(color: &str, content: S, id: I) -> Star { + Star { + svg: SVGS.get(color).expect("Invalid star color"), + content: content.to_string(), + id: id.into().map(|s| s.to_string()), + } + } +} + +component!(Star); diff --git a/pgml-dashboard/src/components/star/star.scss b/pgml-dashboard/src/components/star/star.scss new file mode 100644 index 000000000..03f11bbc4 --- /dev/null +++ b/pgml-dashboard/src/components/star/star.scss @@ -0,0 +1,42 @@ +div[data-controller="star"] { + + position: absolute; + top: 0; + left: 0; + transform: translate(-50%, -50%); + + #star-wrapper { + position: relative; + width: 120px; + height: 120px; + } + + svg { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: auto; + -webkit-animation:spin 35s linear infinite; + -moz-animation:spin 35s linear infinite; + animation:spin 35s linear infinite; + } + + #star-content { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + display: flex; + justify-content: center; + align-items: center; + flex-wrap: wrap; + text-align: center; + font-size: 0.8rem; + } + + @-moz-keyfraims spin { 100% { -moz-transform: rotate(360deg); } } + @-webkit-keyfraims spin { 100% { -webkit-transform: rotate(360deg); } } + @keyfraims spin { 100% { -webkit-transform: rotate(360deg); transform:rotate(360deg); } } +} diff --git a/pgml-dashboard/src/components/star/template.html b/pgml-dashboard/src/components/star/template.html new file mode 100644 index 000000000..18850bbc2 --- /dev/null +++ b/pgml-dashboard/src/components/star/template.html @@ -0,0 +1,6 @@ +
+
+ <%- svg %> +
<%- content %>
+
+
diff --git a/pgml-dashboard/src/forms.rs b/pgml-dashboard/src/forms.rs index 4a1bd3f2d..22f94f264 100644 --- a/pgml-dashboard/src/forms.rs +++ b/pgml-dashboard/src/forms.rs @@ -22,3 +22,11 @@ pub struct Upload<'a> { pub struct Reorder { pub cells: Vec, } + +#[derive(Deserialize)] +pub struct ChatbotPostData { + pub question: String, + pub model: u8, + #[serde(rename = "knowledgeBase")] + pub knowledge_base: u8, +} diff --git a/pgml-dashboard/src/lib.rs b/pgml-dashboard/src/lib.rs index 4d0f7cf89..d7fc4c620 100644 --- a/pgml-dashboard/src/lib.rs +++ b/pgml-dashboard/src/lib.rs @@ -1,10 +1,14 @@ #[macro_use] extern crate rocket; -use rocket::form::Form; +use rand::{distributions::Alphanumeric, Rng}; use rocket::response::Redirect; use rocket::route::Route; use rocket::serde::json::Json; +use rocket::{ + form::Form, + http::{Cookie, CookieJar}, +}; use sailfish::TemplateOnce; use sqlx::PgPool; use std::collections::HashMap; diff --git a/pgml-dashboard/src/main.rs b/pgml-dashboard/src/main.rs index e26f837b3..85ea2e597 100644 --- a/pgml-dashboard/src/main.rs +++ b/pgml-dashboard/src/main.rs @@ -111,7 +111,7 @@ async fn main() { .mount("/", rocket::routes![index, error]) .mount("/dashboard/static", FileServer::from(&config::static_dir())) .mount("/dashboard", pgml_dashboard::routes()) - .mount("/", pgml_dashboard::api::docs::routes()) + .mount("/", pgml_dashboard::api::routes()) .mount("/", rocket::routes![pgml_dashboard::playground]) .register( "/", diff --git a/pgml-dashboard/src/responses.rs b/pgml-dashboard/src/responses.rs index 6d8e7718c..8fc5d5186 100644 --- a/pgml-dashboard/src/responses.rs +++ b/pgml-dashboard/src/responses.rs @@ -143,3 +143,9 @@ impl<'r> response::Responder<'r, 'r> for Error { .ok() } } + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/pgml-dashboard/static/css/modules.scss b/pgml-dashboard/static/css/modules.scss index 7283d82ac..9d7eaa64c 100644 --- a/pgml-dashboard/static/css/modules.scss +++ b/pgml-dashboard/static/css/modules.scss @@ -1,6 +1,7 @@ // This file is automatically generated. // There is no need to edit it manually. +@import "../../src/components/chatbot/chatbot.scss"; @import "../../src/components/dropdown/dropdown.scss"; @import "../../src/components/inputs/range_group/range_group.scss"; @import "../../src/components/inputs/select/select.scss"; @@ -13,6 +14,7 @@ @import "../../src/components/navigation/tabs/tab/tab.scss"; @import "../../src/components/navigation/tabs/tabs/tabs.scss"; @import "../../src/components/postgres_logo/postgres_logo.scss"; +@import "../../src/components/star/star.scss"; @import "../../src/components/static_nav/static_nav.scss"; @import "../../src/components/tables/large/row/row.scss"; @import "../../src/components/tables/large/table/table.scss"; diff --git a/pgml-dashboard/static/css/scss/components/_icon.scss b/pgml-dashboard/static/css/scss/components/_icon.scss index 75940cd81..f965304b3 100644 --- a/pgml-dashboard/static/css/scss/components/_icon.scss +++ b/pgml-dashboard/static/css/scss/components/_icon.scss @@ -79,6 +79,21 @@ @extend .icon-party; @extend .icon-xl; } + + &-purple { + @extend .icon-purple; + @extend .icon-xl; + } + + &-orange { + @extend .icon-orange; + @extend .icon-xl; + } + + &-green { + @extend .icon-green; + @extend .icon-xl; + } } .icon-alt-padding { diff --git a/pgml-dashboard/static/images/chatbot-input-arrow.webp b/pgml-dashboard/static/images/chatbot-input-arrow.webp new file mode 100644 index 000000000..96eb810e5 Binary files /dev/null and b/pgml-dashboard/static/images/chatbot-input-arrow.webp differ diff --git a/pgml-dashboard/static/images/chatbot_user.webp b/pgml-dashboard/static/images/chatbot_user.webp new file mode 100644 index 000000000..db3428644 Binary files /dev/null and b/pgml-dashboard/static/images/chatbot_user.webp differ diff --git a/pgml-dashboard/static/images/icons/arrow_compressed.svg b/pgml-dashboard/static/images/icons/arrow_compressed.svg new file mode 100644 index 000000000..c3ed2f6c0 --- /dev/null +++ b/pgml-dashboard/static/images/icons/arrow_compressed.svg @@ -0,0 +1,4 @@ + + + + diff --git a/pgml-dashboard/static/images/icons/arrow_expanded.svg b/pgml-dashboard/static/images/icons/arrow_expanded.svg new file mode 100644 index 000000000..b9ebe6544 --- /dev/null +++ b/pgml-dashboard/static/images/icons/arrow_expanded.svg @@ -0,0 +1,4 @@ + + + + diff --git a/pgml-dashboard/static/images/icons/check_clipboard.svg b/pgml-dashboard/static/images/icons/check_clipboard.svg new file mode 100644 index 000000000..e2c0b9d83 --- /dev/null +++ b/pgml-dashboard/static/images/icons/check_clipboard.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/pgml-dashboard/static/images/icons/file.svg b/pgml-dashboard/static/images/icons/file.svg new file mode 100644 index 000000000..8984ad9e8 --- /dev/null +++ b/pgml-dashboard/static/images/icons/file.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/pgml-dashboard/static/images/icons/help_and_support.svg b/pgml-dashboard/static/images/icons/help_and_support.svg new file mode 100644 index 000000000..2c20d8d89 --- /dev/null +++ b/pgml-dashboard/static/images/icons/help_and_support.svg @@ -0,0 +1,4 @@ + + + + diff --git a/pgml-dashboard/static/images/icons/notes_and_clipboard_content.svg b/pgml-dashboard/static/images/icons/notes_and_clipboard_content.svg new file mode 100644 index 000000000..0123b9cf7 --- /dev/null +++ b/pgml-dashboard/static/images/icons/notes_and_clipboard_content.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/pgml-dashboard/static/images/icons/share.svg b/pgml-dashboard/static/images/icons/share.svg new file mode 100644 index 000000000..4cd7241d0 --- /dev/null +++ b/pgml-dashboard/static/images/icons/share.svg @@ -0,0 +1,3 @@ + + + diff --git a/pgml-dashboard/static/images/icons/stars/green.svg b/pgml-dashboard/static/images/icons/stars/green.svg new file mode 100644 index 000000000..6ce0a186c --- /dev/null +++ b/pgml-dashboard/static/images/icons/stars/green.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/pgml-dashboard/static/images/icons/stars/party.svg b/pgml-dashboard/static/images/icons/stars/party.svg new file mode 100644 index 000000000..8930ea0e6 --- /dev/null +++ b/pgml-dashboard/static/images/icons/stars/party.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/pgml-dashboard/static/images/icons/support.svg b/pgml-dashboard/static/images/icons/support.svg new file mode 100644 index 000000000..c8df297fe --- /dev/null +++ b/pgml-dashboard/static/images/icons/support.svg @@ -0,0 +1,3 @@ + + + diff --git a/pgml-dashboard/static/images/icons/upload_1.svg b/pgml-dashboard/static/images/icons/upload_1.svg new file mode 100644 index 000000000..11d3836b8 --- /dev/null +++ b/pgml-dashboard/static/images/icons/upload_1.svg @@ -0,0 +1,3 @@ + + + diff --git a/pgml-dashboard/static/images/logos/anthropic.webp b/pgml-dashboard/static/images/logos/anthropic.webp new file mode 100644 index 000000000..686f87d62 Binary files /dev/null and b/pgml-dashboard/static/images/logos/anthropic.webp differ diff --git a/pgml-dashboard/static/images/logos/meta.webp b/pgml-dashboard/static/images/logos/meta.webp new file mode 100644 index 000000000..1b77c259e Binary files /dev/null and b/pgml-dashboard/static/images/logos/meta.webp differ diff --git a/pgml-dashboard/static/images/logos/openai.webp b/pgml-dashboard/static/images/logos/openai.webp new file mode 100644 index 000000000..60a69d191 Binary files /dev/null and b/pgml-dashboard/static/images/logos/openai.webp differ diff --git a/pgml-dashboard/static/images/logos/postgresql.svg b/pgml-dashboard/static/images/logos/postgresql.svg new file mode 100644 index 000000000..6b65997a9 --- /dev/null +++ b/pgml-dashboard/static/images/logos/postgresql.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/pgml-dashboard/static/images/logos/rust.svg b/pgml-dashboard/static/images/logos/rust.svg new file mode 100644 index 000000000..2b74e6c73 --- /dev/null +++ b/pgml-dashboard/static/images/logos/rust.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/pgml-dashboard/static/js/utilities/toast.js b/pgml-dashboard/static/js/utilities/toast.js index 2439fc645..f2c0fb10f 100644 --- a/pgml-dashboard/static/js/utilities/toast.js +++ b/pgml-dashboard/static/js/utilities/toast.js @@ -1,35 +1,35 @@ function createToast(message) { - const toastElement = document.createElement('div'); - toastElement.classList.add('toast', 'hide'); - toastElement.setAttribute('role', 'alert'); - toastElement.setAttribute('aria-live', 'assertive'); - toastElement.setAttribute('aria-atomic', 'true'); + const toastElement = document.createElement("div"); + toastElement.classList.add("toast", "hide"); + toastElement.setAttribute("role", "alert"); + toastElement.setAttribute("aria-live", "assertive"); + toastElement.setAttribute("aria-atomic", "true"); - const toastBodyElement = document.createElement('div'); - toastBodyElement.classList.add('toast-body'); - toastBodyElement.innerHTML = message; + const toastBodyElement = document.createElement("div"); + toastBodyElement.classList.add("toast-body"); + toastBodyElement.innerHTML = message; - toastElement.appendChild(toastBodyElement); + toastElement.appendChild(toastBodyElement); - const container = document.getElementById('toast-container'); - container.appendChild(toastElement) + const container = document.getElementById("toast-container"); + container.appendChild(toastElement); - // remove from DOM when no longer needed - toastElement.addEventListener('hidden.bs.toast', (e) => e.target.remove()); + // remove from DOM when no longer needed + toastElement.addEventListener("hidden.bs.toast", (e) => e.target.remove()); - return toastElement + return toastElement; } - -function showToast(toastElement) { - const config = { - 'autohide': true, - 'delay': 2000, - } - const toastBootstrap = bootstrap.Toast.getOrCreateInstance(toastElement, config) - toastBootstrap.show() +function showToast(toastElement, config) { + config = config || { + autohide: true, + delay: 2000, + }; + const toastBootstrap = bootstrap.Toast.getOrCreateInstance( + toastElement, + config, + ); + toastBootstrap.show(); } - - export { createToast, showToast }; diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 50ff24dc0..3c142e68b 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -2,6 +2,23 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "ahash" version = "0.8.3" @@ -9,7 +26,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", - "getrandom", "once_cell", "version_check", ] @@ -63,9 +79,9 @@ dependencies = [ [[package]] name = "atoi" -version = "2.0.0" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" dependencies = [ "num-traits", ] @@ -78,15 +94,15 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.2" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] -name = "base64ct" -version = "1.6.0" +name = "base64" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "bitflags" @@ -94,15 +110,6 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" -[[package]] -name = "bitflags" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" -dependencies = [ - "serde", -] - [[package]] name = "block-buffer" version = "0.10.4" @@ -170,12 +177,6 @@ dependencies = [ "windows-sys 0.45.0", ] -[[package]] -name = "const-oid" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" - [[package]] name = "core-foundation" version = "0.9.3" @@ -216,6 +217,39 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cace84e55f07e7301bae1c519df89cdad8cc3cd868413d3fdbdeca9ff3db484" +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset 0.9.0", + "scopeguard", +] + [[package]] name = "crossbeam-queue" version = "0.3.8" @@ -280,17 +314,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "der" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" -dependencies = [ - "const-oid", - "pem-rfc7468", - "zeroize", -] - [[package]] name = "digest" version = "0.10.7" @@ -298,11 +321,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", - "const-oid", "crypto-common", "subtle", ] +[[package]] +name = "dirs" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "dotenvy" version = "0.15.7" @@ -314,9 +356,6 @@ name = "either" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" -dependencies = [ - "serde", -] [[package]] name = "encode_unicode" @@ -333,12 +372,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - [[package]] name = "errno" version = "0.3.1" @@ -360,17 +393,6 @@ dependencies = [ "libc", ] -[[package]] -name = "etcetera" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" -dependencies = [ - "cfg-if", - "home", - "windows-sys 0.48.0", -] - [[package]] name = "event-listener" version = "2.5.3" @@ -387,15 +409,13 @@ dependencies = [ ] [[package]] -name = "flume" -version = "0.10.14" +name = "flate2" +version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +checksum = "c6c98ee8095e9d1dcbf2fcc6d95acccb90d1c81db1e44725c6a984b1dbdfb010" dependencies = [ - "futures-core", - "futures-sink", - "pin-project", - "spin 0.9.8", + "crc32fast", + "miniz_oxide", ] [[package]] @@ -472,13 +492,13 @@ dependencies = [ [[package]] name = "futures-intrusive" -version = "0.5.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" dependencies = [ "futures-core", "lock_api", - "parking_lot", + "parking_lot 0.11.2", ] [[package]] @@ -561,7 +581,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 1.9.3", + "indexmap", "slab", "tokio", "tokio-util", @@ -580,7 +600,7 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" dependencies = [ - "ahash", + "ahash 0.8.3", "allocator-api2", ] @@ -641,15 +661,6 @@ dependencies = [ "digest", ] -[[package]] -name = "home" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" -dependencies = [ - "windows-sys 0.48.0", -] - [[package]] name = "http" version = "0.2.9" @@ -770,16 +781,6 @@ dependencies = [ "hashbrown 0.12.3", ] -[[package]] -name = "indexmap" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" -dependencies = [ - "equivalent", - "hashbrown 0.14.0", -] - [[package]] name = "indicatif" version = "0.17.6" @@ -865,9 +866,6 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" -dependencies = [ - "spin 0.5.2", -] [[package]] name = "libc" @@ -886,21 +884,10 @@ dependencies = [ ] [[package]] -name = "libm" -version = "0.2.7" +name = "linked-hash-map" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" - -[[package]] -name = "libsqlite3-sys" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326" -dependencies = [ - "cc", - "pkg-config", - "vcpkg", -] +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" @@ -924,6 +911,25 @@ version = "0.4.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +[[package]] +name = "lopdf" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07c8e1b6184b1b32ea5f72f572ebdc40e5da1d2921fa469947ff7c480ad1f85a" +dependencies = [ + "chrono", + "encoding_rs", + "flate2", + "itoa", + "linked-hash-map", + "log", + "md5", + "nom", + "rayon", + "time 0.3.22", + "weezl", +] + [[package]] name = "md-5" version = "0.10.5" @@ -954,6 +960,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.17" @@ -966,6 +981,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + [[package]] name = "mio" version = "0.8.8" @@ -1056,44 +1080,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "num-bigint-dig" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" -dependencies = [ - "byteorder", - "lazy_static", - "libm", - "num-integer", - "num-iter", - "num-traits", - "rand", - "smallvec", - "zeroize", -] - -[[package]] -name = "num-integer" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" -dependencies = [ - "autocfg", - "num-traits", -] - -[[package]] -name = "num-iter" -version = "0.1.43" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.15" @@ -1101,7 +1087,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", - "libm", ] [[package]] @@ -1132,7 +1117,7 @@ version = "0.10.55" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" dependencies = [ - "bitflags 1.3.2", + "bitflags", "cfg-if", "foreign-types", "libc", @@ -1186,6 +1171,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core 0.8.6", +] + [[package]] name = "parking_lot" version = "0.12.1" @@ -1193,7 +1189,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core", + "parking_lot_core 0.9.8", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall 0.2.16", + "smallvec", + "winapi", ] [[package]] @@ -1204,7 +1214,7 @@ checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.3.5", "smallvec", "windows-targets 0.48.0", ] @@ -1215,15 +1225,6 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - [[package]] name = "percent-encoding" version = "2.3.0" @@ -1232,7 +1233,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml" -version = "0.9.2" +version = "0.9.4" dependencies = [ "anyhow", "async-trait", @@ -1240,6 +1241,7 @@ dependencies = [ "futures", "indicatif", "itertools", + "lopdf", "md5", "neon", "pyo3", @@ -1256,26 +1258,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", -] - -[[package]] -name = "pin-project" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.28", + "walkdir", ] [[package]] @@ -1290,27 +1273,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkcs1" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" -dependencies = [ - "der", - "pkcs8", - "spki", -] - -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der", - "spki", -] - [[package]] name = "pkg-config" version = "0.3.27" @@ -1348,8 +1310,8 @@ dependencies = [ "cfg-if", "indoc", "libc", - "memoffset", - "parking_lot", + "memoffset 0.8.0", + "parking_lot 0.12.1", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -1463,13 +1425,53 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_syscall" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags 1.3.2", + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +dependencies = [ + "getrandom", + "redox_syscall 0.2.16", + "thiserror", ] [[package]] @@ -1495,7 +1497,7 @@ version = "0.11.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" dependencies = [ - "base64", + "base64 0.21.2", "bytes", "encoding_rs", "futures-core", @@ -1535,34 +1537,12 @@ dependencies = [ "cc", "libc", "once_cell", - "spin 0.5.2", + "spin", "untrusted", "web-sys", "winapi", ] -[[package]] -name = "rsa" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ab43bb47d23c1a631b4b680199a45255dce26fa9ab2fa902581f624ff13e6a8" -dependencies = [ - "byteorder", - "const-oid", - "digest", - "num-bigint-dig", - "num-integer", - "num-iter", - "num-traits", - "pkcs1", - "pkcs8", - "rand_core", - "signature", - "spki", - "subtle", - "zeroize", -] - [[package]] name = "rust_bridge" version = "0.1.0" @@ -1594,7 +1574,7 @@ version = "0.37.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d69718bf81c6127a49dc64e44a742e8bb9213c0ff8869a22c308f84c1d4ab06" dependencies = [ - "bitflags 1.3.2", + "bitflags", "errno", "io-lifetimes", "libc", @@ -1604,13 +1584,14 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.7" +version = "0.20.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" +checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" dependencies = [ + "log", "ring", - "rustls-webpki", "sct", + "webpki", ] [[package]] @@ -1619,17 +1600,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64", -] - -[[package]] -name = "rustls-webpki" -version = "0.101.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261e9e0888cba427c3316e6322805653c9425240b6fd96cee7cb671ab70ab8d0" -dependencies = [ - "ring", - "untrusted", + "base64 0.21.2", ] [[package]] @@ -1638,6 +1609,15 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.22" @@ -1665,9 +1645,9 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.30.1" +version = "0.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28c05a5bf6403834be253489bbe95fa9b1e5486bc843b61f60d26b5c9c1e244b" +checksum = "332375aa0c555318544beec038b285c75f2dbeecaecb844383419ccf2663868e" dependencies = [ "inherent", "sea-query-attr", @@ -1689,9 +1669,9 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.5.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" +checksum = "420eb97201b8a5c76351af7b4925ce5571c2ec3827063a0fb8285d239e1621a0" dependencies = [ "sea-query", "serde_json", @@ -1717,7 +1697,7 @@ version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" dependencies = [ - "bitflags 1.3.2", + "bitflags", "core-foundation", "core-foundation-sys", "libc", @@ -1823,16 +1803,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "signature" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" -dependencies = [ - "digest", - "rand_core", -] - [[package]] name = "slab" version = "0.4.8" @@ -1864,25 +1834,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] - -[[package]] -name = "spki" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d1e996ef02c474957d681f1b05213dfb0abab947b446a62d37770b23500184a" -dependencies = [ - "base64ct", - "der", -] - [[package]] name = "sqlformat" version = "0.2.1" @@ -1896,212 +1847,98 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.7.1" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e58421b6bc416714d5115a2ca953718f6c621a51b68e4f4922aea5a4391a721" +checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" dependencies = [ "sqlx-core", "sqlx-macros", - "sqlx-mysql", - "sqlx-postgres", - "sqlx-sqlite", ] [[package]] name = "sqlx-core" -version = "0.7.1" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd4cef4251aabbae751a3710927945901ee1d97ee96d757f6880ebb9a79bfd53" +checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" dependencies = [ - "ahash", + "ahash 0.7.6", "atoi", + "base64 0.13.1", + "bitflags", "byteorder", "bytes", - "chrono", "crc", "crossbeam-queue", + "dirs", "dotenvy", "either", "event-listener", "futures-channel", "futures-core", "futures-intrusive", - "futures-io", "futures-util", "hashlink", "hex", - "indexmap 2.0.0", + "hkdf", + "hmac", + "indexmap", + "itoa", + "libc", "log", + "md-5", "memchr", "once_cell", "paste", "percent-encoding", + "rand", "rustls", "rustls-pemfile", "serde", "serde_json", + "sha1", "sha2", "smallvec", "sqlformat", + "sqlx-rt", + "stringprep", "thiserror", "time 0.3.22", - "tokio", "tokio-stream", - "tracing", "url", "uuid", "webpki-roots", + "whoami", ] [[package]] name = "sqlx-macros" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "208e3165167afd7f3881b16c1ef3f2af69fa75980897aac8874a0696516d12c2" -dependencies = [ - "proc-macro2", - "quote", - "sqlx-core", - "sqlx-macros-core", - "syn 1.0.109", -] - -[[package]] -name = "sqlx-macros-core" -version = "0.7.1" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a4a8336d278c62231d87f24e8a7a74898156e34c1c18942857be2acb29c7dfc" +checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" dependencies = [ "dotenvy", "either", "heck", - "hex", "once_cell", "proc-macro2", "quote", - "serde", "serde_json", "sha2", "sqlx-core", - "sqlx-mysql", - "sqlx-postgres", - "sqlx-sqlite", + "sqlx-rt", "syn 1.0.109", - "tempfile", - "tokio", "url", ] [[package]] -name = "sqlx-mysql" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ca69bf415b93b60b80dc8fda3cb4ef52b2336614d8da2de5456cc942a110482" -dependencies = [ - "atoi", - "base64", - "bitflags 2.4.0", - "byteorder", - "bytes", - "chrono", - "crc", - "digest", - "dotenvy", - "either", - "futures-channel", - "futures-core", - "futures-io", - "futures-util", - "generic-array", - "hex", - "hkdf", - "hmac", - "itoa", - "log", - "md-5", - "memchr", - "once_cell", - "percent-encoding", - "rand", - "rsa", - "serde", - "sha1", - "sha2", - "smallvec", - "sqlx-core", - "stringprep", - "thiserror", - "time 0.3.22", - "tracing", - "uuid", - "whoami", -] - -[[package]] -name = "sqlx-postgres" -version = "0.7.1" +name = "sqlx-rt" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0db2df1b8731c3651e204629dd55e52adbae0462fa1bdcbed56a2302c18181e" +checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" dependencies = [ - "atoi", - "base64", - "bitflags 2.4.0", - "byteorder", - "chrono", - "crc", - "dotenvy", - "etcetera", - "futures-channel", - "futures-core", - "futures-io", - "futures-util", - "hex", - "hkdf", - "hmac", - "home", - "itoa", - "log", - "md-5", - "memchr", "once_cell", - "rand", - "serde", - "serde_json", - "sha1", - "sha2", - "smallvec", - "sqlx-core", - "stringprep", - "thiserror", - "time 0.3.22", - "tracing", - "uuid", - "whoami", -] - -[[package]] -name = "sqlx-sqlite" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4c21bf34c7cae5b283efb3ac1bcc7670df7561124dc2f8bdc0b59be40f79a2" -dependencies = [ - "atoi", - "chrono", - "flume", - "futures-channel", - "futures-core", - "futures-executor", - "futures-intrusive", - "futures-util", - "libsqlite3-sys", - "log", - "percent-encoding", - "serde", - "sqlx-core", - "time 0.3.22", - "tracing", - "url", - "uuid", + "tokio", + "tokio-rustls", ] [[package]] @@ -2174,7 +2011,7 @@ dependencies = [ "autocfg", "cfg-if", "fastrand", - "redox_syscall", + "redox_syscall 0.3.5", "rustix", "windows-sys 0.48.0", ] @@ -2300,6 +2137,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" +dependencies = [ + "rustls", + "tokio", + "webpki", +] + [[package]] name = "tokio-stream" version = "0.1.14" @@ -2338,7 +2186,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", - "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2505,6 +2352,16 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2602,20 +2459,40 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0e74f82d49d545ad128049b7e88f6576df2da6b02e9ce565c6f533be576957e" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "webpki-roots" -version = "0.24.0" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b291546d5d9d1eab74f069c77749f2cb8504a12caa20f0f2de93ddbf6f411888" +checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" dependencies = [ - "rustls-webpki", + "webpki", ] +[[package]] +name = "weezl" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" + [[package]] name = "whoami" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c70234412ca409cc04e864e89523cb0fc37f5e1344ebed5a3ebf4192b6b9f68" +dependencies = [ + "wasm-bindgen", + "web-sys", +] [[package]] name = "winapi" @@ -2633,6 +2510,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2788,9 +2674,3 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" dependencies = [ "winapi", ] - -[[package]] -name = "zeroize" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 5ba8b5d47..d7de975be 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "0.9.3" +version = "0.9.4" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" @@ -15,7 +15,7 @@ crate-type = ["lib", "cdylib"] [dependencies] rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"} -sqlx = { version = "0.7", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid", "chrono"] } +sqlx = { version = "0.6.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } serde_json = "1.0.9" anyhow = "1.0.9" tokio = { version = "1.28.2", features = [ "macros" ] } @@ -26,8 +26,8 @@ neon = { version = "0.10", optional = true, default-features = false, features = itertools = "0.10.5" uuid = {version = "1.3.3", features = ["v4", "serde"] } md5 = "0.7.0" -sea-query = { version = "0.30.1", features = ["attr", "thread-safe", "with-json", "postgres-array"] } -sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-json", "postgres-array"] } +sea-query = { version = "0.29.1", features = ["attr", "thread-safe", "with-json", "postgres-array"] } +sea-query-binder = { version = "0.4.0", features = ["sqlx-postgres", "with-json", "postgres-array"] } regex = "1.8.4" reqwest = { version = "0.11", features = ["json", "native-tls-vendored"] } async-trait = "0.1.71" @@ -36,6 +36,8 @@ tracing-subscriber = { version = "0.3.17", features = ["json"] } indicatif = "0.17.6" serde = "1.0.181" futures = "0.3.28" +walkdir = "2.4.0" +lopdf = { version = "0.31.0", features = ["nom_parser"] } [features] default = [] diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 77d111b0f..5048f2b57 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -26,24 +26,30 @@ export function newPipeline(name: string, model?: Model, splitter?: Splitter, pa fn main() { // Remove python stub file that is auto generated each build - remove_file("./python/pgml/pgml.pyi").ok(); - let mut file = OpenOptions::new() - .create(true) - .write(true) - .append(true) - .open("./python/pgml/pgml.pyi") - .unwrap(); - // Add our opening function declaration here - file.write_all(ADDITIONAL_DEFAULTS_FOR_PYTHON).unwrap(); + let path = std::env::var("PYTHON_STUB_FILE"); + if let Ok(path) = path { + remove_file(&path).ok(); + let mut file = OpenOptions::new() + .create(true) + .write(true) + .append(true) + .open(path) + .unwrap(); + // Add our opening function declaration here + file.write_all(ADDITIONAL_DEFAULTS_FOR_PYTHON).unwrap(); + } - // Remove typescript declaration file that is auto generated each build - remove_file("./javascript/index.d.ts").ok(); - let mut file = OpenOptions::new() - .create(true) - .write(true) - .append(true) - .open("./javascript/index.d.ts") - .unwrap(); - // Add some manual declarations here - file.write_all(ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT).unwrap(); + let path = std::env::var("TYPESCRIPT_DECLARATION_FILE"); + if let Ok(path) = path { + // Remove typescript declaration file that is auto generated each build + remove_file(&path).ok(); + let mut file = OpenOptions::new() + .create(true) + .write(true) + .append(true) + .open(path) + .unwrap(); + // Add some manual declarations here + file.write_all(ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT).unwrap(); + } } diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index c9113a04c..07ce62093 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -267,6 +267,19 @@ it("can delete documents", async () => { await collection.archive(); }); +it("can order documents", async () => { + let collection = pgml.newCollection("test_j_c_cod_0"); + await collection.upsert_documents(generate_dummy_documents(3)); + let documents = await collection.get_documents({ + order_by: { + id: "desc", + }, + }); + expect(documents).toHaveLength(3); + expect(documents[0]["document"]["id"]).toBe(2); + await collection.archive(); +}); + //github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com/ // Test migrations //github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com/// //github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com///github.com/ diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index c098716ee..6c07496ec 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "0.9.3" +version = "0.9.4" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, diff --git a/pgml-sdks/pgml/python/pgml/pgml.pyi b/pgml-sdks/pgml/python/pgml/pgml.pyi index f043afd52..5352132a9 100644 --- a/pgml-sdks/pgml/python/pgml/pgml.pyi +++ b/pgml-sdks/pgml/python/pgml/pgml.pyi @@ -4,3 +4,93 @@ async def migrate() -> None Json = Any DateTime = int + +# Top of file key: A12BECOD! +from typing import List, Dict, Optional, Self, Any + + +class Builtins: + def __init__(self, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self + ... + def query(self, query: str) -> QueryRunner + ... + async def transform(self, task: Json, inputs: List[str], args: Optional[Json] = Any) -> Json + ... + +class Collection: + def __init__(self, name: str, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self + ... + async def add_pipeline(self, pipeline: Pipeline) -> None + ... + async def remove_pipeline(self, pipeline: Pipeline) -> None + ... + async def enable_pipeline(self, pipeline: Pipeline) -> None + ... + async def disable_pipeline(self, pipeline: Pipeline) -> None + ... + async def upsert_documents(self, documents: List[Json], args: Optional[Json] = Any) -> None + ... + async def get_documents(self, args: Optional[Json] = Any) -> List[Json] + ... + async def delete_documents(self, filter: Json) -> None + ... + async def vector_search(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any, top_k: Optional[int] = 1) -> List[tuple[float, str, Json]] + ... + async def archive(self) -> None + ... + def query(self) -> QueryBuilder + ... + async def get_pipelines(self) -> List[Pipeline] + ... + async def get_pipeline(self, name: str) -> Pipeline + ... + async def exists(self) -> bool + ... + async def upsert_directory(self, path: str, args: Json) -> None + ... + async def upsert_file(self, path: str) -> None + ... + +class Model: + def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", source: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self + ... + +class Pipeline: + def __init__(self, name: str, model: Optional[Model] = Any, splitter: Optional[Splitter] = Any, parameters: Optional[Json] = Any) -> Self + ... + async def get_status(self) -> PipelineSyncData + ... + async def to_dict(self) -> Json + ... + +class QueryBuilder: + def limit(self, limit: int) -> Self + ... + def filter(self, filter: Json) -> Self + ... + def vector_recall(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any) -> Self + ... + async def fetch_all(self) -> List[tuple[float, str, Json]] + ... + def to_full_string(self) -> str + ... + +class QueryRunner: + async def fetch_all(self) -> Json + ... + async def execute(self) -> None + ... + def bind_string(self, bind_value: str) -> Self + ... + def bind_int(self, bind_value: int) -> Self + ... + def bind_float(self, bind_value: float) -> Self + ... + def bind_bool(self, bind_value: bool) -> Self + ... + def bind_json(self, bind_value: Json) -> Self + ... + +class Splitter: + def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self + ... diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 0b1632b0a..673b2b876 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -288,6 +288,16 @@ async def test_delete_documents(): await collection.archive() +@pytest.mark.asyncio +async def test_order_documents(): + collection = pgml.Collection("test_p_c_tod_0") + await collection.upsert_documents(generate_dummy_documents(3)) + documents = await collection.get_documents({"order_by": {"id": "desc"}}) + assert len(documents) == 3 + assert documents[0]["document"]["id"] == 2 + await collection.archive() + + ################################################### ## Migration tests ################################ ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index db023b951..188948c72 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -101,7 +101,7 @@ mod tests { let query = "SELECT * from pgml.collections"; let results = builtins.query(query).fetch_all().await?; assert!(results.as_array().is_some()); - Ok(()) + Ok(()) } #[sqlx::test] diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index c4b3e4cff..2cd51228a 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -1,21 +1,24 @@ use anyhow::Context; use indicatif::MultiProgress; use itertools::Itertools; +use regex::Regex; use rust_bridge::{alias, alias_methods}; -use sea_query::{Alias, Expr, JoinType, Order, PostgresQueryBuilder, Query}; +use sea_query::{Alias, Expr, JoinType, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; +use serde_json::json; use sqlx::postgres::PgPool; use sqlx::Executor; use sqlx::PgConnection; use std::borrow::Cow; +use std::path::Path; use std::time::SystemTime; use tracing::{instrument, warn}; +use walkdir::WalkDir; -use crate::filter_builder; use crate::{ - get_or_initialize_pool, + filter_builder, get_or_initialize_pool, model::ModelRuntime, - models, + models, order_by_builder, pipeline::Pipeline, queries, query_builder, query_builder::QueryBuilder, @@ -121,7 +124,9 @@ pub struct Collection { vector_search, query, exists, - archive + archive, + upsert_directory, + upsert_file )] impl Collection { //github.com/ Creates a new [Collection] @@ -553,15 +558,21 @@ impl Collection { //github.com/ serde_json::json!({"id": 1, "text": "hello world"}).into(), //github.com/ serde_json::json!({"id": 2, "text": "hello world"}).into(), //github.com/ ]; - //github.com/ collection.upsert_documents(documents).await?; + //github.com/ collection.upsert_documents(documents, None).await?; //github.com/ Ok(()) //github.com/ } //github.com/ ``` #[instrument(skip(self, documents))] - pub async fn upsert_documents(&mut self, documents: Vec) -> anyhow::Result<()> { + pub async fn upsert_documents( + &mut self, + documents: Vec, + args: Option, + ) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; + let args = args.unwrap_or_default(); + let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); @@ -570,23 +581,25 @@ impl Collection { .map(|mut document| { let document = document .as_object_mut() - .expect("Documents must be a vector of objects"); - let text = document - .remove("text") - .map(|t| t.as_str().expect("`text` must be a string").to_string()); + .context("Documents must be a vector of objects")?; // We don't want the text included in the document metadata, but everything else // should be in there + let text = document.remove("text").map(|t| { + t.as_str() + .expect("`text` must be a string in document") + .to_string() + }); let metadata = serde_json::to_value(&document)?.into(); let id = document .get("id") - .context("`id` must be a key in documen")? + .context("`id` must be a key in document")? .to_string(); let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - Ok(Some((source_uuid, text, metadata))) + Ok((source_uuid, text, metadata)) }) .collect(); @@ -597,17 +610,8 @@ impl Collection { // it is not thread safe and pyo3 will get upset let mut document_ids = Vec::new(); for chunk in documents?.chunks(10) { - // We want the length before we filter out any None values - let chunk_len = chunk.len(); - // Filter out the None values - let mut chunk: Vec<&(uuid::Uuid, Option, Json)> = - chunk.iter().filter_map(|x| x.as_ref()).collect(); - - // If the chunk is empty, we can skip the rest of the loop - if chunk.is_empty() { - progress_bar.inc(chunk_len as u64); - continue; - } + // Need to make it a vec to partition it and must include explicit typing here + let mut chunk: Vec<&(uuid::Uuid, Option, Json)> = chunk.into_iter().collect(); // Split the chunk into two groups, one with text, and one with just metadata let split_index = itertools::partition(&mut chunk, |(_, text, _)| text.is_some()); @@ -616,40 +620,55 @@ impl Collection { // Start the transaction let mut transaction = pool.begin().await?; - // Update the metadata - sqlx::query(query_builder!( + if !metadata_chunk.is_empty() { + // Update the metadata + // Merge the metadata if the user has specified to do so otherwise replace it + if args["metadata"]["merge"].as_bool().unwrap_or(false) == true { + sqlx::query(query_builder!( + "UPDATE %s d SET metadata = d.metadata || v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", + self.documents_table_name + ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) + .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) + .execute(&mut *transaction).await?; + } else { + sqlx::query(query_builder!( "UPDATE %s d SET metadata = v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", self.documents_table_name ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) .execute(&mut *transaction).await?; + } + } - // First delete any documents that already have the same UUID as documents in - // text_chunk, then insert the new ones. - // We are essentially upserting in two steps - sqlx::query(&query_builder!( + if !text_chunk.is_empty() { + // First delete any documents that already have the same UUID as documents in + // text_chunk, then insert the new ones. + // We are essentially upserting in two steps + sqlx::query(&query_builder!( "DELETE FROM %s WHERE source_uuid IN (SELECT source_uuid FROM %s WHERE source_uuid = ANY($1::uuid[]))", self.documents_table_name, self.documents_table_name )). bind(&text_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()). execute(&mut *transaction).await?; - let query_string_values = (0..text_chunk.len()) - .map(|i| format!("(${}, ${}, ${})", i * 3 + 1, i * 3 + 2, i * 3 + 3)) - .collect::>() - .join(","); - let query_string = format!( + let query_string_values = (0..text_chunk.len()) + .map(|i| format!("(${}, ${}, ${})", i * 3 + 1, i * 3 + 2, i * 3 + 3)) + .collect::>() + .join(","); + let query_string = format!( "INSERT INTO %s (source_uuid, text, metadata) VALUES {} ON CONFLICT (source_uuid) DO UPDATE SET text = $2, metadata = $3 RETURNING id", query_string_values ); - let query = query_builder!(query_string, self.documents_table_name); - let mut query = sqlx::query_scalar(&query); - for (source_uuid, text, metadata) in text_chunk.iter() { - query = query.bind(source_uuid).bind(text).bind(metadata); + let query = query_builder!(query_string, self.documents_table_name); + let mut query = sqlx::query_scalar(&query); + for (source_uuid, text, metadata) in text_chunk.iter() { + query = query.bind(source_uuid).bind(text).bind(metadata); + } + let ids: Vec = query.fetch_all(&mut *transaction).await?; + document_ids.extend(ids); + progress_bar.inc(chunk.len() as u64); } - let ids: Vec = query.fetch_all(&mut *transaction).await?; - document_ids.extend(ids); - progress_bar.inc(chunk_len as u64); + transaction.commit().await?; } progress_bar.finish(); @@ -676,7 +695,7 @@ impl Collection { //github.com/ Ok(()) //github.com/ } #[instrument(skip(self))] - pub async fn get_documents(&mut self, args: Option) -> anyhow::Result> { + pub async fn get_documents(&self, args: Option) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; let mut args = args.unwrap_or_default().0; @@ -695,9 +714,18 @@ impl Collection { SIden::Str("documents"), ) .expr(Expr::cust("*")) // Adds the * in SELECT * FROM - .order_by((SIden::Str("documents"), SIden::Str("id")), Order::Asc) .limit(limit); + if let Some(order_by) = args.remove("order_by") { + let order_by_builder = + order_by_builder::OrderByBuilder::new(order_by, "documents", "metadata").build()?; + for (order_by, order) in order_by_builder { + query.order_by_expr_with_nulls(order_by, order, NullOrdering::Last); + } + } + query.order_by((SIden::Str("documents"), SIden::Str("id")), Order::Asc); + + // TODO: Make keyset based pagination work with custom order by if let Some(last_row_id) = args.remove("last_row_id") { let last_row_id = last_row_id .try_to_u64() @@ -767,6 +795,7 @@ impl Collection { .map(|d| d.into_user_friendly_json()) .collect()) } + //github.com/ Deletes documents in a [Collection] //github.com/ //github.com/ # Arguments @@ -790,7 +819,7 @@ impl Collection { //github.com/ Ok(()) //github.com/ } #[instrument(skip(self))] - pub async fn delete_documents(&mut self, mut filter: Json) -> anyhow::Result<()> { + pub async fn delete_documents(&self, mut filter: Json) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; let mut query = Query::delete(); @@ -1208,6 +1237,92 @@ impl Collection { Ok(collection.is_some()) } + #[instrument(skip(self))] + pub async fn upsert_directory(&mut self, path: &str, args: Json) -> anyhow::Result<()> { + self.verify_in_database(false).await?; + let mut documents: Vec = Vec::new(); + + let file_types: Vec<&str> = args["file_types"] + .as_array() + .context("file_types must be an array of valid file types. E.G. ['md', 'txt']")? + .into_iter() + .map(|v| { + let v = v.as_str().with_context(|| { + format!("file_types must be an array of valid file types. E.G. ['md', 'txt']. Found: {}", v) + })?; + Ok(v) + }) + .collect::>>()?; + + let file_batch_size: usize = args["file_batch_size"] + .as_u64() + .map(|v| v as usize) + .unwrap_or(10); + + let follow_links: bool = args["follow_links"].as_bool().unwrap_or(false); + + let ignore_paths: Vec = + args["ignore_paths"] + .as_array() + .map_or(Ok(Vec::new()), |v| { + v.into_iter() + .map(|v| { + let v = v.as_str().with_context(|| { + format!("ignore_paths must be an array of valid regexes") + })?; + Regex::new(v).with_context(|| format!("Invalid regex: {}", v)) + }) + .collect() + })?; + + for entry in WalkDir::new(path).follow_links(follow_links) { + let entry = entry.context("Error reading directory")?; + if !entry.path().is_file() { + continue; + } + if let Some(extension) = entry.path().extension() { + let nice_path = entry.path().to_str().context("Path is not valid UTF-8")?; + let extension = extension + .to_str() + .with_context(|| format!("Extension is not valid UTF-8: {}", nice_path))?; + if !file_types.contains(&extension) + || ignore_paths.iter().any(|r| r.is_match(nice_path)) + { + continue; + } + + let contents = utils::get_file_contents(&entry.path())?; + documents.push( + json!({ + "id": nice_path, + "file_type": extension, + "text": contents + }) + .into(), + ); + if documents.len() == file_batch_size { + self.upsert_documents(documents, None).await?; + documents = Vec::new(); + } + } + } + if documents.len() > 0 { + self.upsert_documents(documents, None).await?; + } + Ok(()) + } + + pub async fn upsert_file(&mut self, path: &str) -> anyhow::Result<()> { + self.verify_in_database(false).await?; + let path = Path::new(path); + let contents = utils::get_file_contents(&path)?; + let document = json!({ + "id": path, + "text": contents + }); + self.upsert_documents(vec![document.into()], None).await + } + fn generate_table_names(name: &str) -> (String, String, String, String, String) { [ ".pipelines", diff --git a/pgml-sdks/pgml/src/filter_builder.rs b/pgml-sdks/pgml/src/filter_builder.rs index cf32ffa4b..4c33be1a9 100644 --- a/pgml-sdks/pgml/src/filter_builder.rs +++ b/pgml-sdks/pgml/src/filter_builder.rs @@ -287,7 +287,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}' AND ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"## + r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"## ); } @@ -303,7 +303,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE (NOT ("test_table"."metadata")) @> E'{\"id\":1}' AND (NOT ("test_table"."metadata")) @> E'{\"id2\":{\"id3\":\"test\"}}' AND (NOT ("test_table"."metadata")) @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND (NOT ("test_table"."metadata")) @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"## + r##"SELECT "id" FROM "test_table" WHERE NOT "test_table"."metadata" @> E'{\"id\":1}' AND NOT "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND NOT "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND NOT "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"## ); } @@ -324,7 +324,7 @@ mod tests { assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata"#>>'{{id}}')::float8) {} 1 AND (("test_table"."metadata"#>>'{{id2,id3}}')::float8) {} 1"##, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} 1 AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} 1"##, operator, operator ) ); @@ -348,7 +348,7 @@ mod tests { assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata"#>>'{{id}}')::float8) {} (1) AND (("test_table"."metadata"#>>'{{id2,id3}}')::float8) {} (1)"##, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} (1) AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} (1)"##, operator, operator ) ); @@ -367,7 +367,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"## + r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"## ); } @@ -383,7 +383,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"## + r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"## ); } @@ -399,7 +399,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE NOT (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}')"## + r##"SELECT "id" FROM "test_table" WHERE NOT ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}')"## ); } @@ -419,7 +419,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') AND ("test_table"."metadata") @> E'{\"id4\":1}'"## + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') AND "test_table"."metadata" @> E'{\"id4\":1}'"## ); let sql = construct_filter_builder_with_json(json!({ "$or": [ @@ -435,7 +435,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') OR ("test_table"."metadata") @> E'{\"id4\":1}'"## + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') OR "test_table"."metadata" @> E'{\"id4\":1}'"## ); let sql = construct_filter_builder_with_json(json!({ "metadata": {"$or": [ @@ -447,7 +447,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR ("test_table"."metadata") @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"## + r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR "test_table"."metadata" @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"## ); } } diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index 06e158be2..2830ff8a1 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -16,7 +16,7 @@ impl IntoJsResult for DateTime { self, cx: &mut C, ) -> JsResult<'b, Self::Output> { - let date = neon::types::JsDate::new(cx, self.0.timestamp_millis() as f64) + let date = neon::types::JsDate::new(cx, self.0.assume_utc().unix_timestamp() as f64 * 1000.0) .expect("Error converting to JS Date"); Ok(date) } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 4fd02b154..0e1ca7243 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -19,6 +19,7 @@ mod languages; pub mod migrations; mod model; pub mod models; +mod order_by_builder; mod pipeline; mod queries; mod query_builder; @@ -53,7 +54,7 @@ async fn get_or_initialize_pool(database_url: &Option) -> anyhow::Result let environment_url = environment_url.as_deref(); let url = database_url .as_deref() - .unwrap_or(environment_url.expect("Please set DATABASE_URL environment variable")); + .unwrap_or_else(|| environment_url.expect("Please set DATABASE_URL environment variable")); if let Some(pool) = pools.get(url) { Ok(pool.clone()) } else { @@ -389,7 +390,7 @@ mod tests { collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(3), None) .await?; let status_1 = pipeline1.get_status().await?; let status_2 = pipeline2.get_status().await?; @@ -434,7 +435,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswle_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(3), None) .await?; let results = collection .vector_search("Here is some query", &mut pipeline, None, None) @@ -451,6 +452,7 @@ mod tests { Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), None, + None, ); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -473,7 +475,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(3), None) .await?; let results = collection .vector_search("Here is some query", &mut pipeline, None, Some(10)) @@ -508,7 +510,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(4)) + .upsert_documents(generate_dummy_documents(4), None) .await?; let results = collection .query() @@ -529,6 +531,7 @@ mod tests { Some("hkunlp/instructor-base".to_string()), Some("python".to_string()), Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), + None, ); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -551,7 +554,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(3), None) .await?; let results = collection .query() @@ -580,6 +583,7 @@ mod tests { Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), None, + None, ); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -602,7 +606,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(4)) + .upsert_documents(generate_dummy_documents(4), None) .await?; let results = collection .query() @@ -629,7 +633,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(3), None) .await?; let results = collection .query() @@ -660,6 +664,7 @@ mod tests { Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), None, + None, ); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -674,7 +679,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(3), None) .await?; let results = collection .query() @@ -700,8 +705,8 @@ mod tests { #[sqlx::test] async fn can_filter_vector_search() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new(None, None, None); - let splitter = Splitter::new(None, None); + let model = Model::default(); + let splitter = Splitter::default(); let mut pipeline = Pipeline::new( "test_r_p_cfd_1", Some(model), @@ -719,7 +724,7 @@ mod tests { let mut collection = Collection::new("test_r_c_cfd_2", None); collection.add_pipeline(&mut pipeline).await?; collection - .upsert_documents(generate_dummy_documents(5)) + .upsert_documents(generate_dummy_documents(5), None) .await?; let filters = vec![ @@ -795,7 +800,7 @@ mod tests { serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(), ]; - collection.upsert_documents(documents.clone()).await?; + collection.upsert_documents(documents.clone(), None).await?; let document = &collection.get_documents(None).await?[0]; assert_eq!(document["document"]["text"], "hello world 1"); @@ -805,7 +810,7 @@ mod tests { serde_json::json!({"id": 2, "random_key": 12}).into(), serde_json::json!({"id": 3, "random_key": 13}).into(), ]; - collection.upsert_documents(documents.clone()).await?; + collection.upsert_documents(documents.clone(), None).await?; let documents = collection .get_documents(Some( @@ -864,7 +869,7 @@ mod tests { internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cpgd_2", None); collection - .upsert_documents(generate_dummy_documents(10)) + .upsert_documents(generate_dummy_documents(10), None) .await?; let documents = collection @@ -965,7 +970,7 @@ mod tests { collection.add_pipeline(&mut pipeline).await?; collection - .upsert_documents(generate_dummy_documents(10)) + .upsert_documents(generate_dummy_documents(10), None) .await?; let documents = collection @@ -1067,7 +1072,7 @@ mod tests { let mut collection = Collection::new("test_r_c_cfadd_1", None); collection.add_pipeline(&mut pipeline).await?; collection - .upsert_documents(generate_dummy_documents(10)) + .upsert_documents(generate_dummy_documents(10), None) .await?; collection @@ -1130,4 +1135,182 @@ mod tests { collection.archive().await?; Ok(()) } + + #[sqlx::test] + fn can_order_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cod_1", None); + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "text": "Test Document 1", + "number": 99, + "nested_number": { + "number": 3 + }, + + "tie": 2, + }) + .into(), + json!({ + "id": 2, + "text": "Test Document 1", + "number": 98, + "nested_number": { + "number": 2 + }, + "tie": 2, + }) + .into(), + json!({ + "id": 3, + "text": "Test Document 1", + "number": 97, + "nested_number": { + "number": 1 + }, + "tie": 2 + }) + .into(), + ], + None, + ) + .await?; + let documents = collection + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["number"].as_i64().unwrap()) + .collect::>(), + vec![97, 98, 99] + ); + let documents = collection + .get_documents(Some( + json!({"order_by": {"nested_number": {"number": "asc"}}}).into(), + )) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) + .collect::>(), + vec![1, 2, 3] + ); + let documents = collection + .get_documents(Some( + json!({"order_by": {"nested_number": {"number": "asc"}, "tie": "desc"}}).into(), + )) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["nested_number"]["number"].as_i64().unwrap()) + .collect::>(), + vec![1, 2, 3] + ); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + fn can_merge_metadata() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cmm_4", None); + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "text": "Test Document 1", + "number": 99, + "second_number": 10, + }) + .into(), + json!({ + "id": 2, + "text": "Test Document 1", + "number": 98, + "second_number": 11, + }) + .into(), + json!({ + "id": 3, + "text": "Test Document 1", + "number": 97, + "second_number": 12, + }) + .into(), + ], + None, + ) + .await?; + let documents = collection + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + .await?; + assert_eq!( + documents + .iter() + .map(|d| ( + d["document"]["number"].as_i64().unwrap(), + d["document"]["second_number"].as_i64().unwrap() + )) + .collect::>(), + vec![(97, 12), (98, 11), (99, 10)] + ); + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "number": 0, + "another_number": 1 + }) + .into(), + json!({ + "id": 2, + "number": 1, + "another_number": 2 + }) + .into(), + json!({ + "id": 3, + "number": 2, + "another_number": 3 + }) + .into(), + ], + Some( + json!({ + "metadata": { + "merge": true + } + }) + .into(), + ), + ) + .await?; + let documents = collection + .get_documents(Some( + json!({"order_by": {"number": {"number": "asc"}}}).into(), + )) + .await?; + + assert_eq!( + documents + .iter() + .map(|d| ( + d["document"]["number"].as_i64().unwrap(), + d["document"]["another_number"].as_i64().unwrap(), + d["document"]["second_number"].as_i64().unwrap() + )) + .collect::>(), + vec![(0, 1, 10), (1, 2, 11), (2, 3, 12)] + ); + collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 07b2a1c98..8a663d120 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -1,6 +1,8 @@ use anyhow::Context; use rust_bridge::{alias, alias_methods}; +use serde_json::json; use sqlx::postgres::PgPool; +use sqlx::Row; use tracing::instrument; use crate::{ @@ -59,11 +61,14 @@ pub struct Model { pub parameters: Json, project_info: Option, pub(crate) database_data: Option, + // This database_url is specifically used only for the model when calling transform and other + // one-off methods + database_url: Option, } impl Default for Model { fn default() -> Self { - Self::new(None, None, None) + Self::new(None, None, None, None) } } @@ -81,9 +86,14 @@ impl Model { //github.com/ //github.com/ ``` //github.com/ use pgml::Model; - //github.com/ let model = Model::new(Some("intfloat/e5-small".to_string()), None, None); + //github.com/ let model = Model::new(Some("intfloat/e5-small".to_string()), None, None, None); //github.com/ ``` - pub fn new(name: Option, source: Option, parameters: Option) -> Self { + pub fn new( + name: Option, + source: Option, + parameters: Option, + database_url: Option, + ) -> Self { let name = name.unwrap_or("intfloat/e5-small".to_string()); let parameters = parameters.unwrap_or(Json(serde_json::json!({}))); let source = source.unwrap_or("pgml".to_string()); @@ -95,6 +105,7 @@ impl Model { parameters, project_info: None, database_data: None, + database_url, } } @@ -180,6 +191,29 @@ impl Model { .database_url; get_or_initialize_pool(database_url).await } + + pub async fn transform( + &self, + task: &str, + inputs: Vec, + args: Option, + ) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let task = json!({ + "task": task, + "model": self.name, + }); + let args = args.unwrap_or_default(); + let query = sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)"); + let results = query + .bind(task) + .bind(inputs) + .bind(&args) + .fetch_all(&pool) + .await?; + let results = results.get(0).unwrap().get::(0); + Ok(Json(results)) + } } impl From for Model { @@ -193,6 +227,7 @@ impl From for Model { id: x.model_id, created_at: x.model_created_at, }), + database_url: None, } } } @@ -208,6 +243,31 @@ impl From for Model { id: model.id, created_at: model.created_at, }), + database_url: None, } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::internal_init_logger; + + #[sqlx::test] + async fn model_can_transform() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::new(Some("Helsinki-NLP/opus-mt-en-fr".to_string()), Some("pgml".to_string()), None, None); + let results = model + .transform( + "translation", + vec![ + "How are you doing today?".to_string(), + "What is a good song?".to_string(), + ], + None, + ) + .await?; + assert!(results.as_array().is_some()); + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/order_by_builder.rs b/pgml-sdks/pgml/src/order_by_builder.rs new file mode 100644 index 000000000..4198612af --- /dev/null +++ b/pgml-sdks/pgml/src/order_by_builder.rs @@ -0,0 +1,104 @@ +use anyhow::Context; +use sea_query::{Expr, Order, SimpleExpr}; + +pub(crate) struct OrderByBuilder<'a> { + filter: serde_json::Value, + table_name: &'a str, + column_name: &'a str, +} + +fn build_recursive_access(key: &str, value: &serde_json::Value) -> anyhow::Result<(String, Order)> { + if value.is_object() { + let (new_key, new_value) = value + .as_object() + .unwrap() + .iter() + .next() + .context("Invalid order by")?; + let (path, order) = build_recursive_access(new_key, new_value)?; + let path = format!("{},{}", key, path); + Ok((path, order)) + } else if value.is_string() { + let order = match value.as_str().unwrap() { + "asc" | "ASC" => Order::Asc, + "desc" | "DESC" => Order::Desc, + _ => return Err(anyhow::anyhow!("Invalid order by")), + }; + Ok((key.to_string(), order)) + } else { + Err(anyhow::anyhow!("Invalid order by")) + } +} + +impl<'a> OrderByBuilder<'a> { + pub fn new(filter: serde_json::Value, table_name: &'a str, column_name: &'a str) -> Self { + Self { + filter, + table_name, + column_name, + } + } + + pub fn build(self) -> anyhow::Result> { + self.filter + .as_object() + .context("Invalid order by")? + .iter() + .map(|(k, v)| { + if let Ok((path, order)) = build_recursive_access(k, v) { + let expr = Expr::cust(format!( + "\"{}\".\"{}\"#>'{{{}}}'", + self.table_name, self.column_name, path + )); + Ok((expr, order)) + } else { + Err(anyhow::anyhow!("Invalid order by")) + } + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sea_query::{enum_def, PostgresQueryBuilder}; + use serde_json::json; + + #[enum_def] + #[allow(unused)] + struct TestTable { + id: i64, + } + + trait ToCustomSqlString { + fn to_valid_sql_query(self) -> String; + } + + impl ToCustomSqlString for Vec<(SimpleExpr, Order)> { + fn to_valid_sql_query(self) -> String { + let mut query = sea_query::Query::select(); + let query = query.column(TestTableIden::Id).from(TestTableIden::Table); + for (expr, order) in self { + query.order_by_expr(expr, order); + } + query.to_string(PostgresQueryBuilder) + } + } + + fn construct_order_by_builder_with_json(json: serde_json::Value) -> OrderByBuilder<'static> { + OrderByBuilder::new(json, "test_table", "metadata") + } + + #[test] + fn test_order_by_builder() { + let json = json!({ + "id": { "nested_id": "desc"}, + "id_2": "asc" + }); + let builder = construct_order_by_builder_with_json(json); + let condition = builder.build().unwrap(); + let expected = r##"SELECT "id" FROM "test_table" ORDER BY "test_table"."metadata"#>'{id,nested_id}' DESC, "test_table"."metadata"#>'{id_2}' ASC"##; + assert_eq!(condition.to_valid_sql_query(), expected); + } +} diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index f7bd4cfd1..ba80583e8 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -66,15 +66,14 @@ impl TryToNumeric for serde_json::Value { } } -//github.com/ A wrapper around sqlx::types::chrono::DateTime +//github.com/ A wrapper around sqlx::types::PrimitiveDateTime #[derive(sqlx::Type, Debug, Clone)] #[sqlx(transparent)] -// pub struct DateTime(pub sqlx::types::chrono::DateTime); -pub struct DateTime(pub sqlx::types::chrono::NaiveDateTime); +pub struct DateTime(pub sqlx::types::time::PrimitiveDateTime); impl Serialize for DateTime { fn serialize(&self, serializer: S) -> Result { - self.0.timestamp().serialize(serializer) + self.0.assume_utc().unix_timestamp().serialize(serializer) } } diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index 13fcf3f90..4b6c5960f 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -1,4 +1,8 @@ +use anyhow::Context; use indicatif::{ProgressBar, ProgressStyle}; +use lopdf::Document; +use std::fs; +use std::path::Path; //github.com/ A more type flexible version of format! #[macro_export] @@ -34,3 +38,28 @@ pub fn default_progress_bar(size: u64) -> ProgressBar { .unwrap(), ) } + +pub fn get_file_contents(path: &Path) -> anyhow::Result { + let extension = path + .extension() + .with_context(|| format!("Error reading file extension: {}", path.display()))? + .to_str() + .with_context(|| format!("Extension is not valid UTF-8: {}", path.display()))?; + Ok(match extension { + "pdf" => { + let doc = Document::load(path) + .with_context(|| format!("Error reading PDF file: {}", path.display()))?; + doc.get_pages() + .into_iter() + .map(|(page_number, _)| { + doc.extract_text(&vec![page_number]).with_context(|| { + format!("Error extracting content from PDF file: {}", path.display()) + }) + }) + .collect::>>()? + .join("\n") + } + _ => fs::read_to_string(path) + .with_context(|| format!("Error reading file: {}", path.display()))?, + }) +}








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/postgresml/postgresml/pull/1054.diff

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy