diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b07bc57f0..d0cf77825 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: path: | ~/.cargo pgml-extension/target - ~/.pgx + ~/.pgrx key: ${{ runner.os }}-rust-${{ hashFiles('pgml-extension/Cargo.lock') }} - name: Submodules run: | @@ -43,20 +43,20 @@ jobs: run: | curl https://sh.rustup.rs -sSf | sh -s -- -y source ~/.cargo/env - cargo install cargo-pgx --version "0.7.1" + cargo install cargo-pgrx --version "0.7.1" - if [[ ! -d ~/.pgx ]]; then - cargo pgx init + if [[ ! -d ~/.pgrx ]]; then + cargo pgrx init fi - cargo pgx test + cargo pgrx test - cargo pgx stop - cargo pgx start + cargo pgrx stop + cargo pgrx start # psql -p 28813 -h 127.0.0.1 -d pgml -P pager -f tests/test.sql - cargo pgx stop + cargo pgrx stop diff --git a/.github/workflows/package-extension.yml b/.github/workflows/package-extension.yml index d9e31243d..e70a1d2b3 100644 --- a/.github/workflows/package-extension.yml +++ b/.github/workflows/package-extension.yml @@ -79,47 +79,47 @@ jobs: curl -sLO https://github.com/deb-s3/deb-s3/releases/download/0.11.4/deb-s3-0.11.4.gem sudo gem install deb-s3-0.11.4.gem dpkg-deb --version - - name: Install pgx + - name: Install pgrx uses: postgresml/gh-actions-cargo@master with: working-directory: pgml-extension command: install - args: cargo-pgx --version "0.7.1" - - name: pgx init + args: cargo-pgrx --version "0.7.1" + - name: pgrx init uses: postgresml/gh-actions-cargo@master with: working-directory: pgml-extension - command: pgx + command: pgrx args: init --pg11=/usr/lib/postgresql/11/bin/pg_config --pg12=/usr/lib/postgresql/12/bin/pg_config --pg13=/usr/lib/postgresql/13/bin/pg_config --pg14=/usr/lib/postgresql/14/bin/pg_config --pg15=/usr/lib/postgresql/15/bin/pg_config - name: Build Postgres 11 uses: postgresml/gh-actions-cargo@master with: working-directory: pgml-extension - command: pgx + command: pgrx args: package --pg-config /usr/lib/postgresql/11/bin/pg_config - name: Build Postgres 12 uses: postgresml/gh-actions-cargo@master with: working-directory: pgml-extension - command: pgx + command: pgrx args: package --pg-config /usr/lib/postgresql/12/bin/pg_config - name: Build Postgres 13 uses: postgresml/gh-actions-cargo@master with: working-directory: pgml-extension - command: pgx + command: pgrx args: package --pg-config /usr/lib/postgresql/13/bin/pg_config - name: Build Postgres 14 uses: postgresml/gh-actions-cargo@master with: working-directory: pgml-extension - command: pgx + command: pgrx args: package --pg-config /usr/lib/postgresql/14/bin/pg_config - name: Build Postgres 15 uses: postgresml/gh-actions-cargo@master with: working-directory: pgml-extension - command: pgx + command: pgrx args: package --pg-config /usr/lib/postgresql/15/bin/pg_config - name: Build debs env: diff --git a/pgml-docs/docs/blog/benchmarks/python_microservices_vs_postgresml/README.md b/pgml-docs/docs/blog/benchmarks/python_microservices_vs_postgresml/README.md index 4e9abd22f..4e45061b0 100644 --- a/pgml-docs/docs/blog/benchmarks/python_microservices_vs_postgresml/README.md +++ b/pgml-docs/docs/blog/benchmarks/python_microservices_vs_postgresml/README.md @@ -3,7 +3,7 @@ ## PostgresML ``` -cargo pgx run --release +cargo pgrx run --release ``` ### Schema diff --git a/pgml-docs/docs/blog/postgresml-is-moving-to-rust-for-our-2.0-release.md b/pgml-docs/docs/blog/postgresml-is-moving-to-rust-for-our-2.0-release.md index 80dfc323d..7848e4cf6 100644 --- a/pgml-docs/docs/blog/postgresml-is-moving-to-rust-for-our-2.0-release.md +++ b/pgml-docs/docs/blog/postgresml-is-moving-to-rust-for-our-2.0-release.md @@ -170,7 +170,7 @@ Spoiler alert: idiomatic Rust is about 10x faster than native SQL, embedded PL/p LIMIT 1; ``` -We're building with the Rust [pgx](https://github.com/tcdi/pgx/tree/master/pgx) crate that makes our development cycle even nicer than the one we use to manage Python. It really streamlines creating an extension in Rust, so all we have to worry about is writing our functions. It took about an hour to port all of our vector operations to Rust with BLAS support, and another week to port all the "business logic" for maintaining model training and deployment. We've even gained some new capabilities for caching models across connections (independent processes), now that we have access to Postgres shared memory, without having to worry about Python's GIL and GC. This is the dream of Apache's Arrow project, realized for our applications, without having to change the world, just our implementations. 🤩 Single-copy end-to-end machine learning, with parallel processing and shared data access. +We're building with the Rust [pgrx](https://github.com/tcdi/pgrx/tree/master/pgrx) crate that makes our development cycle even nicer than the one we use to manage Python. It really streamlines creating an extension in Rust, so all we have to worry about is writing our functions. It took about an hour to port all of our vector operations to Rust with BLAS support, and another week to port all the "business logic" for maintaining model training and deployment. We've even gained some new capabilities for caching models across connections (independent processes), now that we have access to Postgres shared memory, without having to worry about Python's GIL and GC. This is the dream of Apache's Arrow project, realized for our applications, without having to change the world, just our implementations. 🤩 Single-copy end-to-end machine learning, with parallel processing and shared data access. ## What about XGBoost and friends? ML isn't just about basic math and a little bit of business logic. It's about all those complicated algorithms beyond linear regression for gradient boosting and deep learning. The good news is that most of these libraries are implemented in C/C++, and just have Python bindings. There are also bindings for Rust ([lightgbm](https://github.com/vaaaaanquish/lightgbm-rs), [xgboost](https://github.com/davechallis/rust-xgboost), [tensorflow](https://github.com/tensorflow/rust), [torch](https://github.com/LaurentMazare/tch-rs)). diff --git a/pgml-docs/docs/developer_guide/overview.md b/pgml-docs/docs/developer_guide/overview.md index f2df939ed..d34bb84ba 100644 --- a/pgml-docs/docs/developer_guide/overview.md +++ b/pgml-docs/docs/developer_guide/overview.md @@ -62,7 +62,7 @@ The development environment for each differs slightly, but overall we use Python ## Postgres extension -PostgresML is a Rust extension written with `tcdi/pgx` crate. Local development therefore requires the [latest Rust compiler](https://www.rust-lang.org/learn/get-started) and PostgreSQL development headers and libraries. +PostgresML is a Rust extension written with `tcdi/pgrx` crate. Local development therefore requires the [latest Rust compiler](https://www.rust-lang.org/learn/get-started) and PostgreSQL development headers and libraries. The extension code is located in: @@ -72,17 +72,17 @@ cd pgml-extension/ You'll need to install basic dependencies -Once there, you can initialize `pgx` and get going: +Once there, you can initialize `pgrx` and get going: #### Pgx command line and environments ```commandline -cargo install cargo-pgx --version "0.7.1" && \ -cargo pgx init # This will take a few minutes +cargo install cargo-pgrx --version "0.7.4" && \ +cargo pgrx init # This will take a few minutes ``` #### Update postgresql.conf -`pgx` uses Postgres 15 by default. Since `pgml` is using shared memory, you need to add it to `shared_preload_libraries` in `postgresql.conf` which, for `pgx`, is located in `~/.pgx/data-15/postgresql.conf`. +`pgrx` uses Postgres 15 by default. Since `pgml` is using shared memory, you need to add it to `shared_preload_libraries` in `postgresql.conf` which, for `pgrx`, is located in `~/.pgrx/data-15/postgresql.conf`. ``` shared_preload_libraries = 'pgml' # (change requires restart) @@ -91,19 +91,19 @@ shared_preload_libraries = 'pgml' # (change requires restart) Run the unit tests ```commandline -cargo pgx test +cargo pgrx test ``` Run the integration tests: ```commandline -cargo pgx run --release +cargo pgrx run --release psql -h localhost -p 28813 -d pgml -f tests/test.sql -P pager ``` Run an interactive psql session ```commandline -cargo pgx run +cargo pgrx run ``` Create the extension in your database: @@ -147,10 +147,10 @@ By default, the extension is built without CUDA support for XGBoost and LightGBM ```commandline -CUDACXX=/usr/local/cuda/bin/nvcc cargo pgx run --release --features pg15,python,cuda +CUDACXX=/usr/local/cuda/bin/nvcc cargo pgrx run --release --features pg15,python,cuda ``` -If you ever want to reset the environment, simply spin up the database with `cargo pgx run` and drop the extension and metadata tables: +If you ever want to reset the environment, simply spin up the database with `cargo pgrx run` and drop the extension and metadata tables: ```postgresql DROP EXTENSION IF EXISTS pgml CASCADE; @@ -190,7 +190,7 @@ Basic installation can be achieved with: cd postgresml/pgml-dashboard ``` -2. Set the `DATABASE_URL` environment variable, for example to a running interactive `cargo pgx run` session started previously: +2. Set the `DATABASE_URL` environment variable, for example to a running interactive `cargo pgrx run` session started previously: ```commandline export DATABASE_URL=postgres://localhost:28815/pgml ``` diff --git a/pgml-docs/docs/user_guides/setup/v2/installation.md b/pgml-docs/docs/user_guides/setup/v2/installation.md index 9bd45da6a..a1e5f11b1 100644 --- a/pgml-docs/docs/user_guides/setup/v2/installation.md +++ b/pgml-docs/docs/user_guides/setup/v2/installation.md @@ -92,15 +92,15 @@ sudo apt-get install postgresql cd pgml-extension ``` - 5. Install [`pgx`](https://github.com/tcdi/pgx) and build the extension (this will take a few minutes): + 5. Install [`pgrx`](https://github.com/tcdi/pgrx) and build the extension (this will take a few minutes): **With Python support:** ```bash export POSTGRES_VERSION=15 - cargo install cargo-pgx --version "0.7.1" && \ - cargo pgx init --pg${POSTGRES_VERSION} /usr/bin/pg_config && \ - cargo pgx package + cargo install cargo-pgrx --version "0.7.4" && \ + cargo pgrx init --pg${POSTGRES_VERSION} /usr/bin/pg_config && \ + cargo pgrx package ``` **Without Python support:** @@ -108,9 +108,9 @@ sudo apt-get install postgresql ```bash export POSTGRES_VERSION=15 cp docker/Cargo.toml.no-python Cargo.toml && \ - cargo install cargo-pgx --version "0.7.1" && \ - cargo pgx init --pg${POSTGRES_VERSION} /usr/bin/pg_config && \ - cargo pgx package + cargo install cargo-pgrx --version "0.7.4" && \ + cargo pgrx init --pg${POSTGRES_VERSION} /usr/bin/pg_config && \ + cargo pgrx package ``` 6. Copy the extension binaries into Postgres system folders: @@ -152,12 +152,12 @@ sudo apt-get install postgresql For example, `openssl` requires some environment variables set in `~/.zsh` for the compiler to find the library. - 4. Install [`pgx`](https://github.com/tcdi/pgx) and build the extension (this will take a few minutes): + 4. Install [`pgrx`](https://github.com/tcdi/pgrx) and build the extension (this will take a few minutes): ``` - cargo install cargo-pgx && \ - cargo pgx init --pg15 /usr/bin/pg_config && \ - cargo pgx install + cargo install cargo-pgrx && \ + cargo pgrx init --pg15 /usr/bin/pg_config && \ + cargo pgrx install ``` diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index a731ecb6c..b0221a7e7 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -10,9 +10,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "aho-corasick" -version = "0.7.20" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" +checksum = "67fc08ce920c31afb70f013dcce1bfc3a3195de6a228474e45e1f145b36f8d04" dependencies = [ "memchr", ] @@ -26,6 +26,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "anstyle" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d" + [[package]] name = "anyhow" version = "1.0.70" @@ -84,7 +90,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] @@ -343,9 +349,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.2.1" +version = "4.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" +checksum = "956ac1f6381d8d82ab4684768f89c0ea3afe66925ceadb4eeb3fc452ffc55d62" dependencies = [ "clap_builder", "clap_derive", @@ -358,16 +364,17 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eca953650a7350560b61db95a0ab1d9c6f7b74d146a9e08fb258b834f3cf7e2c" dependencies = [ - "clap 4.2.1", + "clap 4.2.4", "doc-comment", ] [[package]] name = "clap_builder" -version = "4.2.1" +version = "4.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" +checksum = "84080e799e54cff944f4b4a4b0e71630b0e0443b25b985175c7dddc1a859b749" dependencies = [ + "anstyle", "bitflags", "clap_lex", ] @@ -381,7 +388,7 @@ dependencies = [ "heck", "proc-macro2", "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] @@ -426,9 +433,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "cpufeatures" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181" +checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" dependencies = [ "libc", ] @@ -450,9 +457,9 @@ checksum = "6548a0ad5d2549e111e1f6a11a6c2e2d00ce6a3dafe22948d67c2b443f775e52" [[package]] name = "crossbeam-channel" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf2b3e8478797446514c91ef04bafcb59faba183e621ad488df88983cc14128c" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" dependencies = [ "cfg-if", "crossbeam-utils", @@ -529,7 +536,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd4056f63fce3b82d852c3da92b08ea59959890813a7f4ce9c0ff85b10cf301b" dependencies = [ "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] @@ -716,13 +723,13 @@ dependencies = [ [[package]] name = "errno" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" dependencies = [ "errno-dragonfly", "libc", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -848,7 +855,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] @@ -890,9 +897,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", @@ -907,7 +914,7 @@ checksum = "e77ac7b51b8e6313251737fcef4b1c01a2ea102bde68415b62c0ee9268fec357" dependencies = [ "proc-macro2", "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] @@ -1102,9 +1109,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.141" +version = "0.2.142" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" +checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" [[package]] name = "libloading" @@ -1233,9 +1240,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.3.1" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" +checksum = "36eb31c1778188ae1e64398743890d0877fef36d11521ac60406b42016e8c2cf" [[package]] name = "lock_api" @@ -1258,9 +1265,9 @@ dependencies = [ [[package]] name = "matrixmultiply" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" +checksum = "bb99c395ae250e1bf9133673f03ca9f97b7e71b705436bf8f089453445d1e9fe" dependencies = [ "rawpointer", ] @@ -1417,9 +1424,9 @@ dependencies = [ [[package]] name = "ntapi" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc51db7b362b205941f71232e56c625156eb9a929f8cf74a428fd5bc094a4afc" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" dependencies = [ "winapi", ] @@ -1554,9 +1561,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.49" +version = "0.10.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d2f106ab837a24e03672c59b1239669a0596406ff657c3c0835b6b7f0f35a33" +checksum = "97ea2d98598bf9ada7ea6ee8a30fb74f9156b63bbe495d64ec2b87c269d2dda3" dependencies = [ "bitflags", "cfg-if", @@ -1575,7 +1582,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] @@ -1586,9 +1593,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.84" +version = "0.9.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a20eace9dc2d82904039cb76dcf50fb1a0bba071cfd1629720b5d6f1ddba0fa" +checksum = "992bac49bdbab4423199c654a5515bd2a6c6a23bf03f2dd3bdb7e5ae6259bc69" dependencies = [ "cc", "libc", @@ -1681,7 +1688,7 @@ dependencies = [ [[package]] name = "pgml" -version = "2.4.0" +version = "2.4.1" dependencies = [ "anyhow", "blas", @@ -1701,9 +1708,9 @@ dependencies = [ "once_cell", "openblas-src", "parking_lot", - "pgx", - "pgx-pg-sys", - "pgx-tests", + "pgrx", + "pgrx-pg-sys", + "pgrx-tests", "pyo3", "rand", "rmp-serde", @@ -1715,10 +1722,10 @@ dependencies = [ ] [[package]] -name = "pgx" +name = "pgrx" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c2947326bd9a80ec122207f0a59367592f79c053390d6ee961fe17a71ef1e3d" +checksum = "520e5f1b0c97fa07446eddf2e9c8a410ced351014a2339ef680bfc177d989b55" dependencies = [ "atomic-traits", "bitflags", @@ -1726,9 +1733,9 @@ dependencies = [ "heapless", "libc", "once_cell", - "pgx-macros", - "pgx-pg-sys", - "pgx-sql-entity-graph", + "pgrx-macros", + "pgrx-pg-sys", + "pgrx-sql-entity-graph", "seahash", "seq-macro", "serde", @@ -1742,22 +1749,22 @@ dependencies = [ ] [[package]] -name = "pgx-macros" +name = "pgrx-macros" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96bf5c70a467b39c1a67a2e1ec7acc4ba8bb32e5bf2d3dead2d89b8442f31ff9" +checksum = "b28d2441f5e36140541868ef5fb92b964e9c10e242fe2cfdcb464f05afd7fd16" dependencies = [ - "pgx-sql-entity-graph", + "pgrx-sql-entity-graph", "proc-macro2", "quote 1.0.26", "syn 1.0.109", ] [[package]] -name = "pgx-pg-config" +name = "pgrx-pg-config" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "020f2f1e0805a60321a375d0f27d771678d59b808bbb5f632c42607a661ab63a" +checksum = "bea222a7b5241102003693f48ee187878e92a5d4f6c5f344e00786899253978d" dependencies = [ "dirs 4.0.0", "eyre", @@ -1771,19 +1778,19 @@ dependencies = [ ] [[package]] -name = "pgx-pg-sys" +name = "pgrx-pg-sys" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2371dc1ee5c6f32b9a862fe1706e7ddf862003f167d21d9886b4b4f3f2391e" +checksum = "9ad42206dbd7780516c159da71a46b6785ee958595adebc87cbfb21ceb18e9cf" dependencies = [ "bindgen 0.60.1", "eyre", "libc", "memoffset 0.8.0", "once_cell", - "pgx-macros", - "pgx-pg-config", - "pgx-sql-entity-graph", + "pgrx-macros", + "pgrx-pg-config", + "pgrx-sql-entity-graph", "proc-macro2", "quote 1.0.26", "serde", @@ -1793,10 +1800,10 @@ dependencies = [ ] [[package]] -name = "pgx-sql-entity-graph" +name = "pgrx-sql-entity-graph" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5b7304665fe3a052dd353a08d013c4d5d780a49be8b60d27c430492b1d442e" +checksum = "87dd0800eb8ab52d161d57db297aaa00b2debdba77f738ff011664c8751e92c9" dependencies = [ "convert_case", "eyre", @@ -1808,19 +1815,19 @@ dependencies = [ ] [[package]] -name = "pgx-tests" +name = "pgrx-tests" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2dfa440a295e0a6bc1a7c87af83dc5e9f7a85c05d28b9fa77f1793f6883f917" +checksum = "e4ab1f3b9bbcdbea070d439e49ad146ea8a5a25104128a8b7f8522d21445a940" dependencies = [ "clap-cargo", "eyre", "libc", "once_cell", "owo-colors", - "pgx", - "pgx-macros", - "pgx-pg-config", + "pgrx", + "pgrx-macros", + "pgrx-pg-config", "postgres", "regex", "serde", @@ -2116,9 +2123,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.3" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" +checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370" dependencies = [ "aho-corasick", "memchr", @@ -2127,9 +2134,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.29" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" +checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c" [[package]] name = "rmp" @@ -2179,16 +2186,16 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.7" +version = "0.37.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aae838e49b3d63e9274e1c01833cc8139d3fec468c3b84688c628f44b1ae11d" +checksum = "d9b864d3c18a5785a05953adeed93e2dca37ed30f18e69bba9f30079d51f363f" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -2309,9 +2316,9 @@ checksum = "e6b44e8fc93a14e66336d230954dda83d18b4605ccace8fe09bc7514a71ad0bc" [[package]] name = "serde" -version = "1.0.159" +version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" +checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" dependencies = [ "serde_derive", ] @@ -2328,20 +2335,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.159" +version = "1.0.160" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" +checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" dependencies = [ "proc-macro2", "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] name = "serde_json" -version = "1.0.95" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" +checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" dependencies = [ "indexmap", "itoa", @@ -2485,12 +2492,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc8d618c6641ae355025c449427f9e96b98abf99a772be3cef6708d15c77147a" +checksum = "6d283f86695ae989d1e18440a943880967156325ba025f05049946bff47bcc2b" dependencies = [ "libc", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -2579,9 +2586,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.13" +version = "2.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec" +checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822" dependencies = [ "proc-macro2", "quote 1.0.26", @@ -2700,7 +2707,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] @@ -2791,7 +2798,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "postgres-types", - "socket2 0.5.1", + "socket2 0.5.2", "tokio", "tokio-util", ] @@ -2888,9 +2895,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.16" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" +checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" dependencies = [ "sharded-slab", "thread_local", @@ -2905,9 +2912,9 @@ checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "typetag" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edc3ebbaab23e6cc369cb48246769d031f5bd85f1b28141f32982e3c0c7b33cf" +checksum = "6a6898cc6f6a32698cc3e14d5632a14d2b23ed9f7b11e6b8e05ce685990acc22" dependencies = [ "erased-serde", "inventory", @@ -2918,13 +2925,13 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb01b60fcc3f5e17babb1a9956263f3ccd2cadc3e52908400231441683283c1d" +checksum = "2c3e1c30cedd24fc597f7d37a721efdbdc2b1acae012c1ef1218f4c7c2c0f3e7" dependencies = [ "proc-macro2", "quote 1.0.26", - "syn 2.0.13", + "syn 2.0.15", ] [[package]] @@ -3012,9 +3019,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1674845326ee10d37ca60470760d4288a6f80f304007d92e5c53bab78c9cfd79" +checksum = "5b55a3fef2a1e3b3a00ce878640918820d3c51081576ac657d23af9fc7928fdb" dependencies = [ "getrandom", ] diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index 2e68aa17c..a0a879007 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "2.4.0" +version = "2.4.1" edition = "2021" [lib] @@ -8,18 +8,18 @@ crate-type = ["cdylib"] [features] default = ["pg15", "python"] -pg11 = ["pgx/pg11", "pgx-tests/pg11" ] -pg12 = ["pgx/pg12", "pgx-tests/pg12" ] -pg13 = ["pgx/pg13", "pgx-tests/pg13" ] -pg14 = ["pgx/pg14", "pgx-tests/pg14" ] -pg15 = ["pgx/pg15", "pgx-tests/pg15" ] +pg11 = ["pgrx/pg11", "pgrx-tests/pg11" ] +pg12 = ["pgrx/pg12", "pgrx-tests/pg12" ] +pg13 = ["pgrx/pg13", "pgrx-tests/pg13" ] +pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ] +pg15 = ["pgrx/pg15", "pgrx-tests/pg15" ] pg_test = [] python = ["pyo3"] cuda = ["xgboost/cuda", "lightgbm/cuda"] [dependencies] -pgx = "=0.7.4" -pgx-pg-sys = "=0.7.4" +pgrx = "=0.7.4" +pgrx-pg-sys = "=0.7.4" xgboost = { git="https://github.com/postgresml/rust-xgboost.git", branch = "master" } once_cell = "1" rand = "0.8" @@ -48,7 +48,7 @@ flate2 = "1.0" csv = "1.1" [dev-dependencies] -pgx-tests = "=0.7.4" +pgrx-tests = "=0.7.4" [profile.dev] panic = "unwind" diff --git a/pgml-extension/Dockerfile b/pgml-extension/Dockerfile index 25f336260..1c61f9e99 100644 --- a/pgml-extension/Dockerfile +++ b/pgml-extension/Dockerfile @@ -37,8 +37,8 @@ RUN useradd postgresml -m -s /bin/bash -G sudo RUN echo 'postgresml ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers USER postgresml RUN curl https://sh.rustup.rs -sSf | sh -s -- -y -RUN $HOME/.cargo/bin/cargo install cargo-pgx --version "0.7.4" -RUN $HOME/.cargo/bin/cargo pgx init +RUN $HOME/.cargo/bin/cargo install cargo-pgrx --version "0.7.4" +RUN $HOME/.cargo/bin/cargo pgrx init RUN curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor | sudo tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg >/dev/null RUN sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' RUN sudo apt update diff --git a/pgml-extension/control b/pgml-extension/control index 640540353..7ce5f6059 100644 --- a/pgml-extension/control +++ b/pgml-extension/control @@ -11,5 +11,5 @@ Description: PostgresML - machine learning with PostgreSQL learning directly in the database. It allows to both train algorithms on tables or views, and to predict novel datapoints using only SQL. - The extension is written in Rust using tcdi/pgx with some additional + The extension is written in Rust using tcdi/pgrx with some additional functionality written in Python & PLPython. diff --git a/pgml-extension/docker/Cargo.toml.cuda b/pgml-extension/docker/Cargo.toml.cuda index 97102f767..8ee22af42 100644 --- a/pgml-extension/docker/Cargo.toml.cuda +++ b/pgml-extension/docker/Cargo.toml.cuda @@ -8,19 +8,19 @@ crate-type = ["cdylib"] [features] default = ["pg15", "python", "cuda"] -pg10 = ["pgx/pg10", "pgx-tests/pg10" ] -pg11 = ["pgx/pg11", "pgx-tests/pg11" ] -pg12 = ["pgx/pg12", "pgx-tests/pg12" ] -pg13 = ["pgx/pg13", "pgx-tests/pg13" ] -pg14 = ["pgx/pg14", "pgx-tests/pg14" ] -pg15 = ["pgx/pg15", "pgx-tests/pg15" ] +pg10 = ["pgrx/pg10", "pgrx-tests/pg10" ] +pg11 = ["pgrx/pg11", "pgrx-tests/pg11" ] +pg12 = ["pgrx/pg12", "pgrx-tests/pg12" ] +pg13 = ["pgrx/pg13", "pgrx-tests/pg13" ] +pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ] +pg15 = ["pgrx/pg15", "pgrx-tests/pg15" ] pg_test = [] python = ["pyo3"] cuda = ["xgboost/cuda", "lightgbm/cuda"] [dependencies] -pgx = "=0.7.1" -pgx-pg-sys = "=0.7.1" +pgrx = "=0.7.4" +pgrx-pg-sys = "=0.7.4" xgboost = { git="https://github.com/postgresml/rust-xgboost.git", branch = "master" } once_cell = "1" rand = "0.8" @@ -49,7 +49,7 @@ flate2 = "1.0" csv = "1.1" [dev-dependencies] -pgx-tests = "=0.7.1" +pgrx-tests = "=0.7.4" [profile.dev] panic = "unwind" diff --git a/pgml-extension/docker/Cargo.toml.no-python b/pgml-extension/docker/Cargo.toml.no-python index e4dbf82cd..e1cc1bc42 100644 --- a/pgml-extension/docker/Cargo.toml.no-python +++ b/pgml-extension/docker/Cargo.toml.no-python @@ -8,19 +8,19 @@ crate-type = ["cdylib"] [features] default = ["pg15"] -pg10 = ["pgx/pg10", "pgx-tests/pg10" ] -pg11 = ["pgx/pg11", "pgx-tests/pg11" ] -pg12 = ["pgx/pg12", "pgx-tests/pg12" ] -pg13 = ["pgx/pg13", "pgx-tests/pg13" ] -pg14 = ["pgx/pg14", "pgx-tests/pg14" ] -pg14 = ["pgx/pg15", "pgx-tests/pg15" ] +pg10 = ["pgrx/pg10", "pgrx-tests/pg10" ] +pg11 = ["pgrx/pg11", "pgrx-tests/pg11" ] +pg12 = ["pgrx/pg12", "pgrx-tests/pg12" ] +pg13 = ["pgrx/pg13", "pgrx-tests/pg13" ] +pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ] +pg14 = ["pgrx/pg15", "pgrx-tests/pg15" ] pg_test = [] python = ["pyo3"] cuda = ["xgboost/cuda", "lightgbm/cuda"] [dependencies] -pgx = "=0.7.1" -pgx-pg-sys = "=0.7.1" +pgrx = "=0.7.4" +pgrx-pg-sys = "=0.7.4" xgboost = { git="https://github.com/postgresml/rust-xgboost.git", branch = "master" } once_cell = "1" rand = "0.8" @@ -49,7 +49,7 @@ flate2 = "1.0" csv = "1.1" [dev-dependencies] -pgx-tests = "=0.7.1" +pgrx-tests = "=0.7.4" [profile.dev] panic = "unwind" diff --git a/pgml-extension/examples/vectors.sql b/pgml-extension/examples/vectors.sql index 5fd546e57..335124705 100644 --- a/pgml-extension/examples/vectors.sql +++ b/pgml-extension/examples/vectors.sql @@ -30,3 +30,17 @@ SELECT pgml.distance_l1(ARRAY[1.0, 2.0, 3.0]::FLOAT4[], ARRAY[4.0, 5.0, 6.0]::FL SELECT pgml.distance_l2(ARRAY[1.0, 2.0, 3.0]::FLOAT4[], ARRAY[4.0, 5.0, 6.0]::FLOAT4[]); SELECT pgml.dot_product(ARRAY[1.0, 2.0, 3.0]::FLOAT4[], ARRAY[4.0, 5.0, 6.0]::FLOAT4[]); SELECT pgml.cosine_similarity(ARRAY[1.0, 2.0, 3.0]::FLOAT4[], ARRAY[1.0, 2.0, 3.0]::FLOAT4[]); + +-- Aggregates +WITH vectors AS ( +SELECT * FROM ( + VALUES + (ARRAY[-2,-4,-6,-8]::FLOAT4[]), + (ARRAY[-1,-2,-3,-4]::FLOAT4[]), + (ARRAY[0,0,0,0]::FLOAT4[]), + (ARRAY[1,2,3,4]::FLOAT4[]), + (ARRAY[1,2,3,4]::FLOAT4[]), + (NULL) + ) AS vectors (embedding) +) SELECT pgml.sum(embedding), pgml.min(embedding), pgml.max(embedding), pgml.min_abs(embedding), pgml.max_abs(embedding), pgml.divide(pgml.sum(embedding), count(embedding)) as avg + FROM vectors; diff --git a/pgml-extension/sql/setup_examples.sql b/pgml-extension/sql/setup_examples.sql index af696af89..872588f1a 100644 --- a/pgml-extension/sql/setup_examples.sql +++ b/pgml-extension/sql/setup_examples.sql @@ -3,7 +3,7 @@ --- --- Usage: --- ---- $ cargo pgx run --release +--- $ cargo pgrx run --release --- $ psql -P pager-off -h localhost -p 28813 -d pgml -f sql/setup_examples.sql --- \set ON_ERROR_STOP true diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 7f09fb8c8..cd8538c0c 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -3,8 +3,8 @@ use std::fmt::Write; use std::str::FromStr; use ndarray::Zip; -use pgx::iter::{SetOfIterator, TableIterator}; -use pgx::*; +use pgrx::iter::{SetOfIterator, TableIterator}; +use pgrx::*; #[cfg(feature = "python")] use pyo3::prelude::*; @@ -464,7 +464,7 @@ fn predict_batch(project_name: &str, features: Vec) -> SetOfIterator<'stati } #[pg_extern(strict, name = "predict")] -fn predict_row(project_name: &str, row: pgx::datum::AnyElement) -> f32 { +fn predict_row(project_name: &str, row: pgrx::datum::AnyElement) -> f32 { predict_model_row(Project::get_deployed_model_id(project_name), row) } @@ -489,7 +489,7 @@ fn predict_model_batch(model_id: i64, features: Vec) -> Vec { } #[pg_extern(strict, name = "predict")] -fn predict_model_row(model_id: i64, row: pgx::datum::AnyElement) -> f32 { +fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 { let model = Model::find_cached(model_id); let snapshot = &model.snapshot; let numeric_encoded_features = model.numeric_encode_features(&[row]); @@ -534,7 +534,7 @@ fn load_dataset( limit: default!(Option, "NULL"), kwargs: default!(JsonB, "'{}'"), ) -> TableIterator<'static, (name!(table_name, String), name!(rows, i64))> { - // cast limit since pgx doesn't support usize + // cast limit since pgrx doesn't support usize let limit: Option = limit.map(|limit| limit.try_into().unwrap()); let (name, rows) = match source { "breast_cancer" => dataset::load_breast_cancer(limit), diff --git a/pgml-extension/src/bindings/lightgbm.rs b/pgml-extension/src/bindings/lightgbm.rs index e5795080e..65825f4ae 100644 --- a/pgml-extension/src/bindings/lightgbm.rs +++ b/pgml-extension/src/bindings/lightgbm.rs @@ -3,7 +3,7 @@ use crate::orm::dataset::Dataset; use crate::orm::task::Task; use crate::orm::Hyperparams; use lightgbm; -use pgx::*; +use pgrx::*; use serde_json::json; pub struct Estimator { diff --git a/pgml-extension/src/bindings/linfa.rs b/pgml-extension/src/bindings/linfa.rs index 420337208..8c358e7f6 100644 --- a/pgml-extension/src/bindings/linfa.rs +++ b/pgml-extension/src/bindings/linfa.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use super::Bindings; use crate::orm::*; -use pgx::*; +use pgrx::*; #[derive(Debug, Serialize, Deserialize)] pub struct LinearRegression { diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index a1d35526a..6f62ebd34 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; #[allow(unused_imports)] // used for test macros -use pgx::*; +use pgrx::*; use crate::orm::*; diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 0ef690410..30cb4dc49 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -6,6 +6,7 @@ import numpy as np import datasets +from InstructorEmbedding import INSTRUCTOR from rouge import Rouge from sacrebleu.metrics import BLEU from sentence_transformers import SentenceTransformer @@ -67,10 +68,23 @@ def transform(task, args, inputs, cache): def embed(transformer, text, kwargs): kwargs = json.loads(kwargs) + instructor = transformer.startswith("hkunlp/instructor") + if instructor: + klass = INSTRUCTOR + text = [[kwargs.pop("instruction"), text]] + else: + klass = SentenceTransformer + if transformer not in __cache_sentence_transformer_by_name: - __cache_sentence_transformer_by_name[transformer] = SentenceTransformer(transformer) + __cache_sentence_transformer_by_name[transformer] = klass(transformer) model = __cache_sentence_transformer_by_name[transformer] - return model.encode(text, **kwargs) + + result = model.encode(text, **kwargs) + if instructor: + result = result[0] + + return result + def load_dataset(name, subset, limit: None, kwargs: "{}"): kwargs = json.loads(kwargs) diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 504202ba8..95b9ab298 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use std::str::FromStr; use once_cell::sync::Lazy; -use pgx::*; +use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs index b7bebd91d..2eafc6839 100644 --- a/pgml-extension/src/bindings/xgboost.rs +++ b/pgml-extension/src/bindings/xgboost.rs @@ -13,7 +13,7 @@ use crate::orm::Hyperparams; use crate::bindings::Bindings; -use pgx::*; +use pgrx::*; #[pg_extern] fn xgboost_version() -> String { diff --git a/pgml-extension/src/lib.rs b/pgml-extension/src/lib.rs index a692ce368..e47825ad2 100644 --- a/pgml-extension/src/lib.rs +++ b/pgml-extension/src/lib.rs @@ -4,7 +4,7 @@ extern crate openblas_src; extern crate serde; extern crate signal_hook; -use pgx::*; +use pgrx::*; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/pgml-extension/src/orm/algorithm.rs b/pgml-extension/src/orm/algorithm.rs index 277e90147..098e04adb 100644 --- a/pgml-extension/src/orm/algorithm.rs +++ b/pgml-extension/src/orm/algorithm.rs @@ -1,4 +1,4 @@ -use pgx::*; +use pgrx::*; use serde::Deserialize; #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index 8092522ab..92d20369c 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -1,7 +1,7 @@ use std::fmt::{Display, Formatter}; use flate2::read::GzDecoder; -use pgx::*; +use pgrx::*; use serde::Deserialize; #[derive(Debug)] diff --git a/pgml-extension/src/orm/file.rs b/pgml-extension/src/orm/file.rs index 89fa059c5..121243661 100644 --- a/pgml-extension/src/orm/file.rs +++ b/pgml-extension/src/orm/file.rs @@ -4,7 +4,7 @@ use std::str::FromStr; use std::sync::Arc; use once_cell::sync::Lazy; -use pgx::*; +use pgrx::*; use crate::bindings::Bindings; diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index bc9fbd7c3..9866f5c33 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -10,8 +10,8 @@ use indexmap::IndexMap; use itertools::{izip, Itertools}; use ndarray::ArrayView1; use once_cell::sync::Lazy; -use pgx::heap_tuple::PgHeapTuple; -use pgx::*; +use pgrx::heap_tuple::PgHeapTuple; +use pgrx::*; use rand::prelude::SliceRandom; use serde_json::json; @@ -928,13 +928,13 @@ impl Model { ).unwrap(); } - pub fn numeric_encode_features(&self, rows: &[pgx::datum::AnyElement]) -> Vec { - // TODO handle FLOAT4[] as if it were pgx::datum::AnyElement, skipping all this, and going straight to predict + pub fn numeric_encode_features(&self, rows: &[pgrx::datum::AnyElement]) -> Vec { + // TODO handle FLOAT4[] as if it were pgrx::datum::AnyElement, skipping all this, and going straight to predict let mut features = Vec::new(); // TODO pre-allocate space let columns = &self.snapshot.columns; for row in rows { match row.oid() { - pgx_pg_sys::RECORDOID => { + pgrx_pg_sys::RECORDOID => { let tuple = unsafe { PgHeapTuple::from_composite_datum(row.datum()) }; for index in 1..tuple.len() + 1 { let column = &columns[index - 1]; @@ -944,19 +944,19 @@ impl Model { match &column.statistics.categories { Some(_categories) => { let key = match attribute.atttypid { - pgx_pg_sys::UNKNOWNOID => { + pgrx_pg_sys::UNKNOWNOID => { error!("Type information missing for column: {:?}. If this is intended to be a TEXT or other categorical column, you will need to explicitly cast it, e.g. change `{:?}` to `CAST({:?} AS TEXT)`.", column.name, column.name, column.name); } - pgx_pg_sys::TEXTOID - | pgx_pg_sys::VARCHAROID - | pgx_pg_sys::BPCHAROID => { + pgrx_pg_sys::TEXTOID + | pgrx_pg_sys::VARCHAROID + | pgrx_pg_sys::BPCHAROID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); element .unwrap() .unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) } - pgx_pg_sys::BOOLOID => { + pgrx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); element @@ -965,7 +965,7 @@ impl Model { k.to_string() }) } - pgx_pg_sys::INT2OID => { + pgrx_pg_sys::INT2OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); element @@ -974,7 +974,7 @@ impl Model { k.to_string() }) } - pgx_pg_sys::INT4OID => { + pgrx_pg_sys::INT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); element @@ -983,7 +983,7 @@ impl Model { k.to_string() }) } - pgx_pg_sys::INT8OID => { + pgrx_pg_sys::INT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); element @@ -992,7 +992,7 @@ impl Model { k.to_string() }) } - pgx_pg_sys::FLOAT4OID => { + pgrx_pg_sys::FLOAT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); element @@ -1001,7 +1001,7 @@ impl Model { k.to_string() }) } - pgx_pg_sys::FLOAT8OID => { + pgrx_pg_sys::FLOAT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); element @@ -1020,82 +1020,82 @@ impl Model { } None => { match attribute.atttypid { - pgx_pg_sys::UNKNOWNOID => { + pgrx_pg_sys::UNKNOWNOID => { error!("Type information missing for column: {:?}. If this is intended to be a FLOAT4 or other numeric column, you will need to explicitly cast it, e.g. change `{:?}` to `CAST({:?} AS FLOAT4)`.", column.name, column.name, column.name); } - pgx_pg_sys::BOOLOID => { + pgrx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); features.push( element.unwrap().map_or(f32::NAN, |v| v as u8 as f32), ); } - pgx_pg_sys::INT2OID => { + pgrx_pg_sys::INT2OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); features .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } - pgx_pg_sys::INT4OID => { + pgrx_pg_sys::INT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); features .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } - pgx_pg_sys::INT8OID => { + pgrx_pg_sys::INT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); features .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } - pgx_pg_sys::FLOAT4OID => { + pgrx_pg_sys::FLOAT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); features.push(element.unwrap().map_or(f32::NAN, |v| v)); } - pgx_pg_sys::FLOAT8OID => { + pgrx_pg_sys::FLOAT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); features .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } // TODO handle NULL to NaN for arrays - pgx_pg_sys::BOOLARRAYOID => { + pgrx_pg_sys::BOOLARRAYOID => { let element: Result>, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as i8 as f32); } } - pgx_pg_sys::INT2ARRAYOID => { + pgrx_pg_sys::INT2ARRAYOID => { let element: Result>, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as f32); } } - pgx_pg_sys::INT4ARRAYOID => { + pgrx_pg_sys::INT4ARRAYOID => { let element: Result>, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as f32); } } - pgx_pg_sys::INT8ARRAYOID => { + pgrx_pg_sys::INT8ARRAYOID => { let element: Result>, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as f32); } } - pgx_pg_sys::FLOAT4ARRAYOID => { + pgrx_pg_sys::FLOAT4ARRAYOID => { let element: Result>, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); for j in element.as_ref().unwrap().as_ref().unwrap() { features.push(*j as f32); } } - pgx_pg_sys::FLOAT8ARRAYOID => { + pgrx_pg_sys::FLOAT8ARRAYOID => { let element: Result>, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); for j in element.as_ref().unwrap().as_ref().unwrap() { diff --git a/pgml-extension/src/orm/project.rs b/pgml-extension/src/orm/project.rs index caf12e022..595fb09ea 100644 --- a/pgml-extension/src/orm/project.rs +++ b/pgml-extension/src/orm/project.rs @@ -4,7 +4,7 @@ use std::fmt::{Display, Error, Formatter}; use std::str::FromStr; use once_cell::sync::Lazy; -use pgx::*; +use pgrx::*; use crate::orm::*; diff --git a/pgml-extension/src/orm/runtime.rs b/pgml-extension/src/orm/runtime.rs index abc02ed35..a28656b7d 100644 --- a/pgml-extension/src/orm/runtime.rs +++ b/pgml-extension/src/orm/runtime.rs @@ -1,4 +1,4 @@ -use pgx::*; +use pgrx::*; use serde::Deserialize; #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] diff --git a/pgml-extension/src/orm/sampling.rs b/pgml-extension/src/orm/sampling.rs index f8781f97b..6bb3d7b5a 100644 --- a/pgml-extension/src/orm/sampling.rs +++ b/pgml-extension/src/orm/sampling.rs @@ -1,4 +1,4 @@ -use pgx::*; +use pgrx::*; use serde::Deserialize; #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] diff --git a/pgml-extension/src/orm/search.rs b/pgml-extension/src/orm/search.rs index 58961af8b..4169e3e27 100644 --- a/pgml-extension/src/orm/search.rs +++ b/pgml-extension/src/orm/search.rs @@ -1,4 +1,4 @@ -use pgx::*; +use pgrx::*; use serde::Deserialize; #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 121c82066..27f623347 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -5,7 +5,7 @@ use std::str::FromStr; use indexmap::IndexMap; use ndarray::Zip; -use pgx::*; +use pgrx::*; use serde::{Deserialize, Serialize}; use serde_json::json; diff --git a/pgml-extension/src/orm/status.rs b/pgml-extension/src/orm/status.rs index 4f99468e5..10d05e1f1 100644 --- a/pgml-extension/src/orm/status.rs +++ b/pgml-extension/src/orm/status.rs @@ -1,4 +1,4 @@ -use pgx::*; +use pgrx::*; use serde::Deserialize; #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] diff --git a/pgml-extension/src/orm/strategy.rs b/pgml-extension/src/orm/strategy.rs index 597f45f4e..2e8e54edf 100644 --- a/pgml-extension/src/orm/strategy.rs +++ b/pgml-extension/src/orm/strategy.rs @@ -1,4 +1,4 @@ -use pgx::*; +use pgrx::*; use serde::Deserialize; #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs index f9285a2cd..bd9d69d56 100644 --- a/pgml-extension/src/orm/task.rs +++ b/pgml-extension/src/orm/task.rs @@ -1,4 +1,4 @@ -use pgx::*; +use pgrx::*; use serde::Deserialize; #[derive(PostgresEnum, Copy, Clone, Eq, PartialEq, Debug, Deserialize)] @@ -14,7 +14,7 @@ pub enum Task { text2text, } -// unfortunately the pgx macro expands the enum names to underscore, but huggingface uses dash +// unfortunately the pgrx macro expands the enum names to underscore, but huggingface uses dash impl Task { pub fn to_pg_enum(&self) -> String { match *self { diff --git a/pgml-extension/src/vectors.rs b/pgml-extension/src/vectors.rs index 411f0b7eb..57cbe2abe 100644 --- a/pgml-extension/src/vectors.rs +++ b/pgml-extension/src/vectors.rs @@ -1,4 +1,4 @@ -use pgx::*; +use pgrx::*; #[pg_extern(immutable, parallel_safe, strict, name = "add")] fn add_scalar_s(vector: Vec, addend: f32) -> Vec { @@ -330,6 +330,664 @@ fn cosine_similarity_d(vector: Vec, other: Vec) -> f64 { } } +#[derive(Copy, Clone, Default, Debug)] +pub struct SumS; + +#[pg_aggregate] +impl Aggregate for SumS { + const NAME: &'static str = "sum"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg); + } + Some(ref mut vec) => { + for (i, v) in arg.iter().enumerate() { + vec[i] += v; + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, v) in second_inner.iter().enumerate() { + first_inner[i] += v; + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + +#[derive(Copy, Clone, Default, Debug)] +pub struct SumD; + +#[pg_aggregate] +impl Aggregate for SumD { + const NAME: &'static str = "sum"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg); + } + Some(ref mut vec) => { + for (i, v) in arg.iter().enumerate() { + vec[i] += v; + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, v) in second_inner.iter().enumerate() { + first_inner[i] += v; + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + + +#[derive(Copy, Clone, Default, Debug)] +pub struct MaxAbsS; + +#[pg_aggregate] +impl Aggregate for MaxAbsS { + const NAME: &'static str = "max_abs"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg.into_iter().map(|v| v.abs()).collect()); + } + Some(ref mut vec) => { + for (i, &v) in arg.iter().enumerate() { + if v.abs() > vec[i].abs() { + vec[i] = v.abs(); + } + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, &v) in second_inner.iter().enumerate() { + if v.abs() > first_inner[i].abs() { + first_inner[i] = v.abs(); + } + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + +#[derive(Copy, Clone, Default, Debug)] +pub struct MaxAbsD {} + +#[pg_aggregate] +impl Aggregate for MaxAbsD { + const NAME: &'static str = "max_abs"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg.into_iter().map(|v| v.abs()).collect()); + } + Some(ref mut vec) => { + for (i, &v) in arg.iter().enumerate() { + if v.abs() > vec[i].abs() { + vec[i] = v.abs(); + } + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, &v) in second_inner.iter().enumerate() { + if v.abs() > first_inner[i].abs() { + first_inner[i] = v.abs(); + } + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + + + +#[derive(Copy, Clone, Default, Debug)] +pub struct MaxS; + +#[pg_aggregate] +impl Aggregate for MaxS { + const NAME: &'static str = "max"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg); + } + Some(ref mut vec) => { + for (i, &v) in arg.iter().enumerate() { + if v > vec[i] { + vec[i] = v; + } + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, &v) in second_inner.iter().enumerate() { + if v > first_inner[i] { + first_inner[i] = v; + } + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + +#[derive(Copy, Clone, Default, Debug)] +pub struct MaxD {} + +#[pg_aggregate] +impl Aggregate for MaxD { + const NAME: &'static str = "max"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg); + } + Some(ref mut vec) => { + for (i, &v) in arg.iter().enumerate() { + if v > vec[i] { + vec[i] = v; + } + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, &v) in second_inner.iter().enumerate() { + if v > first_inner[i] { + first_inner[i] = v; + } + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + + +#[derive(Copy, Clone, Default, Debug)] +pub struct MinS; + +#[pg_aggregate] +impl Aggregate for MinS { + const NAME: &'static str = "min"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg); + } + Some(ref mut vec) => { + for (i, &v) in arg.iter().enumerate() { + if v < vec[i] { + vec[i] = v; + } + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, &v) in second_inner.iter().enumerate() { + if v < first_inner[i] { + first_inner[i] = v; + } + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + +#[derive(Copy, Clone, Default, Debug)] +pub struct MinD {} + +#[pg_aggregate] +impl Aggregate for MinD { + const NAME: &'static str = "min"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg); + } + Some(ref mut vec) => { + for (i, &v) in arg.iter().enumerate() { + if v < vec[i] { + vec[i] = v; + } + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, &v) in second_inner.iter().enumerate() { + if v < first_inner[i] { + first_inner[i] = v; + } + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + + +#[derive(Copy, Clone, Default, Debug)] +pub struct MinAbsS; + +#[pg_aggregate] +impl Aggregate for MinAbsS { + const NAME: &'static str = "min_abs"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg.into_iter().map(|v| v.abs()).collect()); + } + Some(ref mut vec) => { + for (i, &v) in arg.iter().enumerate() { + if v.abs() < vec[i].abs() { + vec[i] = v.abs(); + } + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, &v) in second_inner.iter().enumerate() { + if v.abs() < first_inner[i].abs() { + first_inner[i] = v.abs(); + } + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + +#[derive(Copy, Clone, Default, Debug)] +pub struct MinAbsD {} + +#[pg_aggregate] +impl Aggregate for MinAbsD { + const NAME: &'static str = "min_abs"; + type Args = Option>; + type State = Option>; + type Finalize = Vec; + + fn state( + mut current: Self::State, + arg: Self::Args, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::State { + match arg { + None => {}, + Some(arg) => { + match current { + None => { + _ = current.insert(arg.into_iter().map(|v| v.abs()).collect()); + } + Some(ref mut vec) => { + for (i, &v) in arg.iter().enumerate() { + if v.abs() < vec[i].abs() { + vec[i] = v.abs(); + } + } + } + } + } + } + current + } + + fn combine( + mut first: Self::State, + second: Self::State, + _fcinfo: pg_sys::FunctionCallInfo + ) -> Self::State { + match (&mut first, &second) { + (None, None) => None, + (Some(_), None) => first, + (None, Some(_)) => second, + (Some(first_inner), Some(second_inner)) => { + for (i, &v) in second_inner.iter().enumerate() { + if v.abs() < first_inner[i].abs() { + first_inner[i] = v.abs(); + } + } + first + } + } + } + + fn finalize( + mut current: Self::State, + _direct_arg: Self::OrderedSetArgs, + _fcinfo: pg_sys::FunctionCallInfo, + ) -> Self::Finalize { + let inner = current.get_or_insert_with(|| Vec::new() ); + + inner.clone() + } +} + + #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql index ed14c510d..ece0cd165 100644 --- a/pgml-extension/tests/test.sql +++ b/pgml-extension/tests/test.sql @@ -3,7 +3,7 @@ --- --- Usage: --- ---- $ cargo pgx run --release +--- $ cargo pgrx run --release --- $ psql -h localhost -p 28815 -d pgml -f tests/test.sql -P pager --- \set ON_ERROR_STOP true pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy