diff --git a/.gitignore b/.gitignore index 4434b9069..8d5ed3336 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ # local scratch pad scratch.sql +scratch.py diff --git a/pgml-dashboard/src/models.rs b/pgml-dashboard/src/models.rs index 9469a8861..3fbcaaf79 100644 --- a/pgml-dashboard/src/models.rs +++ b/pgml-dashboard/src/models.rs @@ -60,16 +60,22 @@ impl Project { pub fn key_metric_name(&self) -> anyhow::Result<&'static str> { match self.task.as_ref().unwrap().as_str() { - "classification" | "text-classification" => Ok("f1"), + "classification" | "text_classification" | "question_answering" => Ok("f1"), "regression" => Ok("r2"), + "summarization" => Ok("rouge_ngram_f1"), + "translation" => Ok("bleu"), + "text_generation" | "text2text" => Ok("perplexity"), task => Err(anyhow::anyhow!("Unhandled task: {}", task)), } } pub fn key_metric_display_name(&self) -> anyhow::Result<&'static str> { match self.task.as_ref().unwrap().as_str() { - "classification" | "text-classification" => Ok("F1"), + "classification" | "text_classification" | "question_answering" => Ok("F1"), "regression" => Ok("R2"), + "summarization" => Ok("Rouge Ngram F1"), + "translation" => Ok("Bleu"), + "text_generation" | "text2text" => Ok("Perplexity"), task => Err(anyhow::anyhow!("Unhandled task: {}", task)), } } diff --git a/pgml-docs/docs/user_guides/transformers/fine_tuning.md b/pgml-docs/docs/user_guides/transformers/fine_tuning.md index 287aae92e..26cfb61f7 100644 --- a/pgml-docs/docs/user_guides/transformers/fine_tuning.md +++ b/pgml-docs/docs/user_guides/transformers/fine_tuning.md @@ -34,8 +34,53 @@ You can view the newly loaded data in your Postgres database: 103 | {"en": "ROLES_OF_TRANSLATORS", "es": "Rafael Osuna rosuna@wol. es Traductor"} (5 rows) ``` +This huggingface dataset stores the data as language key pairs in a JSON document. To use it with PostgresML, we'll need to provide a `VIEW` that structures the data into more primitively typed columns. + +=== "SQL" + + ```sql linenums="1" + CREATE OR REPLACE VIEW kde4_en_to_es AS + SELECT translation->>'en' AS "en", translation->>'es' AS "es" + FROM pgml.kde4 + LIMIT 10; + ``` + +=== "Result" + + ```sql linenums="1" + CREATE VIEW + ``` + +Now, we can see the data in more normalized form. The exact column names don't matter for now, we'll specify which one is the target during the training call, and the other one will be used as the input. + +=== "SQL" + + ```sql linenums="1" + SELECT * FROM kde4_en_to_es LIMIT 10; + ``` + +=== "Result" + + ```sql linenums="1" + en | es + + --------------------------------------------------------------------------------------------+-------------------------------------------------------------------------- + ------------------------------ + Lauri Watts | Lauri Watts + & Lauri. Watts. mail; | & Lauri. Watts. mail; + ROLES_OF_TRANSLATORS | Rafael Osuna rosuna@wol. es Traductor Miguel Revilla Rodríguez yo@miguelr + evilla. com Traductor + 2006-02-26 3.5.1 | 2006-02-26 3.5.1 + The Babel & konqueror; plugin gives you quick access to the Babelfish translation service. | La extensión Babel de & konqueror; le permite un acceso rápido al servici + o de traducción de Babelfish. + KDE | KDE + kdeaddons | kdeaddons + konqueror | konqueror + plugins | extensiones + babelfish | babelfish + (10 rows) + ``` -When you're constructing your own datasets for translation, it's important to mirror the same table structure. You'll need a `JSONB` column named `translation`, that has first has a "from" language name/value pair, and then a "to" language name/value pair. In this English to Spanish example we use from "en" to "es". You'll pass a `y_column_name` of `translation` to tune the model. ### Tune the model Tuning is very similar to training with PostgresML, although we specify a `model_name` to download from Hugging Face instead of the base `algorithm`. @@ -43,9 +88,9 @@ Tuning is very similar to training with PostgresML, although we specify a `model ```sql linenums="1" title="tune.sql" SELECT pgml.tune( 'Translate English to Spanish', - task => 'translation_en_to_es', - relation_name => 'pgml.kde4', - y_column_name => 'translation', + task => 'translation', + relation_name => 'kde4_en_to_es', + y_column_name => 'es', -- translate into spanish model_name => 'Helsinki-NLP/opus-mt-en-es', hyperparams => '{ "learning_rate": 2e-5, @@ -289,7 +334,8 @@ Or, it might be interesting to concat the title to the text field to see how rel ```sql linenums="1" title="concat_title.sql" CREATE OR REPLACE VIEW billsum_training_data -AS SELECT title || '\n' || "text" AS "text", summary FROM pgml.billsum; +AS SELECT title || '\n' || "text" AS "text", summary FROM pgml.billsum +LIMIT 10; ``` @@ -310,14 +356,14 @@ SELECT pgml.tune( "per_device_eval_batch_size": 2, "num_train_epochs": 1, "weight_decay": 0.01, - "max_input_length": 1024, - "max_summary_length": 128 + "max_length": 1024 }', test_size => 0.2, test_sampling => 'last' ); ``` + ### Make predictions === "SQL" @@ -355,3 +401,27 @@ The default for predict in a classification problem classifies the statement as This shows that there is a 6.26% chance for category 0 (negative sentiment), and a 93.73% chance it's category 1 (positive sentiment). See the [task documentation](https://huggingface.co/tasks/text-classification) for more examples, use cases, models and datasets. + + + +## Text Generation + +```postgresql linenums="1" + SELECT pgml.load_dataset('bookcorpus', "limit" => 100); + + SELECT pgml.tune( + 'GPT Generator', + task => 'text-generation', + relation_name => 'pgml.bookcorpus', + y_column_name => 'text', + model_name => 'gpt2', + hyperparams => '{ + "learning_rate": 2e-5, + "num_train_epochs": 1 + }', + test_size => 0.2, + test_sampling => 'last' + ); + + SELECT pgml.generate('GPT Generator', 'While I wandered weak and weary'); +``` diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 5d22613eb..6f839f330 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -78,12 +78,12 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.64" +version = "0.1.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" +checksum = "b84f9ebcc6c1f5b8cb160f6990096a5c127f423fcb6e1ccc46c370cbdfb75dfc" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -160,7 +160,7 @@ dependencies = [ "log", "peeking_take_while", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "regex", "rustc-hash", "shlex 0.1.1", @@ -180,7 +180,7 @@ dependencies = [ "lazycell", "peeking_take_while", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "regex", "rustc-hash", "shlex 1.1.0", @@ -200,7 +200,7 @@ dependencies = [ "log", "peeking_take_while", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "regex", "rustc-hash", "shlex 1.1.0", @@ -257,25 +257,13 @@ dependencies = [ [[package]] name = "block-buffer" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ "generic-array", ] -[[package]] -name = "bstr" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" -dependencies = [ - "lazy_static", - "memchr", - "regex-automata", - "serde", -] - [[package]] name = "byteorder" version = "1.4.3" @@ -329,9 +317,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clang-sys" -version = "1.4.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa2e27ae6ab525c3d369ded447057bca5438d86dc3a68f6faafb8269ba82ebf3" +checksum = "77ed9a53e5d4d9c573ae844bfac6872b159cb1d1585a83b29e7a64b7eef7332a" dependencies = [ "glob", "libc", @@ -355,9 +343,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.1.4" +version = "4.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f13b9c79b5d1dd500d20ef541215a6423c75829ef43117e1b4d17fd8af0b5d76" +checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" dependencies = [ "bitflags", "clap_derive", @@ -371,28 +359,28 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eca953650a7350560b61db95a0ab1d9c6f7b74d146a9e08fb258b834f3cf7e2c" dependencies = [ - "clap 4.1.4", + "clap 4.1.8", "doc-comment", ] [[package]] name = "clap_derive" -version = "4.1.0" +version = "4.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "684a277d672e91966334af371f1a7b5833f9aa00b07c84e92fbce95e00208ce8" +checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" dependencies = [ "heck", "proc-macro-error", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "clap_lex" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "783fe232adfca04f90f56201b26d79682d4cd2625e0bc7290b95123afe558ade" +checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09" dependencies = [ "os_str_bytes", ] @@ -454,9 +442,9 @@ checksum = "6548a0ad5d2549e111e1f6a11a6c2e2d00ce6a3dafe22948d67c2b443f775e52" [[package]] name = "crossbeam-channel" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" +checksum = "cf2b3e8478797446514c91ef04bafcb59faba183e621ad488df88983cc14128c" dependencies = [ "cfg-if", "crossbeam-utils", @@ -464,9 +452,9 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" dependencies = [ "cfg-if", "crossbeam-epoch", @@ -475,22 +463,22 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.13" +version = "0.9.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a" +checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" dependencies = [ "autocfg", "cfg-if", "crossbeam-utils", - "memoffset 0.7.1", + "memoffset 0.8.0", "scopeguard", ] [[package]] name = "crossbeam-utils" -version = "0.8.14" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" +checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" dependencies = [ "cfg-if", ] @@ -507,13 +495,12 @@ dependencies = [ [[package]] name = "csv" -version = "1.1.6" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22813a6dc45b335f9bade10bf7271dc477e81113e89eb251a0bc2a8a81c536e1" +checksum = "0b015497079b9a9d69c02ad25de6c0a6edef051ea6360a327d0bd05802ef64ad" dependencies = [ - "bstr", "csv-core", - "itoa 0.4.8", + "itoa", "ryu", "serde", ] @@ -533,15 +520,15 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" dependencies = [ - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "darling" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0808e1bd8671fb44a113a14e13497557533369847788fa2ae912b6ebfce9fa8" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" dependencies = [ "darling_core", "darling_macro", @@ -549,26 +536,26 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "001d80444f28e193f30c2f293455da62dcf9a6b29918a4253152ae2b1de592cb" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" dependencies = [ "fnv", "ident_case", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "strsim 0.10.0", "syn 1.0.109", ] [[package]] name = "darling_macro" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b36230598a2d5de7ec1c6f51f72d8a99a9208daff41de2084d06e3fd3ea56685" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ "darling_core", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -610,7 +597,7 @@ checksum = "1f91d4cfa921f1c05904dc3c57b4a32c38aed3340cce209f3a6fd1478babafc4" dependencies = [ "darling", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -712,13 +699,34 @@ dependencies = [ [[package]] name = "erased-serde" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4ca605381c017ec7a5fef5e548f1cfaa419ed0f6df6367339300db74c92aa7d" +checksum = "4f2b0c2380453a92ea8b6c8e5f64ecaafccddde8ceab55ff7a8ac1029f894569" dependencies = [ "serde", ] +[[package]] +name = "errno" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +dependencies = [ + "errno-dragonfly", + "libc", + "winapi", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "eyre" version = "0.6.8" @@ -737,23 +745,23 @@ checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" [[package]] name = "fastrand" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" dependencies = [ "instant", ] [[package]] name = "filetime" -version = "0.2.19" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e884668cd0c7480504233e951174ddc3b382f7c2666e3b7310b5c4e7b0c37f9" +checksum = "8a3de6e8d11b22ff9edc6d916f890800597d60f8b2da1caf2955c274638d6412" dependencies = [ "cfg-if", "libc", "redox_syscall", - "windows-sys 0.42.0", + "windows-sys 0.45.0", ] [[package]] @@ -810,9 +818,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures-channel" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" +checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac" dependencies = [ "futures-core", "futures-sink", @@ -820,38 +828,38 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" +checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" [[package]] name = "futures-macro" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" +checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "futures-sink" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" +checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2" [[package]] name = "futures-task" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" +checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" [[package]] name = "futures-util" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" +checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" dependencies = [ "futures-core", "futures-macro", @@ -885,12 +893,12 @@ dependencies = [ [[package]] name = "ghost" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41973d4c45f7a35af8753ba3457cc99d406d863941fd7f52663cff54a5ab99b3" +checksum = "69e0cd8a998937e25c6ba7cc276b96ec5cc3f4dc4ab5de9ede4fb152bdd5c5eb" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -1023,14 +1031,24 @@ dependencies = [ [[package]] name = "inventory" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16fe3b35d64bd1f72917f06425e7573a2f63f74f42c8f56e53ea6826dde3a2b5" +checksum = "498ae1c9c329c7972b917506239b557a60386839192f1cf0ca034f345b65db99" dependencies = [ "ctor", "ghost", ] +[[package]] +name = "io-lifetimes" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfa919a82ea574332e2de6e74b4c36e74d41982b335080fa59d4ef31be20fdf3" +dependencies = [ + "libc", + "windows-sys 0.45.0", +] + [[package]] name = "itertools" version = "0.10.5" @@ -1042,15 +1060,9 @@ dependencies = [ [[package]] name = "itoa" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" - -[[package]] -name = "itoa" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "kdtree" @@ -1075,9 +1087,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.139" +version = "0.2.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" +checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" [[package]] name = "libloading" @@ -1204,6 +1216,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "linux-raw-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" + [[package]] name = "lock_api" version = "0.4.9" @@ -1267,9 +1285,9 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" dependencies = [ "autocfg", ] @@ -1291,14 +1309,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ "libc", "log", "wasi", - "windows-sys 0.42.0", + "windows-sys 0.45.0", ] [[package]] @@ -1540,9 +1558,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.45" +version = "0.10.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b102428fd03bc5edf97f62620f7298614c45cedf287c271e7ed450bbaf83f2e1" +checksum = "fd2523381e46256e40930512c7fd25562b9eae4812cb52078f155e87217c9d1e" dependencies = [ "bitflags", "cfg-if", @@ -1560,7 +1578,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b501e44f11665960c7e7fcf062c7d96a14ade4aa98116c004b2e37b5be7d736c" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -1572,9 +1590,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.80" +version = "0.9.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23bbbf7854cd45b83958ebe919f0e8e516793727652e27fda10a8384cfc790b7" +checksum = "176be2629957c157240f68f61f2d0053ad3a4ecfdd9ebf1e6521d18d9635cf67" dependencies = [ "autocfg", "cc", @@ -1632,9 +1650,9 @@ dependencies = [ [[package]] name = "paste" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba" +checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" [[package]] name = "pathsearch" @@ -1660,9 +1678,9 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pest" -version = "2.5.4" +version = "2.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab62d2fa33726dbe6321cc97ef96d8cde531e3eeaf858a058de53a8a6d40d8f" +checksum = "8cbd939b234e95d72bc393d51788aec68aeeb5d51e748ca08ff3aad58cb722f7" dependencies = [ "thiserror", "ucd-trie", @@ -1715,9 +1733,9 @@ dependencies = [ [[package]] name = "pgx" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c3d224bd3a4fe3798498c16cb37a955e7c4a4e4e9bb01e6dfd8c3738c903d4f" +checksum = "fc91f19f84e7c1ba7b25953b042bd487b6e1bbec4c3af09f61a6ac31207ff776" dependencies = [ "atomic-traits", "bitflags", @@ -1742,21 +1760,21 @@ dependencies = [ [[package]] name = "pgx-macros" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "100cd28f400753e7aeb54820d63ebafff8c1b87018d55f18404477f8c047dd9e" +checksum = "1ebfde3c33353d42c2fbcc76bea758b37018b33b1391c93d6402546569914e94" dependencies = [ "pgx-sql-entity-graph", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "pgx-pg-config" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce4f7099005b82e1b386a82a98dd436b734804974ee539b814eb4d1a78e4c1f" +checksum = "e97c27bab88fdb7b94e549b02267ab9595bd9d1043718d6d72bc2d34cf1e3952" dependencies = [ "dirs 4.0.0", "eyre", @@ -1771,9 +1789,9 @@ dependencies = [ [[package]] name = "pgx-pg-sys" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "081d104fcb693fdef1a911f11c2aad295a6b778de67069e54c40b46702c1b2d8" +checksum = "6b79c48c564bed305d202b852321603107e5f3ac31f25ea2cc4031475f38d0b3" dependencies = [ "bindgen 0.60.1", "eyre", @@ -1784,7 +1802,7 @@ dependencies = [ "pgx-pg-config", "pgx-sql-entity-graph", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "serde", "shlex 1.1.0", "sptr", @@ -1793,15 +1811,15 @@ dependencies = [ [[package]] name = "pgx-sql-entity-graph" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e03b43619e70e955894ba94883132d88ff7acf0cd15fd14e50c837a2bd3a6595" +checksum = "573a8d8c23be24c39f7b7fbbc7e15d95aa0327acd61ba95c9c9f237fec51f205" dependencies = [ "convert_case", "eyre", "petgraph", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "regex", "seq-macro", "syn 1.0.109", @@ -1813,9 +1831,9 @@ dependencies = [ [[package]] name = "pgx-tests" -version = "0.7.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4069f5b6c98d5542bacac0fdb072d6198297267df132fedd4499a156937e4ca" +checksum = "fc09f25ae560bc4e3308022999416966beda5b60d2957b9ab92bffaf2d6a86c3" dependencies = [ "clap-cargo", "eyre", @@ -1927,7 +1945,7 @@ checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" dependencies = [ "proc-macro-error-attr", "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", "version_check", ] @@ -1939,15 +1957,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "version_check", ] [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" dependencies = [ "unicode-ident", ] @@ -1997,7 +2015,7 @@ checksum = "94144a1266e236b1c932682136dc35a9dee8d3589728f68130c7c3861ef96b28" dependencies = [ "proc-macro2", "pyo3-macros-backend", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2008,7 +2026,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8df9be978a2d2f0cdebabb03206ed73b11314701a5bfe71b0d753b81997777f" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2020,9 +2038,9 @@ checksum = "7a6e920b65c65f10b2ae65c831a81a073a89edd28c7cce89475bff467ab4167a" [[package]] name = "quote" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -2093,9 +2111,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" dependencies = [ "either", "rayon-core", @@ -2103,9 +2121,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.10.2" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ "crossbeam-channel", "crossbeam-deque", @@ -2159,15 +2177,6 @@ version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" -[[package]] -name = "remove_dir_all" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" -dependencies = [ - "winapi", -] - [[package]] name = "rmp" version = "0.8.11" @@ -2211,7 +2220,21 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver 1.0.16", + "semver 1.0.17", +] + +[[package]] +name = "rustix" +version = "0.36.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd5c6ff11fecd55b40746d1995a02f2eb375bf8c00d192d521ee09f42bef37bc" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.45.0", ] [[package]] @@ -2237,15 +2260,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" +checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" [[package]] name = "ryu" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "same-file" @@ -2311,9 +2334,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" +checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "semver-parser" @@ -2326,15 +2349,15 @@ dependencies = [ [[package]] name = "seq-macro" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1685deded9b272198423bdbdb907d8519def2f26cf3699040e54e8c4fbd5c5ce" +checksum = "e6b44e8fc93a14e66336d230954dda83d18b4605ccace8fe09bc7514a71ad0bc" [[package]] name = "serde" -version = "1.0.152" +version = "1.0.156" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +checksum = "314b5b092c0ade17c00142951e50ced110ec27cea304b1037c6969246c2469a4" dependencies = [ "serde_derive", ] @@ -2351,12 +2374,12 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.152" +version = "1.0.156" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +checksum = "d7e29c4601e36bcec74a223228dce795f4cd3616341a4af93520ca1a837c087d" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2367,7 +2390,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" dependencies = [ "indexmap", - "itoa 1.0.5", + "itoa", "ryu", "serde", ] @@ -2406,9 +2429,9 @@ checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" [[package]] name = "signal-hook" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a253b5e89e2698464fc26b545c9edceb338e18a89effeeecfea192c3025be29d" +checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9" dependencies = [ "libc", "signal-hook-registry", @@ -2416,9 +2439,9 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" dependencies = [ "libc", ] @@ -2431,9 +2454,9 @@ checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" [[package]] name = "slab" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" dependencies = [ "autocfg", ] @@ -2489,9 +2512,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "socket2" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" dependencies = [ "libc", "winapi", @@ -2499,9 +2522,9 @@ dependencies = [ [[package]] name = "spin" -version = "0.9.5" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dccf47db1b41fa1573ed27ccf5e08e3ca771cb994f776668c5ebda893b248fc" +checksum = "b5d6e0250b93c8427a177b849d144a96d5acc57006149479403d7861ab721e34" dependencies = [ "lock_api", ] @@ -2577,7 +2600,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "unicode-ident", ] @@ -2592,9 +2615,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.27.7" +version = "0.27.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "975fe381e0ecba475d4acff52466906d95b153a40324956552e027b2a9eaa89e" +checksum = "a902e9050fca0a5d6877550b769abd2bd1ce8c04634b941dbe2809735e1a1e33" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2630,22 +2653,21 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.5" +version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9410d0f6853b1d94f0e519fb95df60f29d2c1eff2d921ffdf01a4c8a3b54f12d" +checksum = "8ae9980cab1db3fceee2f6c6f643d5d8de2997c58ee8d25fb0cc8a9e9e7348e5" [[package]] name = "tempfile" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" +checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95" dependencies = [ "cfg-if", "fastrand", - "libc", "redox_syscall", - "remove_dir_all", - "winapi", + "rustix", + "windows-sys 0.42.0", ] [[package]] @@ -2679,30 +2701,31 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.38" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" +checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.38" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" +checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] [[package]] name = "thread_local" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" dependencies = [ + "cfg-if", "once_cell", ] @@ -2712,7 +2735,7 @@ version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" dependencies = [ - "itoa 1.0.5", + "itoa", "libc", "num_threads", "serde", @@ -2752,9 +2775,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.25.0" +version = "1.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" +checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" dependencies = [ "autocfg", "bytes", @@ -2763,7 +2786,7 @@ dependencies = [ "mio", "pin-project-lite", "socket2", - "windows-sys 0.42.0", + "windows-sys 0.45.0", ] [[package]] @@ -2792,9 +2815,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.4" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb2e075f03b3d66d8d8785356224ba688d2906a371015e225beeb65ca92c740" +checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" dependencies = [ "bytes", "futures-core", @@ -2832,7 +2855,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2893,9 +2916,9 @@ checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "typetag" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eecd98403ae5ea2813689125cf5b3f99c40b8abed46c0a8945c81eadb673b31" +checksum = "69bf9bd14fed1815295233a0eee76a963283b53ebcbd674d463f697d3bfcae0c" dependencies = [ "erased-serde", "inventory", @@ -2906,12 +2929,12 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f9568611f0de5e83e0993b85c54679cd0afd659adcfcb0233f16280b980492e" +checksum = "bf9f5f225956dc2254c6c27500deac9390a066b2e8a1a571300627a7c4400a33" dependencies = [ "proc-macro2", - "quote 1.0.23", + "quote 1.0.26", "syn 1.0.109", ] @@ -2929,15 +2952,15 @@ checksum = "ccb97dac3243214f8d8507998906ca3e2e0b900bf9bf4870477f125b82e68f6e" [[package]] name = "unicode-bidi" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" +checksum = "524b68aca1d05e03fdf03fcdce2c6c94b6daf6d16861ddaa7e4f2b6638a9052c" [[package]] name = "unicode-ident" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" [[package]] name = "unicode-normalization" @@ -3119,9 +3142,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -3134,45 +3157,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_i686_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_x86_64_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "wyz" diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 868357cb4..cf966e21a 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -68,7 +68,8 @@ pub fn validate_shared_library() { WHERE name = 'shared_preload_libraries' LIMIT 1", ) - .unwrap().unwrap(); + .unwrap() + .unwrap(); if !shared_preload_libraries.contains("pgml") { error!("`pgml` must be added to `shared_preload_libraries` setting or models cannot be deployed"); @@ -103,7 +104,7 @@ fn version() -> String { #[pg_extern] fn train( project_name: &str, - task: default!(Option, "NULL"), + task: default!(Option<&str>, "NULL"), relation_name: default!(Option<&str>, "NULL"), y_column_name: default!(Option<&str>, "NULL"), algorithm: default!(Algorithm, "'linear'"), @@ -149,7 +150,7 @@ fn train( #[pg_extern] fn train_joint( project_name: &str, - task: default!(Option, "NULL"), + task: default!(Option<&str>, "NULL"), relation_name: default!(Option<&str>, "NULL"), y_column_name: default!(Option>, "NULL"), algorithm: default!(Algorithm, "'linear'"), @@ -172,6 +173,7 @@ fn train_joint( name!(deployed, bool), ), > { + let task = task.map(|t| Task::from_str(t).unwrap()); let project = match Project::find_by_name(project_name) { Some(project) => project, None => Project::create(project_name, match task { @@ -276,6 +278,9 @@ fn train_joint( deploy = false; } } + _ => error!( + "Training only supports `classification` and `regression` task types." + ), } } } @@ -314,7 +319,8 @@ fn deploy( let (project_id, task) = Spi::get_two_with_args::( "SELECT id, task::TEXT from pgml.projects WHERE name = $1", vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ).unwrap(); + ) + .unwrap(); let project_id = project_id.unwrap_or_else(|| error!("Project named `{}` does not exist.", project_name)); @@ -332,17 +338,34 @@ fn deploy( } match strategy { Strategy::best_score => match task { + Task::classification | Task::question_answering | Task::text_classification => { + let _ = write!( + sql, + "{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST" + ); + } Task::regression => { let _ = write!( sql, "{predicate}\nORDER BY models.metrics->>'r2' DESC NULLS LAST" ); } - - Task::classification => { + Task::summarization => { let _ = write!( sql, - "{predicate}\nORDER BY models.metrics->>'f1' DESC NULLS LAST" + "{predicate}\nORDER BY models.metrics->>'rouge_ngram_f1' DESC NULLS LAST" + ); + } + Task::text_generation | Task::text2text => { + let _ = write!( + sql, + "{predicate}\nORDER BY models.metrics->>'perplexity' ASC NULLS LAST" + ); + } + Task::translation => { + let _ = write!( + sql, + "{predicate}\nORDER BY models.metrics->>'bleu' DESC NULLS LAST" ); } }, @@ -377,7 +400,8 @@ fn deploy( let (model_id, algorithm) = Spi::get_two_with_args::( &sql, vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ).unwrap(); + ) + .unwrap(); let model_id = model_id.expect("No qualified models exist for this deployment."); let algorithm = algorithm.expect("No qualified models exist for this deployment."); @@ -444,10 +468,8 @@ fn predict_model_row(model_id: i64, row: pgx::datum::AnyElement) -> f32 { let features_width = snapshot.features_width(); let mut processed = vec![0_f32; features_width]; - let feature_data = ndarray::ArrayView2::from_shape( - (1, features_width), - &numeric_encoded_features, - ).unwrap(); + let feature_data = + ndarray::ArrayView2::from_shape((1, features_width), &numeric_encoded_features).unwrap(); Zip::from(feature_data.columns()) .and(&snapshot.feature_positions) @@ -458,7 +480,6 @@ fn predict_model_row(model_id: i64, row: pgx::datum::AnyElement) -> f32 { model.predict(&processed) } - #[pg_extern] fn snapshot( relation_name: &str, @@ -481,18 +502,24 @@ fn snapshot( #[pg_extern] fn load_dataset( source: &str, + subset: default!(Option, "NULL"), limit: default!(Option, "NULL"), + kwargs: default!(JsonB, "'{}'"), ) -> TableIterator<'static, (name!(table_name, String), name!(rows, i64))> { // cast limit since pgx doesn't support usize let limit: Option = limit.map(|limit| limit.try_into().unwrap()); let (name, rows) = match source { - "breast_cancer" => crate::orm::dataset::load_breast_cancer(limit), - "diabetes" => crate::orm::dataset::load_diabetes(limit), - "digits" => crate::orm::dataset::load_digits(limit), - "iris" => crate::orm::dataset::load_iris(limit), - "linnerud" => crate::orm::dataset::load_linnerud(limit), - "wine" => crate::orm::dataset::load_wine(limit), - _ => error!("Unknown source: `{source}`"), + "breast_cancer" => dataset::load_breast_cancer(limit), + "diabetes" => dataset::load_diabetes(limit), + "digits" => dataset::load_digits(limit), + "iris" => dataset::load_iris(limit), + "linnerud" => dataset::load_linnerud(limit), + "wine" => dataset::load_wine(limit), + _ => { + let rows = + crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0); + (source.into(), rows as i64) + } }; TableIterator::new(vec![(name, rows)].into_iter()) @@ -525,6 +552,195 @@ pub fn transform_string( )) } +#[cfg(feature = "python")] +#[pg_extern(name = "generate")] +fn generate(project_name: &str, inputs: &str) -> String { + generate_batch(project_name, Vec::from([inputs])) + .first() + .unwrap() + .to_string() +} + +#[cfg(feature = "python")] +#[pg_extern(name = "generate")] +fn generate_batch(project_name: &str, inputs: Vec<&str>) -> Vec { + crate::bindings::transformers::generate(Project::get_deployed_model_id(project_name), inputs) +} + +#[cfg(feature = "python")] +#[allow(clippy::too_many_arguments)] +#[pg_extern] +fn tune( + project_name: &str, + task: default!(Option<&str>, "NULL"), + relation_name: default!(Option<&str>, "NULL"), + y_column_name: default!(Option<&str>, "NULL"), + model_name: default!(Option<&str>, "NULL"), + hyperparams: default!(JsonB, "'{}'"), + test_size: default!(f32, 0.25), + test_sampling: default!(Sampling, "'last'"), + automatic_deploy: default!(Option, true), + materialize_snapshot: default!(bool, false), +) -> TableIterator< + 'static, + ( + name!(status, String), + name!(task, String), + name!(algorithm, String), + name!(deployed, bool), + ), +> { + let task = task.map(|t| Task::from_str(t).unwrap()); + let preprocess = JsonB(serde_json::from_str("{}").unwrap()); + let project = match Project::find_by_name(project_name) { + Some(project) => project, + None => Project::create( + project_name, + match task { + Some(task) => task, + None => error!( + "Project `{}` does not exist. To create a new project, provide the task.", + project_name + ), + }, + ), + }; + + if task.is_some() && task.unwrap() != project.task { + error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + } + + let mut snapshot = match relation_name { + None => { + let snapshot = project + .last_snapshot() + .expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."); + + info!("Using existing snapshot from {}", snapshot.snapshot_name(),); + + snapshot + } + + Some(relation_name) => { + info!( + "Snapshotting table \"{}\", this may take a little while...", + relation_name + ); + + let snapshot = Snapshot::create( + relation_name, + vec![y_column_name + .expect("You must pass a `y_column_name` when you pass a `relation_name`") + .to_string()], + test_size, + test_sampling, + materialize_snapshot, + preprocess, + ); + + if materialize_snapshot { + info!( + "Snapshot of table \"{}\" created and saved in {}", + relation_name, + snapshot.snapshot_name(), + ); + } + + snapshot + } + }; + + // algorithm will be transformers, stash the model_name in a hyperparam for v1 compatibility. + let mut hyperparams = hyperparams.0.as_object().unwrap().clone(); + hyperparams.insert(String::from("model_name"), json!(model_name)); + let hyperparams = JsonB(json!(hyperparams)); + + // # Default repeatable random state when possible + // let algorithm = Model.algorithm_from_name_and_task(algorithm, task); + // if "random_state" in algorithm().get_params() and "random_state" not in hyperparams: + // hyperparams["random_state"] = 0 + let model = Model::tune(&project, &mut snapshot, &hyperparams); + let new_metrics: &serde_json::Value = &model.metrics.unwrap().0; + let new_metrics = new_metrics.as_object().unwrap(); + + let deployed_metrics = Spi::get_one_with_args::( + " + SELECT models.metrics + FROM pgml.models + JOIN pgml.deployments + ON deployments.model_id = models.id + JOIN pgml.projects + ON projects.id = deployments.project_id + WHERE projects.name = $1 + ORDER by deployments.created_at DESC + LIMIT 1;", + vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], + ); + + let mut deploy = true; + match automatic_deploy { + // Deploy only if metrics are better than previous model. + Some(true) | None => { + if let Ok(Some(deployed_metrics)) = deployed_metrics { + let deployed_metrics = deployed_metrics.0.as_object().unwrap(); + match project.task { + Task::classification | Task::question_answering | Task::text_classification => { + if deployed_metrics.get("f1").unwrap().as_f64() + > new_metrics.get("f1").unwrap().as_f64() + { + deploy = false; + } + } + Task::regression => { + if deployed_metrics.get("r2").unwrap().as_f64() + > new_metrics.get("r2").unwrap().as_f64() + { + deploy = false; + } + } + Task::translation => { + if deployed_metrics.get("bleu").unwrap().as_f64() + > new_metrics.get("bleu").unwrap().as_f64() + { + deploy = false; + } + } + Task::summarization => { + if deployed_metrics.get("rouge_ngram_f1").unwrap().as_f64() + > new_metrics.get("rouge_ngram_f1").unwrap().as_f64() + { + deploy = false; + } + } + Task::text_generation | Task::text2text => { + if deployed_metrics.get("perplexity").unwrap().as_f64() + < new_metrics.get("perplexity").unwrap().as_f64() + { + deploy = false; + } + } + } + } + } + + Some(false) => deploy = false, + }; + + if deploy { + project.deploy(model.id); + } + + TableIterator::new( + vec![( + project.name, + project.task.to_string(), + model.algorithm.to_string(), + deploy, + )] + .into_iter(), + ) +} + #[cfg(feature = "python")] #[pg_extern(name = "sklearn_f1_score")] pub fn sklearn_f1_score(ground_truth: Vec, y_hat: Vec) -> f32 { @@ -578,31 +794,36 @@ pub fn dump_all(path: &str) { Spi::run(&format!( "COPY pgml.projects TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("snapshots.csv"); Spi::run(&format!( "COPY pgml.snapshots TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("models.csv"); Spi::run(&format!( "COPY pgml.models TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("files.csv"); Spi::run(&format!( "COPY pgml.files TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("deployments.csv"); Spi::run(&format!( "COPY pgml.deployments TO '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); } #[pg_extern] @@ -611,31 +832,36 @@ pub fn load_all(path: &str) { Spi::run(&format!( "COPY pgml.projects FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("snapshots.csv"); Spi::run(&format!( "COPY pgml.snapshots FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("models.csv"); Spi::run(&format!( "COPY pgml.models FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("files.csv"); Spi::run(&format!( "COPY pgml.files FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); let p = std::path::Path::new(path).join("deployments.csv"); Spi::run(&format!( "COPY pgml.deployments FROM '{}' CSV HEADER", p.to_str().unwrap() - )).unwrap(); + )) + .unwrap(); } #[cfg(any(test, feature = "pg_test"))] @@ -665,7 +891,7 @@ mod tests { 0.5, Sampling::last, true, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ); assert!(snapshot.id > 0); } @@ -681,7 +907,7 @@ mod tests { 0.5, Sampling::last, true, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ); }); @@ -695,14 +921,15 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); + Spi::get_one::("select setting from pg_settings where name = 'data_directory'") + .unwrap(); info!("Data directory: {}", setting.unwrap()); for runtime in [Runtime::python, Runtime::rust] { let result: Vec<(String, String, String, bool)> = train( "Test project", - Some(Task::regression), + Some(&Task::regression.to_string()), Some("pgml.diabetes"), Some("target"), Algorithm::linear, @@ -715,7 +942,7 @@ mod tests { Some(runtime), Some(true), false, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ) .collect(); @@ -734,14 +961,15 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); + Spi::get_one::("select setting from pg_settings where name = 'data_directory'") + .unwrap(); info!("Data directory: {}", setting.unwrap()); for runtime in [Runtime::python, Runtime::rust] { let result: Vec<(String, String, String, bool)> = train( "Test project 2", - Some(Task::classification), + Some(&Task::classification.to_string()), Some("pgml.digits"), Some("target"), Algorithm::xgboost, @@ -754,7 +982,7 @@ mod tests { Some(runtime), Some(true), false, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ) .collect(); @@ -773,14 +1001,15 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); + Spi::get_one::("select setting from pg_settings where name = 'data_directory'") + .unwrap(); info!("Data directory: {}", setting.unwrap()); for runtime in [Runtime::python, Runtime::rust] { let result: Vec<(String, String, String, bool)> = train( "Test project 3", - Some(Task::classification), + Some(&Task::classification.to_string()), Some("pgml.breast_cancer"), Some("malignant"), Algorithm::xgboost, @@ -793,7 +1022,7 @@ mod tests { Some(runtime), Some(true), true, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ) .collect(); diff --git a/pgml-extension/src/bindings/lightgbm.rs b/pgml-extension/src/bindings/lightgbm.rs index b28f945d1..e5795080e 100644 --- a/pgml-extension/src/bindings/lightgbm.rs +++ b/pgml-extension/src/bindings/lightgbm.rs @@ -3,6 +3,7 @@ use crate::orm::dataset::Dataset; use crate::orm::task::Task; use crate::orm::Hyperparams; use lightgbm; +use pgx::*; use serde_json::json; pub struct Estimator { @@ -52,6 +53,7 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box error!("lightgbm only supports `regression` and `classification` tasks."), }; let data = lightgbm::Dataset::from_vec( diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index d97a51868..a1d35526a 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -52,7 +52,7 @@ mod tests { 0.5, Sampling::last, false, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ); let classification = Project::create("classification", Task::classification); let mut breast_cancer = Snapshot::create( @@ -61,7 +61,7 @@ mod tests { 0.5, Sampling::last, false, - JsonB(serde_json::Value::Object(Hyperparams::new())) + JsonB(serde_json::Value::Object(Hyperparams::new())), ); let mut regressors = Vec::new(); diff --git a/pgml-extension/src/bindings/sklearn.rs b/pgml-extension/src/bindings/sklearn.rs index 5963fc287..886504cff 100644 --- a/pgml-extension/src/bindings/sklearn.rs +++ b/pgml-extension/src/bindings/sklearn.rs @@ -9,6 +9,7 @@ /// defined in `src/bindings/sklearn.py`. use std::collections::HashMap; +use once_cell::sync::Lazy; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -16,6 +17,17 @@ use crate::bindings::Bindings; use crate::orm::*; +static PY_MODULE: Lazy> = Lazy::new(|| { + Python::with_gil(|py| -> Py { + let src = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/bindings/sklearn.py" + )); + + PyModule::from_code(py, src, "", "").unwrap().into() + }) +}); + pub fn linear_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Box { fit(dataset, hyperparams, "linear_regression") } @@ -290,17 +302,11 @@ fn fit( hyperparams: &Hyperparams, algorithm_task: &'static str, ) -> Box { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - let hyperparams = serde_json::to_string(hyperparams).unwrap(); let (estimator, predict, predict_proba) = Python::with_gil(|py| -> (Py, Py, Py) { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let estimator: Py = module.getattr("estimator").unwrap().into(); + let estimator: Py = PY_MODULE.getattr(py, "estimator").unwrap().into(); let train: Py = estimator .call1( @@ -321,20 +327,20 @@ fn fit( .call1(py, PyTuple::new(py, &[&dataset.x_train, &dataset.y_train])) .unwrap(); - let predict: Py = module - .getattr("predictor") + let predict: Py = PY_MODULE + .getattr(py, "predictor") .unwrap() - .call1(PyTuple::new(py, &[&estimator])) + .call1(py, PyTuple::new(py, &[&estimator])) .unwrap() - .extract() + .extract(py) .unwrap(); - let predict_proba: Py = module - .getattr("predictor_proba") + let predict_proba: Py = PY_MODULE + .getattr(py, "predictor_proba") .unwrap() - .call1(PyTuple::new(py, &[&estimator])) + .call1(py, PyTuple::new(py, &[&estimator])) .unwrap() - .extract() + .extract(py) .unwrap(); (estimator, predict, predict_proba) @@ -389,17 +395,11 @@ impl Bindings for Estimator { /// Serialize self to bytes fn to_bytes(&self) -> Vec { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> Vec { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let save = module.getattr("save").unwrap(); - save.call1(PyTuple::new(py, &[&self.estimator])) + let save = PY_MODULE.getattr(py, "save").unwrap(); + save.call1(py, PyTuple::new(py, &[&self.estimator])) .unwrap() - .extract() + .extract(py) .unwrap() }) } @@ -409,34 +409,28 @@ impl Bindings for Estimator { where Self: Sized, { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> Box { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let load = module.getattr("load").unwrap(); + let load = PY_MODULE.getattr(py, "load").unwrap(); let estimator: Py = load - .call1(PyTuple::new(py, &[bytes])) + .call1(py, PyTuple::new(py, &[bytes])) .unwrap() - .extract() + .extract(py) .unwrap(); - let predict: Py = module - .getattr("predictor") + let predict: Py = PY_MODULE + .getattr(py, "predictor") .unwrap() - .call1(PyTuple::new(py, &[&estimator])) + .call1(py, PyTuple::new(py, &[&estimator])) .unwrap() - .extract() + .extract(py) .unwrap(); - let predict_proba: Py = module - .getattr("predictor_proba") + let predict_proba: Py = PY_MODULE + .getattr(py, "predictor_proba") .unwrap() - .call1(PyTuple::new(py, &[&estimator])) + .call1(py, PyTuple::new(py, &[&estimator])) .unwrap() - .extract() + .extract(py) .unwrap(); Box::new(Estimator { @@ -449,18 +443,12 @@ impl Bindings for Estimator { } fn sklearn_metric(name: &str, ground_truth: &[f32], y_hat: &[f32]) -> f32 { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> f32 { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let calculate_metric = module.getattr("calculate_metric").unwrap(); + let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap(); let wrapper: Py = calculate_metric - .call1(PyTuple::new(py, &[name])) + .call1(py, PyTuple::new(py, &[name])) .unwrap() - .extract() + .extract(py) .unwrap(); let score: f32 = wrapper @@ -490,18 +478,12 @@ pub fn recall(ground_truth: &[f32], y_hat: &[f32]) -> f32 { } pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec> { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> Vec> { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let calculate_metric = module.getattr("calculate_metric").unwrap(); + let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap(); let wrapper: Py = calculate_metric - .call1(PyTuple::new(py, &["confusion_matrix"])) + .call1(py, PyTuple::new(py, &["confusion_matrix"])) .unwrap() - .extract() + .extract(py) .unwrap(); let matrix: Vec> = wrapper @@ -515,18 +497,12 @@ pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec> { } pub fn regression_metrics(ground_truth: &[f32], y_hat: &[f32]) -> HashMap { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - Python::with_gil(|py| -> HashMap { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let calculate_metric = module.getattr("regression_metrics").unwrap(); + let calculate_metric = PY_MODULE.getattr(py, "regression_metrics").unwrap(); let scores: HashMap = calculate_metric - .call1(PyTuple::new(py, &[ground_truth, y_hat])) + .call1(py, PyTuple::new(py, &[ground_truth, y_hat])) .unwrap() - .extract() + .extract(py) .unwrap(); scores @@ -538,18 +514,12 @@ pub fn classification_metrics( y_hat: &[f32], num_classes: usize, ) -> HashMap { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/sklearn.py" - )); - let mut scores = Python::with_gil(|py| -> HashMap { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let calculate_metric = module.getattr("classification_metrics").unwrap(); + let calculate_metric = PY_MODULE.getattr(py, "classification_metrics").unwrap(); let scores: HashMap = calculate_metric - .call1(PyTuple::new(py, &[ground_truth, y_hat])) + .call1(py, PyTuple::new(py, &[ground_truth, y_hat])) .unwrap() - .extract() + .extract(py) .unwrap(); scores @@ -564,12 +534,8 @@ pub fn classification_metrics( } pub fn package_version(name: &str) -> String { - let mut version = String::new(); - - Python::with_gil(|py| { + Python::with_gil(|py| -> String { let package = py.import(name).unwrap(); - version = package.getattr("__version__").unwrap().extract().unwrap(); - }); - - version + package.getattr("__version__").unwrap().extract().unwrap() + }) } diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index e6f0c12b0..d03affaa2 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -1,5 +1,41 @@ -import transformers +import os import json +import math +import shutil +import time + + +import datasets +from rouge import Rouge +from sacrebleu.metrics import BLEU +from sklearn.metrics import ( + mean_squared_error, + r2_score, + f1_score, + precision_score, + recall_score, + roc_auc_score, + accuracy_score, + log_loss, +) +import torch +from tqdm import tqdm +import transformers +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq, + DefaultDataCollator, + AutoModelForSequenceClassification, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForCausalLM, + TrainingArguments, + Trainer, +) + +__cache_transformer_by_model_id = {} def transform(task, args, inputs): task = json.loads(task) @@ -12,3 +48,404 @@ def transform(task, args, inputs): inputs = [json.loads(input) for input in inputs] return json.dumps(pipe(inputs, **args)) + +def load_dataset(name, subset, limit: None, kwargs: "{}"): + kwargs = json.loads(kwargs) + + if limit: + dataset = datasets.load_dataset(name, subset, split=f"train[:{limit}]", **kwargs) + else: + dataset = datasets.load_dataset(name, subset, **kwargs) + + data = None + types = None + if isinstance(dataset, datasets.Dataset): + data = dataset.to_dict() + types = {name: feature.dtype for name, feature in dataset.features.items()} + elif isinstance(dataset, datasets.DatasetDict): + data = {} + # Merge train/test splits, we'll re-split back in PostgresML. + for name, split in dataset.items(): + types = {name: feature.dtype for name, feature in split.features.items()} + for field, values in split.to_dict().items(): + if field in data: + data[field] += values + else: + data[field] = values + else: + raise PgMLException(f"Unhandled dataset type: {type(dataset)}") + + return json.dumps({"data": data, "types": types}) + +def tokenize_text_classification(tokenizer, max_length, x, y): + encoding = tokenizer(x, padding=True, truncation=True) + encoding["label"] = y + return datasets.Dataset.from_dict(encoding.data) + +def tokenize_translation(tokenizer, max_length, x, y): + encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y) + return datasets.Dataset.from_dict(encoding.data) + +def tokenize_summarization(tokenizer, max_length, x, y): + encoding = tokenizer(x, max_length=max_length, truncation=True, text_target=y) + return datasets.Dataset.from_dict(encoding.data) + +def tokenize_text_generation(tokenizer, max_length, y): + encoding = tokenizer(y, max_length=max_length) + return datasets.Dataset.from_dict(encoding.data) + +def tokenize_question_answering(tokenizer, max_length, x, y): + pass + +def compute_metrics_summarization(model, tokenizer, hyperparams, x, y): + all_preds = [] + all_labels = y + + batch_size = hyperparams["per_device_eval_batch_size"] + batches = int(math.ceil(len(y) / batch_size)) + with torch.no_grad(): + for i in range(batches): + inputs = x[i * batch_size : min((i + 1) * batch_size, len(x))] + tokens = tokenizer.batch_encode_plus( + inputs, + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(model.device) + predictions = model.generate(**tokens) + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) + rouge = Rouge().get_scores(all_preds, all_labels, avg=True) + return { + "bleu": bleu.score, + "rouge_ngram_f1": rouge["rouge-1"]["f"], + "rouge_ngram_precision": rouge["rouge-1"]["p"], + "rouge_ngram_recall": rouge["rouge-1"]["r"], + "rouge_bigram_f1": rouge["rouge-2"]["f"], + "rouge_bigram_precision": rouge["rouge-2"]["p"], + "rouge_bigram_recall": rouge["rouge-2"]["r"], + } + +def compute_metrics_text_classification(self, dataset): + feature = label = None + for name, type in dataset.features.items(): + if isinstance(type, datasets.features.features.ClassLabel): + label = name + elif isinstance(type, datasets.features.features.Value): + feature = name + else: + raise PgMLException(f"Unhandled feature type: {type}") + logits = torch.Tensor(device="cpu") + + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(math.ceil(len(dataset) / batch_size)) + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + tokens = self.tokenizer(slice[feature], padding=True, truncation=True, return_tensors="pt") + tokens.to(self.model.device) + result = self.model(**tokens).logits.to("cpu") + logits = torch.cat((logits, result), 0) + + metrics = {} + + y_pred = logits.argmax(-1) + y_prob = torch.nn.functional.softmax(logits, dim=-1) + y_test = numpy.array(dataset[label]).flatten() + + metrics["mean_squared_error"] = mean_squared_error(y_test, y_pred) + metrics["r2"] = r2_score(y_test, y_pred) + metrics["f1"] = f1_score(y_test, y_pred, average="weighted") + metrics["precision"] = precision_score(y_test, y_pred, average="weighted") + metrics["recall"] = recall_score(y_test, y_pred, average="weighted") + metrics["accuracy"] = accuracy_score(y_test, y_pred) + metrics["log_loss"] = log_loss(y_test, y_prob) + roc_auc_y_prob = y_prob + if y_prob.shape[1] == 2: # binary classification requires only the greater label by passed to roc_auc_score + roc_auc_y_prob = y_prob[:, 1] + metrics["roc_auc"] = roc_auc_score(y_test, roc_auc_y_prob, average="weighted", multi_class="ovo") + + return metrics + +def compute_metrics_translation(model, tokenizer, hyperparams, x, y): + all_preds = [] + all_labels = y + + batch_size = hyperparams["per_device_eval_batch_size"] + batches = int(math.ceil(len(y) / batch_size)) + with torch.no_grad(): + for i in range(batches): + inputs = x[i * batch_size : min((i + 1) * batch_size, len(x))] + tokens = tokenizer.batch_encode_plus( + inputs, + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(model.device) + predictions = model.generate(**tokens) + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + bleu = BLEU().corpus_score(all_preds, [[l] for l in all_labels]) + rouge = Rouge().get_scores(all_preds, all_labels, avg=True) + return { + "bleu": bleu.score, + "rouge_ngram_f1": rouge["rouge-1"]["f"], + "rouge_ngram_precision": rouge["rouge-1"]["p"], + "rouge_ngram_recall": rouge["rouge-1"]["r"], + "rouge_bigram_f1": rouge["rouge-2"]["f"], + "rouge_bigram_precision": rouge["rouge-2"]["p"], + "rouge_bigram_recall": rouge["rouge-2"]["r"], + } + +def compute_metrics_question_answering(model, tokenizer, hyperparams, x, y): + batch_size = self.hyperparams["per_device_eval_batch_size"] + batches = int(math.ceil(len(dataset) / batch_size)) + + with torch.no_grad(): + for i in range(batches): + slice = dataset.select(range(i * batch_size, min((i + 1) * batch_size, len(dataset)))) + tokens = self.algorithm["tokenizer"].encode_plus( + slice["question"], slice["context"], return_tensors="pt" + ) + tokens.to(self.algorithm["model"].device) + outputs = self.algorithm["model"](**tokens) + answer_start = torch.argmax(outputs[0]) + answer_end = torch.argmax(outputs[1]) + 1 + answer = self.algorithm["tokenizer"].convert_tokens_to_string( + self.algorithm["tokenizer"].convert_ids_to_tokens(tokens["input_ids"][0][answer_start:answer_end]) + ) + + def compute_exact_match(prediction, truth): + return int(normalize_text(prediction) == normalize_text(truth)) + + def compute_f1(prediction, truth): + pred_tokens = normalize_text(prediction).split() + truth_tokens = normalize_text(truth).split() + + # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise + if len(pred_tokens) == 0 or len(truth_tokens) == 0: + return int(pred_tokens == truth_tokens) + + common_tokens = set(pred_tokens) & set(truth_tokens) + + # if there are no common tokens then f1 = 0 + if len(common_tokens) == 0: + return 0 + + prec = len(common_tokens) / len(pred_tokens) + rec = len(common_tokens) / len(truth_tokens) + + return 2 * (prec * rec) / (prec + rec) + + def get_gold_answers(example): + """helper function that retrieves all possible true answers from a squad2.0 example""" + + gold_answers = [answer["text"] for answer in example.answers if answer["text"]] + + # if gold_answers doesn't exist it's because this is a negative example - + # the only correct answer is an empty string + if not gold_answers: + gold_answers = [""] + + return gold_answers + + metrics = {} + metrics["exact_match"] = 0 + + return metrics + +def compute_metrics_text_generation(model, tokenizer, hyperparams, y): + full_text = "" + for entry in y: + if entry: + full_text += "\n\n" + entry + + encodings = tokenizer(full_text, return_tensors="pt") + + # TODO make these more configurable + stride = 512 + config = model.config.to_dict() + max_length = config.get("n_positions", 1024) + + stride = min(stride, max_length) + seq_len = encodings.input_ids.size(1) + + nlls = [] + prev_end_loc = 0 + for begin_loc in tqdm(range(0, seq_len, stride)): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + with torch.no_grad(): + outputs = model(input_ids, labels=target_ids) + + # loss is calculated using CrossEntropyLoss which averages over input tokens. + # Multiply it with trg_len to get the summation instead of average. + # We will take average over all the tokens to get the true average + # in the last step of this example. + neg_log_likelihood = outputs.loss * trg_len + + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + + perplexity = torch.exp(torch.stack(nlls).sum() / end_loc) + + return { + "perplexity": perplexity + } + +def tune(task, hyperparams, path, x_train, x_test, y_train, y_test): + hyperparams = json.loads(hyperparams) + model_name = hyperparams.pop("model_name") + tokenizer = AutoTokenizer.from_pretrained(model_name) + + algorithm = {} + + if task == "text-classification": + model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) + train = tokenize_text_classification(tokenizer, max_length, x_train, y_train) + test = tokenize_text_classification(tokenizer, max_length, x_test, y_test) + data_collator = DefaultDataCollator() + elif task == "question-answering": + max_length = hyperparams.pop("max_length", None) + algorithm["stride"] = hyperparams.pop("stride", 128) + algorithm["model"] = AutoModelForQuestionAnswering.from_pretrained(model_name) + train = tokenize_question_answering(tokenizer, max_length, x_train, y_train) + test = tokenize_question_answering(tokenizer, max_length, x_test, y_test) + data_collator = DefaultDataCollator() + elif task == "summarization": + max_length = hyperparams.pop("max_length", 1024) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + train = tokenize_summarization(tokenizer, max_length, x_train, y_train) + test = tokenize_summarization(tokenizer, max_length, x_test, y_test) + data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) + elif task == "translation": + max_length = hyperparams.pop("max_length", None) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + train = tokenize_translation(tokenizer, max_length, x_train, y_train) + test = tokenize_translation(tokenizer, max_length, x_test, y_test) + data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt") + elif task == "text-generation": + max_length = hyperparams.pop("max_length", None) + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained(model_name) + model.resize_token_embeddings(len(tokenizer)) + train = tokenize_text_generation(tokenizer, max_length, y_train) + test = tokenize_text_generation(tokenizer, max_length, y_test) + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt") + else: + raise PgMLException(f"unhandled task type: {task}") + trainer = Trainer( + model=model, + args=TrainingArguments(output_dir=path, **hyperparams), + train_dataset=train, + eval_dataset=test, + tokenizer=tokenizer, + data_collator=data_collator, + ) + start = time.perf_counter() + trainer.train() + fit_time = time.perf_counter() - start + model.eval() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Test + start = time.perf_counter() + if task == "summarization": + metrics = compute_metrics_summarization(model, tokenizer, hyperparams, x_test, y_test) + elif task == "text-classification": + metrics = compute_metrics_text_classification(model, tokenizer, hyperparams, x_test, y_test) + elif task == "question-answering": + metrics = compute_metrics_question_answering(model, tokenizer, hyperparams, x_test, y_test) + elif task == "translation": + metrics = compute_metrics_translation(model, tokenizer, hyperparams, x_test, y_test) + elif task == "text-generation": + metrics = compute_metrics_text_generation(model, tokenizer, hyperparams, y_test) + else: + raise PgMLException(f"unhandled task type: {task}") + metrics["score_time"] = time.perf_counter() - start + metrics["fit_time"] = fit_time + + # Save the results + if os.path.isdir(path): + shutil.rmtree(path, ignore_errors=True) + trainer.save_model() + + return metrics + +class MissingModelError(Exception): + pass + +def get_transformer_by_model_id(model_id): + global __cache_transformer_by_model_id + if model_id in __cache_transformer_by_model_id: + return __cache_transformer_by_model_id[model_id] + else: + raise MissingModelError + +def load_model(model_id, task, dir): + if task == "summarization": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForSeq2SeqLM.from_pretrained(dir), + } + elif task == "text-classification": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForSequenceClassification.from_pretrained(dir), + } + elif task == "translation": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForSeq2SeqLM.from_pretrained(dir), + } + elif task == "question-answering": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForQuestionAnswering.from_pretrained(dir), + } + elif task == "text-generation": + __cache_transformer_by_model_id[model_id] = { + "tokenizer": AutoTokenizer.from_pretrained(dir), + "model": AutoModelForCausalLM.from_pretrained(dir), + } + + else: + raise Exception(f"unhandled task type: {task}") + +def generate(model_id, data): + result = get_transformer_by_model_id(model_id) + tokenizer = result["tokenizer"] + model = result["model"] + + all_preds = [] + + batch_size = 1 # TODO hyperparams + batches = int(math.ceil(len(data) / batch_size)) + + with torch.no_grad(): + for i in range(batches): + start = i * batch_size + end = min((i + 1) * batch_size, len(data)) + tokens = tokenizer.batch_encode_plus( + data[start:end], + padding=True, + truncation=True, + return_tensors="pt", + return_token_type_ids=False, + ).to(model.device) + predictions = model.generate(**tokens) + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + all_preds.extend(decoded_preds) + return all_preds diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 7aaeac17c..7c479aa8e 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -1,25 +1,39 @@ +use std::collections::HashMap; +use std::io::Write; +use std::path::PathBuf; +use std::str::FromStr; + +use once_cell::sync::Lazy; +use pgx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; +use crate::orm::{Task, TextDataset}; + +static PY_MODULE: Lazy> = Lazy::new(|| { + Python::with_gil(|py| -> Py { + let src = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/bindings/transformers.py" + )); + + PyModule::from_code(py, src, "", "").unwrap().into() + }) +}); + pub fn transform( task: &serde_json::Value, args: &serde_json::Value, inputs: &Vec, ) -> serde_json::Value { - let module = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/src/bindings/transformers.py" - )); - let task = serde_json::to_string(task).unwrap(); let args = serde_json::to_string(args).unwrap(); let inputs = serde_json::to_string(inputs).unwrap(); let results = Python::with_gil(|py| -> String { - let module = PyModule::from_code(py, module, "", "").unwrap(); - let transformer: Py = module.getattr("transform").unwrap().into(); + let transform: Py = PY_MODULE.getattr(py, "transform").unwrap().into(); - transformer + transform .call1( py, PyTuple::new( @@ -33,3 +47,228 @@ pub fn transform( }); serde_json::from_str(&results).unwrap() } + +pub fn tune( + task: &Task, + dataset: TextDataset, + hyperparams: &JsonB, + path: &std::path::PathBuf, +) -> HashMap { + let task = task.to_string(); + let hyperparams = serde_json::to_string(&hyperparams.0).unwrap(); + + let metrics = Python::with_gil(|py| -> HashMap { + let tune = PY_MODULE.getattr(py, "tune").unwrap(); + let result = tune.call1( + py, + ( + &task, + &hyperparams, + path.to_str().unwrap(), + dataset.x_train, + dataset.x_test, + dataset.y_train, + dataset.y_test, + ), + ); + let result = match result { + Err(e) => { + let traceback = e.traceback(py).unwrap().format().unwrap(); + error!("{traceback} {e}") + } + Ok(o) => o, + }; + result.extract(py).unwrap() + }); + metrics +} + +pub fn generate(model_id: i64, inputs: Vec<&str>) -> Vec { + Python::with_gil(|py| -> Vec { + let generate = PY_MODULE.getattr(py, "generate").unwrap(); + // cloning inputs in case we have to re-call on error is rather unfortunate here + let result = generate.call1(py, (model_id, inputs.clone())); + let result = match result { + Err(e) => { + if e.get_type(py).name().unwrap() == "MissingModelError" { + let mut dir = std::path::PathBuf::from("/tmp/postgresml/models"); + dir.push(model_id.to_string()); + if !dir.exists() { + dump_model(model_id, dir.clone()); + } + let task = Spi::get_one_with_args::( + "SELECT task::TEXT + FROM pgml.projects + JOIN pgml.models + ON models.project_id = projects.id + WHERE models.id = $1", + vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())], + ) + .unwrap() + .unwrap(); + + let load = PY_MODULE.getattr(py, "load_model").unwrap(); + let task = Task::from_str(&task).unwrap(); + load.call1(py, (model_id, task.to_string(), dir)).unwrap(); + + generate.call1(py, (model_id, inputs)).unwrap() + } else { + let traceback = e.traceback(py).unwrap().format().unwrap(); + error!("{traceback} {e}") + } + } + Ok(o) => o, + }; + result.extract(py).unwrap() + }) +} + +fn dump_model(model_id: i64, dir: PathBuf) { + if dir.exists() { + std::fs::remove_dir_all(&dir).unwrap(); + } + std::fs::create_dir_all(&dir).unwrap(); + Spi::connect(|client| { + let result = client.select("SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC", + None, + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), + ]) + ).unwrap(); + for row in result { + let mut path = dir.clone(); + path.push(row.get::(1).unwrap().unwrap()); + let data: Vec = row.get(3).unwrap().unwrap(); + let mut file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(path) + .unwrap(); + file.write(&data).unwrap(); + file.flush().unwrap(); + } + }); +} + +pub fn load_dataset( + name: &str, + subset: Option, + limit: Option, + kwargs: &serde_json::Value, +) -> usize { + let kwargs = serde_json::to_string(kwargs).unwrap(); + + let dataset = Python::with_gil(|py| -> String { + let load_dataset: Py = PY_MODULE.getattr(py, "load_dataset").unwrap().into(); + load_dataset + .call1( + py, + PyTuple::new( + py, + &[ + name.into_py(py), + subset.into_py(py), + limit.into_py(py), + kwargs.into_py(py), + ], + ), + ) + .unwrap() + .extract(py) + .unwrap() + }); + + let table_name = format!("pgml.\"{}\"", name); + + // Columns are a (name: String, values: Vec) pair + let json: serde_json::Value = serde_json::from_str(&dataset).unwrap(); + let json = json.as_object().unwrap(); + let types = json.get("types").unwrap().as_object().unwrap(); + let data = json.get("data").unwrap().as_object().unwrap(); + let column_names = types + .iter() + .map(|(name, _type)| name.clone()) + .collect::>() + .join(", "); + let column_types = types + .iter() + .map(|(name, type_)| { + let type_ = match type_.as_str().unwrap() { + "string" => "TEXT", + "dict" => "JSONB", + "int64" => "INT8", + "int32" => "INT4", + "int16" => "INT2", + "float64" => "FLOAT8", + "float32" => "FLOAT4", + "float16" => "FLOAT4", + "bool" => "BOOLEAN", + _ => error!( + "unhandled dataset feature while reading dataset: {:?}", + type_ + ), + }; + format!("{name} {type_}") + }) + .collect::>() + .join(", "); + let column_placeholders = types + .iter() + .enumerate() + .map(|(i, _)| { + let placeholder = i + 1; + format!("${placeholder}") + }) + .collect::>() + .join(", "); + let num_cols = types.len(); + let num_rows = data.values().next().unwrap().as_array().unwrap().len(); + + // Avoid the existence warning by checking the schema for the table first + let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ + (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) + ]).unwrap().unwrap(); + match table_count { + 1 => Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap(), + _ => (), + } + + Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap(); + let insert = + format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); + for i in 0..num_rows { + let mut row = Vec::with_capacity(num_cols); + for (name, values) in data { + let value = values.as_array().unwrap().get(i).unwrap(); + match types.get(name).unwrap().as_str().unwrap() { + "string" => row.push(( + PgBuiltInOids::TEXTOID.oid(), + value.as_str().unwrap().into_datum(), + )), + "dict" => row.push(( + PgBuiltInOids::JSONBOID.oid(), + JsonB(value.clone()).into_datum(), + )), + "int64" | "int32" | "int16" => row.push(( + PgBuiltInOids::INT8OID.oid(), + value.as_i64().unwrap().into_datum(), + )), + "float64" | "float32" | "float16" => row.push(( + PgBuiltInOids::FLOAT8OID.oid(), + value.as_f64().unwrap().into_datum(), + )), + "bool" => row.push(( + PgBuiltInOids::BOOLOID.oid(), + value.as_bool().unwrap().into_datum(), + )), + type_ => error!( + "unhandled dataset value type while reading dataset: {:?} {:?}", + value, type_ + ), + } + } + Spi::run_with_args(&insert, Some(row)).unwrap(); + } + + num_rows +} diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs index bfa373bd6..b7bebd91d 100644 --- a/pgml-extension/src/bindings/xgboost.rs +++ b/pgml-extension/src/bindings/xgboost.rs @@ -285,7 +285,8 @@ impl Bindings for Estimator { '2' )::bigint", ) - .unwrap().unwrap(); + .unwrap() + .unwrap(); estimator .set_param("nthread", &concurrency.to_string()) diff --git a/pgml-extension/src/orm/algorithm.rs b/pgml-extension/src/orm/algorithm.rs index b5142a181..277e90147 100644 --- a/pgml-extension/src/orm/algorithm.rs +++ b/pgml-extension/src/orm/algorithm.rs @@ -37,6 +37,7 @@ pub enum Algorithm { hist_gradient_boosting, linear_svm, lightgbm, + transformers, } impl std::str::FromStr for Algorithm { @@ -77,6 +78,7 @@ impl std::str::FromStr for Algorithm { "hist_gradient_boosting" => Ok(Algorithm::hist_gradient_boosting), "linear_svm" => Ok(Algorithm::linear_svm), "lightgbm" => Ok(Algorithm::lightgbm), + "transformers" => Ok(Algorithm::transformers), _ => Err(()), } } @@ -120,6 +122,7 @@ impl std::string::ToString for Algorithm { Algorithm::hist_gradient_boosting => "hist_gradient_boosting".to_string(), Algorithm::linear_svm => "linear_svm".to_string(), Algorithm::lightgbm => "lightgbm".to_string(), + Algorithm::transformers => "transformers".to_string(), } } } diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index e4de76bdc..75ab1e4e5 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -68,6 +68,41 @@ impl Dataset { } } +#[derive(Debug)] +pub struct TextDataset { + pub x_train: Vec, + pub y_train: Vec, + pub x_test: Vec, + pub y_test: Vec, + pub num_features: usize, + pub num_labels: usize, + pub num_rows: usize, + pub num_train_rows: usize, + pub num_test_rows: usize, + pub num_distinct_labels: usize, +} + +impl Display for TextDataset { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + write!( + f, + "TextDataset {{ num_features: {}, num_labels: {}, num_distinct_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}", + self.num_features, self.num_labels, self.num_distinct_labels, self.num_rows, self.num_train_rows, self.num_test_rows, + ) + } +} + +fn drop_table_if_exists(table_name: &str) { + // Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first + let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ + (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) + ]).unwrap().unwrap(); + match table_count { + 1 => Spi::run(&format!(r#"DROP TABLE {table_name}"#)).unwrap(), + _ => (), + } +} + #[derive(Deserialize)] struct BreastCancerRow { mean_radius: f32, @@ -104,7 +139,7 @@ struct BreastCancerRow { } pub fn load_breast_cancer(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.breast_cancer").unwrap(); + drop_table_if_exists("breast_cancer"); Spi::run( r#"CREATE TABLE pgml.breast_cancer ( "mean radius" FLOAT4, @@ -139,7 +174,8 @@ pub fn load_breast_cancer(limit: Option) -> (String, i64) { "worst fractal dimension" FLOAT4, "malignant" BOOLEAN )"#, - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, @@ -297,7 +333,7 @@ struct DiabetesRow { } pub fn load_diabetes(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.diabetes").unwrap(); + drop_table_if_exists("diabetes"); Spi::run( "CREATE TABLE pgml.diabetes ( age FLOAT4, @@ -312,7 +348,8 @@ pub fn load_diabetes(limit: Option) -> (String, i64) { s6 FLOAT4, target FLOAT4 )", - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, @@ -347,7 +384,8 @@ pub fn load_diabetes(limit: Option) -> (String, i64) { (PgBuiltInOids::FLOAT4OID.oid(), row.s6.into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), row.target.into_datum()), ]), - ).unwrap(); + ) + .unwrap(); inserted += 1; } @@ -361,7 +399,7 @@ struct DigitsRow { } pub fn load_digits(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.digits").unwrap(); + drop_table_if_exists("digits"); Spi::run("CREATE TABLE pgml.digits (image SMALLINT[][], target SMALLINT)").unwrap(); let limit = match limit { @@ -388,7 +426,8 @@ pub fn load_digits(limit: Option) -> (String, i64) { (PgBuiltInOids::TEXTOID.oid(), row.image.into_datum()), (PgBuiltInOids::INT2OID.oid(), row.target.into_datum()), ]), - ).unwrap(); + ) + .unwrap(); inserted += 1; } @@ -405,7 +444,7 @@ struct IrisRow { } pub fn load_iris(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.iris").unwrap(); + drop_table_if_exists("iris"); Spi::run( "CREATE TABLE pgml.iris ( sepal_length FLOAT4, @@ -414,7 +453,8 @@ pub fn load_iris(limit: Option) -> (String, i64) { petal_width FLOAT4, target INT4 )", - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, @@ -449,7 +489,8 @@ pub fn load_iris(limit: Option) -> (String, i64) { (PgBuiltInOids::FLOAT4OID.oid(), row.petal_width.into_datum()), (PgBuiltInOids::INT4OID.oid(), row.target.into_datum()), ]), - ).unwrap(); + ) + .unwrap(); inserted += 1; } @@ -467,7 +508,7 @@ struct LinnerudRow { } pub fn load_linnerud(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.linnerud").unwrap(); + drop_table_if_exists("linnerud"); Spi::run( "CREATE TABLE pgml.linnerud( chins FLOAT4, @@ -477,7 +518,8 @@ pub fn load_linnerud(limit: Option) -> (String, i64) { waist FLOAT4, pulse FLOAT4 )", - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, @@ -507,7 +549,8 @@ pub fn load_linnerud(limit: Option) -> (String, i64) { (PgBuiltInOids::FLOAT4OID.oid(), row.waist.into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), row.pulse.into_datum()), ]), - ).unwrap(); + ) + .unwrap(); inserted += 1; } @@ -533,7 +576,7 @@ struct WineRow { } pub fn load_wine(limit: Option) -> (String, i64) { - Spi::run("DROP TABLE IF EXISTS pgml.wine").unwrap(); + drop_table_if_exists("wine"); Spi::run( r#"CREATE TABLE pgml.wine ( alcohol FLOAT4, @@ -551,7 +594,8 @@ pub fn load_wine(limit: Option) -> (String, i64) { proline FLOAT4, target INT4 )"#, - ).unwrap(); + ) + .unwrap(); let limit = match limit { Some(limit) => limit, diff --git a/pgml-extension/src/orm/file.rs b/pgml-extension/src/orm/file.rs index 9c727c172..89fa059c5 100644 --- a/pgml-extension/src/orm/file.rs +++ b/pgml-extension/src/orm/file.rs @@ -60,14 +60,10 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc Arc Arc { crate::bindings::linfa::LogisticRegression::from_bytes(&data) } + _ => error!("Rust runtime only supports `classification` and `regression` task types for linear algorithms."), }, Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data), _ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams), diff --git a/pgml-extension/src/orm/mod.rs b/pgml-extension/src/orm/mod.rs index 590dbaa04..abe00f1c1 100644 --- a/pgml-extension/src/orm/mod.rs +++ b/pgml-extension/src/orm/mod.rs @@ -13,6 +13,7 @@ pub mod task; pub use algorithm::Algorithm; pub use dataset::Dataset; +pub use dataset::TextDataset; pub use model::Model; pub use project::Project; pub use runtime::Runtime; diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 4b778fc1f..bc9fbd7c3 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -10,13 +10,12 @@ use indexmap::IndexMap; use itertools::{izip, Itertools}; use ndarray::ArrayView1; use once_cell::sync::Lazy; -use pgx::*; use pgx::heap_tuple::PgHeapTuple; +use pgx::*; use rand::prelude::SliceRandom; use serde_json::json; use crate::bindings::*; -use crate::orm::Dataset; use crate::orm::*; #[allow(clippy::type_complexity)] @@ -79,6 +78,7 @@ impl Model { Algorithm::linear => match project.task { Task::classification => Runtime::python, Task::regression => Runtime::rust, + _ => error!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."), }, _ => Runtime::python, }, @@ -87,13 +87,13 @@ impl Model { }, }; - let dataset = snapshot.dataset(); + let dataset = snapshot.tabular_dataset(); let status = Status::in_progress; // Create the model record. Spi::connect(|client| { let result = client.select(" - INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) - VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) + INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) + VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) RETURNING id, project_id, snapshot_id, algorithm, runtime, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", Some(1), Some(vec![ @@ -129,7 +129,7 @@ impl Model { bindings: None, num_classes: match project.task { Task::regression => 0, - Task::classification => snapshot.num_classes(), + _ => snapshot.num_classes(), }, num_features: snapshot.num_features(), }); @@ -141,7 +141,6 @@ impl Model { let mut model = model.unwrap(); info!("Training {}", model); - model.fit(&dataset); Spi::run_with_args( @@ -152,12 +151,123 @@ impl Model { Status::successful.to_string().into_datum(), ), (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), - ]) - ).unwrap(); + ]), + ) + .unwrap(); model } + #[allow(clippy::too_many_arguments)] + pub fn tune(project: &Project, snapshot: &mut Snapshot, hyperparams: &JsonB) -> Model { + let mut model: Option = None; + let dataset = snapshot.text_dataset(); + + // Create the model record. + Spi::connect(|client| { + let result = client.select(" + INSERT INTO pgml.models (project_id, snapshot_id, algorithm, runtime, hyperparams, status, search, search_params, search_args, num_features) + VALUES ($1, $2, cast($3 AS pgml.algorithm), cast($4 AS pgml.runtime), $5, cast($6 as pgml.status), $7, $8, $9, $10) + RETURNING id, project_id, snapshot_id, algorithm, runtime::TEXT, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;", + Some(1), + Some(vec![ + (PgBuiltInOids::INT8OID.oid(), project.id.into_datum()), + (PgBuiltInOids::INT8OID.oid(), snapshot.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Algorithm::transformers.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Runtime::python.to_string().into_datum()), + (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(hyperparams)).into_datum()), + (PgBuiltInOids::TEXTOID.oid(), Status::in_progress.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), None::>.into_datum()), + (PgBuiltInOids::JSONBOID.oid(), JsonB(serde_json::from_str("{}").unwrap()).into_datum()), + (PgBuiltInOids::JSONBOID.oid(), JsonB(serde_json::from_str("{}").unwrap()).into_datum()), + (PgBuiltInOids::INT8OID.oid(), (dataset.num_features as i64).into_datum()), + ]), + ).unwrap().first(); + if !result.is_empty() { + model = Some(Model { + id: result.get(1).unwrap().unwrap(), + project_id: result.get(2).unwrap().unwrap(), + snapshot_id: result.get(3).unwrap().unwrap(), + algorithm: Algorithm::from_str(result.get(4).unwrap().unwrap()).unwrap(), + runtime: Runtime::from_str(result.get(5).unwrap().unwrap()).unwrap(), + hyperparams: result.get(6).unwrap().unwrap(), + status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), + metrics: result.get(8).unwrap(), + search: result + .get(9) + .unwrap() + .map(|search| Search::from_str(search).unwrap()), + search_params: result.get(10).unwrap().unwrap(), + search_args: result.get(11).unwrap().unwrap(), + created_at: result.get(12).unwrap().unwrap(), + updated_at: result.get(13).unwrap().unwrap(), + project: project.clone(), + snapshot: snapshot.clone(), + bindings: None, + num_classes: 0, + num_features: snapshot.num_features(), + }); + } + + result + }); + + let mut model = model.unwrap(); + let id = model.id; + let path = std::path::PathBuf::from(format!("/tmp/postgresml/models/{id}")); + + info!("Tuning {}", model); + let metrics = transformers::tune(&project.task, dataset, &model.hyperparams, &path); + model.metrics = Some(JsonB(json!(metrics))); + info!("Metrics: {:?}", &metrics); + + Spi::get_one_with_args::( + "UPDATE pgml.models SET hyperparams = $1, metrics = $2 WHERE id = $3 RETURNING id", + vec![ + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(model.hyperparams.0.clone()).into_datum(), + ), + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(model.metrics.as_ref().unwrap().0.clone()).into_datum(), + ), + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + ], + ) + .unwrap(); + + // Save the bindings. + for entry in std::fs::read_dir(&path).unwrap() { + let path = entry.unwrap().path(); + let bytes = std::fs::read(&path).unwrap(); + for (i, chunk) in bytes.chunks(100_000_000).enumerate() { + Spi::get_one_with_args::( + "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, $2, $3, $4) RETURNING id", + vec![ + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + (PgBuiltInOids::TEXTOID.oid(), path.file_name().unwrap().to_str().into_datum()), + (PgBuiltInOids::INT8OID.oid(), (i as i64).into_datum()), + (PgBuiltInOids::BYTEAOID.oid(), chunk.into_datum()), + ], + ).unwrap(); + } + } + + Spi::run_with_args( + "UPDATE pgml.models SET status = $1::pgml.status WHERE id = $2", + Some(vec![ + ( + PgBuiltInOids::TEXTOID.oid(), + Status::successful.to_string().into_datum(), + ), + (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), + ]), + ) + .unwrap(); + model + } + fn find(id: i64) -> Model { let mut model = None; // Create the model record. @@ -187,7 +297,9 @@ impl Model { WHERE files.model_id = $1 LIMIT 1", vec![(PgBuiltInOids::INT8OID.oid(), id.into_datum())], - ).unwrap().unwrap(); + ) + .unwrap() + .unwrap(); let bindings: Box = match runtime { Runtime::rust => { @@ -205,6 +317,7 @@ impl Model { Task::classification => { crate::bindings::linfa::LogisticRegression::from_bytes(&data) } + _ => error!("No default runtime available for tasks other than `classification` and `regression` when using a linear algorithm."), }, Algorithm::svm => crate::bindings::linfa::Svm::from_bytes(&data), _ => todo!(), //smartcore_load(&data, task, algorithm, &hyperparams), @@ -223,7 +336,7 @@ impl Model { let num_features = snapshot.num_features(); let num_classes = match project.task { Task::regression => 0, - Task::classification => snapshot.num_classes(), + _ => snapshot.num_classes(), }; model = Some(Model { @@ -236,7 +349,8 @@ impl Model { status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), search: result - .get(9).unwrap() + .get(9) + .unwrap() .map(|search| Search::from_str(search).unwrap()), search_params: result.get(10).unwrap().unwrap(), search_args: result.get(11).unwrap().unwrap(), @@ -253,7 +367,12 @@ impl Model { result }); - model.unwrap_or_else(|| error!("pgml.models WHERE id = {:?} could not be loaded. Does it exist?", id)) + model.unwrap_or_else(|| { + error!( + "pgml.models WHERE id = {:?} could not be loaded. Does it exist?", + id + ) + }) } pub fn find_cached(id: i64) -> Arc { @@ -288,6 +407,7 @@ impl Model { Algorithm::svm => linfa::Svm::fit, _ => todo!(), }, + _ => error!("use pgml.tune for transformers tasks"), }, #[cfg(not(feature = "python"))] @@ -366,6 +486,7 @@ impl Model { Algorithm::lightgbm => sklearn::lightgbm_classification, _ => panic!("{:?} does not support classification", self.algorithm), }, + _ => error!("use pgml.tune for transformers tasks"), }, } } @@ -530,6 +651,7 @@ impl Model { // This one is inaccurate, I have it in my TODO to reimplement. metrics.insert("mcc".to_string(), confusion_matrix.mcc()); } + _ => error!("no tests for huggingface"), } metrics @@ -621,7 +743,7 @@ impl Model { check_for_interrupts!(); }) } - .unwrap(); + .unwrap(); let mut n_iter: usize = 10; let mut cv: usize = if self.search.is_some() { 5 } else { 1 }; @@ -678,6 +800,12 @@ impl Model { let target_metric = match self.project.task { Task::regression => "r2", Task::classification => "f1", + Task::question_answering => "f1", + Task::translation => "blue", + Task::summarization => "rouge_ngram_f1", + Task::text_classification => "f1", + Task::text_generation => "perplexity", + Task::text2text => "perplexity", }; let mut i = 0; let mut best_index = 0; @@ -788,7 +916,7 @@ impl Model { (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), ], ) - .unwrap(); + .unwrap(); // Save the bindings. Spi::get_one_with_args::( @@ -819,42 +947,73 @@ impl Model { pgx_pg_sys::UNKNOWNOID => { error!("Type information missing for column: {:?}. If this is intended to be a TEXT or other categorical column, you will need to explicitly cast it, e.g. change `{:?}` to `CAST({:?} AS TEXT)`.", column.name, column.name, column.name); } - pgx_pg_sys::TEXTOID | pgx_pg_sys::VARCHAROID | pgx_pg_sys::BPCHAROID => { + pgx_pg_sys::TEXTOID + | pgx_pg_sys::VARCHAROID + | pgx_pg_sys::BPCHAROID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) + element + .unwrap() + .unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) } pgx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::INT2OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::INT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::INT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::FLOAT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } pgx_pg_sys::FLOAT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - element.unwrap().map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { + k.to_string() + }) } - _ => error!("Unsupported type for categorical column: {:?}. oid: {:?}", column.name, attribute.atttypid), + _ => error!( + "Unsupported type for categorical column: {:?}. oid: {:?}", + column.name, attribute.atttypid + ), }; let value = column.get_category_value(&key); features.push(value); @@ -867,22 +1026,27 @@ impl Model { pgx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as u8 as f32)); + features.push( + element.unwrap().map_or(f32::NAN, |v| v as u8 as f32), + ); } pgx_pg_sys::INT2OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + features + .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgx_pg_sys::INT4OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + features + .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgx_pg_sys::INT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + features + .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgx_pg_sys::FLOAT4OID => { let element: Result, TryFromDatumError> = @@ -892,7 +1056,8 @@ impl Model { pgx_pg_sys::FLOAT8OID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index.try_into().unwrap()); - features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + features + .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } // TODO handle NULL to NaN for arrays pgx_pg_sys::BOOLARRAYOID => { @@ -937,13 +1102,18 @@ impl Model { features.push(*j as f32); } } - _ => error!("Unsupported type for quantitative column: {:?}. oid: {:?}", column.name, attribute.atttypid), + _ => error!( + "Unsupported type for quantitative column: {:?}. oid: {:?}", + column.name, attribute.atttypid + ), } } } } } - _ => error!("This preprocessing requires Postgres `record` types created with `row()`.") + _ => error!( + "This preprocessing requires Postgres `record` types created with `row()`." + ), } } features @@ -961,6 +1131,7 @@ impl Model { .as_ref() .unwrap() .predict_proba(features, self.num_features), + _ => error!("no predict_proba for huggingface"), } } @@ -974,6 +1145,7 @@ impl Model { Task::classification => { error!("You can't predict joint probabilities for a classification model") } + _ => error!("no predict_joint for huggingface"), } } diff --git a/pgml-extension/src/orm/project.rs b/pgml-extension/src/orm/project.rs index c0f372930..caf12e022 100644 --- a/pgml-extension/src/orm/project.rs +++ b/pgml-extension/src/orm/project.rs @@ -43,7 +43,7 @@ impl Project { let project_id = match projects.get(project_name) { Some(project_id) => *project_id, None => { - let (project_id, model_id) = Spi::get_two_with_args::( + let result = Spi::get_two_with_args::( "SELECT deployments.project_id, deployments.model_id FROM pgml.deployments JOIN pgml.projects ON projects.id = deployments.project_id @@ -51,7 +51,14 @@ impl Project { ORDER BY deployments.created_at DESC LIMIT 1", vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ).unwrap(); + ); + let (project_id, model_id) = match result { + Ok(o) => o, + Err(_) => error!( + "No deployed model exists for the project named: `{}`", + project_name + ), + }; let project_id = project_id.unwrap_or_else(|| { error!( "No deployed model exists for the project named: `{}`", @@ -74,7 +81,6 @@ impl Project { project_id } }; - *PROJECT_ID_TO_DEPLOYED_MODEL_ID .share() .get(&project_id) @@ -157,7 +163,7 @@ impl Project { Some(1), Some(vec![ (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), task.to_string().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), task.to_pg_enum().into_datum()), ]) ).unwrap().first(); if !result.is_empty() { diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 434de00de..121c82066 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -1,17 +1,17 @@ use std::cmp::Ordering; +use std::collections::HashMap; use std::fmt::{Display, Error, Formatter}; use std::str::FromStr; -use std::collections::HashMap; -use ndarray::Zip; use indexmap::IndexMap; +use ndarray::Zip; use pgx::*; use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::orm::Dataset; use crate::orm::Sampling; use crate::orm::Status; +use crate::orm::{Dataset, TextDataset}; // Categories use a designated string to represent NULL categorical values, // rather than Option = None, because the JSONB serialization schema @@ -142,8 +142,7 @@ impl Column { fn nominal_type(pg_type: &str) -> bool { match pg_type { - "bpchar" | "text" | "varchar" | - "bpchar[]" | "text[]" | "varchar[]" => true, + "bpchar" | "text" | "varchar" | "bpchar[]" | "text[]" | "varchar[]" => true, _ => false, } } @@ -164,10 +163,15 @@ impl Column { pub(crate) fn scale(&self, value: f32) -> f32 { match self.preprocessor.scale { Scale::standard => (value - self.statistics.mean) / self.statistics.std_dev, - Scale::min_max => (value - self.statistics.min) / (self.statistics.max - self.statistics.min), + Scale::min_max => { + (value - self.statistics.min) / (self.statistics.max - self.statistics.min) + } Scale::max_abs => value / self.statistics.max_abs, - Scale::robust => (value - self.statistics.median) / (self.statistics.ventiles[15] - self.statistics.ventiles[5]), - Scale::preserve => value + Scale::robust => { + (value - self.statistics.median) + / (self.statistics.ventiles[15] - self.statistics.ventiles[5]) + } + Scale::preserve => value, } } @@ -191,7 +195,7 @@ impl Column { pub(crate) fn encoded_width(&self) -> usize { match self.preprocessor.encode { Encode::one_hot => self.statistics.categories.as_ref().unwrap().len() - 1, - _ => 1 + _ => 1, } } @@ -199,7 +203,13 @@ impl Column { self.size } - pub(crate) fn preprocess(&self, data: &ndarray::ArrayView, processed_data: &mut Vec, features_width: usize, position: usize) { + pub(crate) fn preprocess( + &self, + data: &ndarray::ArrayView, + processed_data: &mut Vec, + features_width: usize, + position: usize, + ) { for (row, &d) in data.iter().enumerate() { let value = self.impute(d); match &self.preprocessor.encode { @@ -208,13 +218,17 @@ impl Column { let one_hot = if i == value as usize { 1. } else { 0. } as f32; processed_data[row * features_width + position + i] = one_hot; } - }, + } _ => processed_data[row * features_width + position] = self.scale(value), }; } } - fn analyze(&mut self, array: &ndarray::ArrayView, target: &ndarray::ArrayView) { + fn analyze( + &mut self, + array: &ndarray::ArrayView, + target: &ndarray::ArrayView, + ) { // target encode if necessary before analyzing match &self.preprocessor.encode { Encode::target_mean => { @@ -232,21 +246,32 @@ impl Column { } // Data is filtered for NaN because it is not well defined statistically, and they are counted as separate stat - let mut data = array.iter().filter_map(|n| if n.is_nan() { None } else { Some(*n) }).collect::>(); + let mut data = array + .iter() + .filter_map(|n| if n.is_nan() { None } else { Some(*n) }) + .collect::>(); data.sort_by(|a, b| a.total_cmp(&b)); // FixMe: Arrays are analyzed many times, clobbering/appending to the same stats, columns are also re-analyzed in memory during tests, which can cause unnexpected failures let mut statistics = &mut self.statistics; statistics.min = *data.first().unwrap(); statistics.max = *data.last().unwrap(); - statistics.max_abs = if statistics.min.abs() > statistics.max.abs() { statistics.min.abs() } else { statistics.max.abs() }; + statistics.max_abs = if statistics.min.abs() > statistics.max.abs() { + statistics.min.abs() + } else { + statistics.max.abs() + }; statistics.mean = data.iter().sum::() / data.len() as f32; statistics.median = data[data.len() / 2]; statistics.missing = array.len() - data.len(); - statistics.variance = data.iter().map(|i| { - let diff = statistics.mean - (*i); - diff * diff - }).sum::() / data.len() as f32; + statistics.variance = data + .iter() + .map(|i| { + let diff = statistics.mean - (*i); + diff * diff + }) + .sum::() + / data.len() as f32; statistics.std_dev = statistics.variance.sqrt(); let mut i = 0; let histogram_boundaries = ndarray::Array::linspace(statistics.min, statistics.max, 21); @@ -263,7 +288,7 @@ impl Column { if value == previous { streak += 1; } else if !previous.is_nan() { - if streak > max_streak { + if streak > max_streak { modes = vec![previous]; max_streak = streak; } else if streak == max_streak { @@ -443,7 +468,8 @@ impl Snapshot { let jsonb: JsonB = result.get(7).unwrap().unwrap(); let columns: Vec = serde_json::from_value(jsonb.0).unwrap(); let jsonb: JsonB = result.get(8).unwrap().unwrap(); - let analysis: Option> = Some(serde_json::from_value(jsonb.0).unwrap()); + let analysis: Option> = + Some(serde_json::from_value(jsonb.0).unwrap()); let mut s = Snapshot { id: result.get(1).unwrap().unwrap(), @@ -481,7 +507,8 @@ impl Snapshot { // Validate table exists. let (schema_name, table_name) = Self::fully_qualified_table(relation_name); - let preprocessors: HashMap = serde_json::from_value(preprocess.0).expect("is valid"); + let preprocessors: HashMap = + serde_json::from_value(preprocess.0).expect("is valid"); Spi::connect(|client| { let mut columns: Vec = Vec::new(); @@ -604,7 +631,7 @@ impl Snapshot { } pub(crate) fn labels(&self) -> impl Iterator { - self.columns.iter().filter(|c| c.label ) + self.columns.iter().filter(|c| c.label) } pub(crate) fn label_positions(&self) -> Vec { @@ -612,7 +639,10 @@ impl Snapshot { let mut row_position = 0; for column in self.labels() { for _ in 0..column.size { - label_positions.push(ColumnRowPosition {column_position: column.position, row_position}); + label_positions.push(ColumnRowPosition { + column_position: column.position, + row_position, + }); row_position += column.encoded_width(); } } @@ -620,7 +650,7 @@ impl Snapshot { } pub(crate) fn features(&self) -> impl Iterator { - self.columns.iter().filter(|c| !c.label ) + self.columns.iter().filter(|c| !c.label) } pub(crate) fn feature_positions(&self) -> Vec { @@ -628,7 +658,10 @@ impl Snapshot { let mut row_position = 0; for column in self.features() { for _ in 0..column.size { - feature_positions.push(ColumnRowPosition {column_position: column.position, row_position}); + feature_positions.push(ColumnRowPosition { + column_position: column.position, + row_position, + }); row_position += column.encoded_width(); } } @@ -636,11 +669,14 @@ impl Snapshot { } pub(crate) fn num_labels(&self) -> usize { - self.labels().map(|f| f.size ).sum::() + self.labels().map(|f| f.size).sum::() } pub(crate) fn first_label(&self) -> &Column { - self.labels().filter(|l| l.name == self.y_column_name[0] ).next().unwrap() + self.labels() + .filter(|l| l.name == self.y_column_name[0]) + .next() + .unwrap() } pub(crate) fn num_classes(&self) -> usize { @@ -651,15 +687,15 @@ impl Snapshot { } pub(crate) fn num_features(&self) -> usize { - self.features().map(|c| c.size ).sum::() + self.features().map(|c| c.size).sum::() } pub(crate) fn features_width(&self) -> usize { - self.features().map(|f| f.array_width() * f.encoded_width() ).sum::() + self.features() + .map(|f| f.array_width() * f.encoded_width()) + .sum::() } - - fn fully_qualified_table(relation_name: &str) -> (String, String) { let parts = relation_name .split('.') @@ -708,14 +744,125 @@ impl Snapshot { } } - pub fn dataset(&mut self) -> Dataset { + fn select_sql(&self) -> String { + format!( + "SELECT {} FROM {} {}", + self.columns + .iter() + .map(|c| c.quoted_name()) + .collect::>() + .join(", "), + self.relation_name(), + match self.materialized { + // If the snapshot is materialized, we already randomized it. + true => "", + false => { + if self.test_sampling == Sampling::random { + "ORDER BY random()" + } else { + "" + } + } + }, + ) + } + + fn train_test_split(&self, num_rows: usize) -> (usize, usize) { + let num_test_rows = if self.test_size > 1.0 { + self.test_size as usize + } else { + (num_rows as f32 * self.test_size).round() as usize + }; + + let num_train_rows = num_rows - num_test_rows; + if num_train_rows == 0 { + error!( + "test_size = {} is too large. There are only {} samples.", + num_test_rows, num_rows + ); + } + + (num_train_rows, num_test_rows) + } + + pub fn text_dataset(&mut self) -> TextDataset { + let mut data = None; + + Spi::connect(|client| { + let result = client.select(&self.select_sql(), None, None).unwrap(); + let num_rows = result.len(); + let (num_train_rows, num_test_rows) = self.train_test_split(num_rows); + let num_features = self.num_features(); + let num_labels = self.num_labels(); + + let mut x_train: Vec = Vec::with_capacity(num_train_rows * num_features); + let mut y_train: Vec = Vec::with_capacity(num_train_rows * num_labels); + let mut x_test: Vec = Vec::with_capacity(num_test_rows * num_features); + let mut y_test: Vec = Vec::with_capacity(num_test_rows * num_labels); + + result.enumerate().for_each(|(i, row)| { + for column in &mut self.columns { + let vector = if column.label { + if i < num_train_rows { + &mut y_train + } else { + &mut y_test + } + } else if i < num_train_rows { + &mut x_train + } else { + &mut x_test + }; + + match column.pg_type.as_str() { + "bpchar" | "text" | "varchar" => { + match row[column.position].value::().unwrap() { + Some(text) => vector.push(text), + None => error!("NULL training text is not handled"), + } + } + _ => error!("only text type columns are supported"), + } + } + }); + + data = Some(TextDataset { + x_train, + y_train, + x_test, + y_test, + num_features, + num_labels, + num_rows, + num_test_rows, + num_train_rows, + // TODO rename and audit this + num_distinct_labels: self.num_classes(), + }); + + Ok::, i64>(Some(())) // this return type is nonsense + }) + .unwrap(); + + let data = data.unwrap(); + + info!("{}", data); + + data + } + + pub fn tabular_dataset(&mut self) -> Dataset { let numeric_encoded_dataset = self.numeric_encoded_dataset(); // Analyze labels let label_data = ndarray::ArrayView2::from_shape( - (numeric_encoded_dataset.num_train_rows, numeric_encoded_dataset.num_labels), + ( + numeric_encoded_dataset.num_train_rows, + numeric_encoded_dataset.num_labels, + ), &numeric_encoded_dataset.y_train, - ).unwrap(); + ) + .unwrap(); // The data for the first label let target_data = label_data.columns().into_iter().next().unwrap(); @@ -728,9 +875,13 @@ impl Snapshot { // Analyze features let feature_data = ndarray::ArrayView2::from_shape( - (numeric_encoded_dataset.num_train_rows, numeric_encoded_dataset.num_features), + ( + numeric_encoded_dataset.num_train_rows, + numeric_encoded_dataset.num_features, + ), &numeric_encoded_dataset.x_train, - ).unwrap(); + ) + .unwrap(); Zip::from(feature_data.columns()) .and(&self.feature_positions()) .for_each(|data, position| { @@ -739,18 +890,28 @@ impl Snapshot { }); let mut analysis = IndexMap::new(); - analysis.insert("samples".to_string(), numeric_encoded_dataset.num_rows as f32); + analysis.insert( + "samples".to_string(), + numeric_encoded_dataset.num_rows as f32, + ); self.analysis = Some(analysis); // Record the analysis Spi::run_with_args( "UPDATE pgml.snapshots SET analysis = $1, columns = $2 WHERE id = $3", Some(vec![ - (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(self.analysis)).into_datum()), - (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(self.columns)).into_datum()), + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(json!(self.analysis)).into_datum(), + ), + ( + PgBuiltInOids::JSONBOID.oid(), + JsonB(json!(self.columns)).into_datum(), + ), (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), - ]) - ).unwrap(); + ]), + ) + .unwrap(); let features_width = self.features_width(); let mut x_train = vec![0_f32; features_width * numeric_encoded_dataset.num_train_rows]; @@ -763,9 +924,13 @@ impl Snapshot { let mut x_test = vec![0_f32; features_width * numeric_encoded_dataset.num_test_rows]; let test_features = ndarray::ArrayView2::from_shape( - (numeric_encoded_dataset.num_test_rows, numeric_encoded_dataset.num_features), + ( + numeric_encoded_dataset.num_test_rows, + numeric_encoded_dataset.num_features, + ), &numeric_encoded_dataset.x_test, - ).unwrap(); + ) + .unwrap(); Zip::from(test_features.columns()) .and(&self.feature_positions()) .for_each(|data, position| { @@ -787,45 +952,10 @@ impl Snapshot { pub fn numeric_encoded_dataset(&mut self) -> Dataset { let mut data = None; Spi::connect(|client| { - let sql = format!( - "SELECT {} FROM {} {}", - self.columns - .iter() - .map(|c| c.quoted_name()) - .collect::>() - .join(", "), - self.relation_name(), - match self.materialized { - // If the snapshot is materialized, we already randomized it. - true => "", - false => { - if self.test_sampling == Sampling::random { - "ORDER BY random()" - } else { - "" - } - } - }, - ); - // Postgres Arrays arrays are 1 indexed and so are SPI tuples... - let result = client.select(&sql, None, None).unwrap(); + let result = client.select(&self.select_sql(), None, None).unwrap(); let num_rows = result.len(); - - let num_test_rows = if self.test_size > 1.0 { - self.test_size as usize - } else { - (num_rows as f32 * self.test_size).round() as usize - }; - - let num_train_rows = num_rows - num_test_rows; - if num_train_rows == 0 { - error!( - "test_size = {} is too large. There are only {} samples.", - num_test_rows, num_rows - ); - } - + let (num_train_rows, num_test_rows) = self.train_test_split(num_rows); let num_features = self.num_features(); let num_labels = self.num_labels(); @@ -968,7 +1098,7 @@ impl Snapshot { let num_features = self.num_features(); let num_labels = self.num_labels(); - data = Some(Dataset { + data = Some(Dataset{ x_train, y_train, x_test, @@ -1009,6 +1139,9 @@ fn check_column_size(column: &mut Column, len: usize) { if column.size == 0 { column.size = len; } else if column.size != len { - error!("Mismatched array length for feature `{}`. Expected: {} Received: {}", column.name, column.size, len); + error!( + "Mismatched array length for feature `{}`. Expected: {} Received: {}", + column.name, column.size, len + ); } } diff --git a/pgml-extension/src/orm/task.rs b/pgml-extension/src/orm/task.rs index 1c2cd92cf..f9285a2cd 100644 --- a/pgml-extension/src/orm/task.rs +++ b/pgml-extension/src/orm/task.rs @@ -6,6 +6,28 @@ use serde::Deserialize; pub enum Task { regression, classification, + question_answering, + summarization, + translation, + text_classification, + text_generation, + text2text, +} + +// unfortunately the pgx macro expands the enum names to underscore, but huggingface uses dash +impl Task { + pub fn to_pg_enum(&self) -> String { + match *self { + Task::regression => "regression".to_string(), + Task::classification => "classification".to_string(), + Task::question_answering => "question_answering".to_string(), + Task::summarization => "summarization".to_string(), + Task::translation => "translation".to_string(), + Task::text_classification => "text_classification".to_string(), + Task::text_generation => "text_generation".to_string(), + Task::text2text => "text2text".to_string(), + } + } } impl std::str::FromStr for Task { @@ -15,6 +37,12 @@ impl std::str::FromStr for Task { match input { "regression" => Ok(Task::regression), "classification" => Ok(Task::classification), + "question-answering" | "question_answering" => Ok(Task::question_answering), + "summarization" => Ok(Task::summarization), + "translation" => Ok(Task::translation), + "text-classification" | "text_classification" => Ok(Task::text_classification), + "text-generation" | "text_generation" => Ok(Task::text_generation), + "text2text" => Ok(Task::text2text), _ => Err(()), } } @@ -25,6 +53,12 @@ impl std::string::ToString for Task { match *self { Task::regression => "regression".to_string(), Task::classification => "classification".to_string(), + Task::question_answering => "question-answering".to_string(), + Task::summarization => "summarization".to_string(), + Task::translation => "translation".to_string(), + Task::text_classification => "text-classification".to_string(), + Task::text_generation => "text-generation".to_string(), + Task::text2text => "text2text".to_string(), } } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy