diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 5b8ddc4e7..bb97b31e8 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -4,8 +4,6 @@ use std::str::FromStr; use ndarray::Zip; use pgrx::iter::{SetOfIterator, TableIterator}; use pgrx::*; -use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyDict}; #[cfg(feature = "python")] use serde_json::json; @@ -634,40 +632,6 @@ pub fn transform_string( } } -struct TransformStreamIterator { - locals: Py, -} - -impl TransformStreamIterator { - fn new(python_iter: Py) -> Self { - let locals = Python::with_gil(|py| -> Result, PyErr> { - Ok([("python_iter", python_iter)].into_py_dict(py).into()) - }) - .map_err(|e| error!("{e}")) - .unwrap(); - Self { locals } - } -} - -impl Iterator for TransformStreamIterator { - type Item = String; - fn next(&mut self) -> Option { - // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - Python::with_gil(|py| -> Result, PyErr> { - let code = "next(python_iter)"; - let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?; - if res.is_none() { - Ok(None) - } else { - let res: String = res.extract()?; - Ok(Some(res)) - } - }) - .map_err(|e| error!("{e}")) - .unwrap() - } -} - #[cfg(all(feature = "python", not(feature = "use_as_lib")))] #[pg_extern(immutable, parallel_safe, name = "transform_stream")] #[allow(unused_variables)] // cache is maintained for api compatibility @@ -678,11 +642,11 @@ pub fn transform_stream_json( cache: default!(bool, false), ) -> SetOfIterator<'static, String> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); - let res = TransformStreamIterator::new(python_iter); - SetOfIterator::new(res) + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); + SetOfIterator::new(python_iter) } #[cfg(all(feature = "python", not(feature = "use_as_lib")))] @@ -696,11 +660,11 @@ pub fn transform_stream_string( ) -> SetOfIterator<'static, String> { let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); - let res = TransformStreamIterator::new(python_iter); - SetOfIterator::new(res) + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); + SetOfIterator::new(python_iter) } #[cfg(feature = "python")] diff --git a/pgml-extension/src/bindings/transformers/transformers.rs b/pgml-extension/src/bindings/transformers/transformers.rs index 55d59b070..6b89dd2a8 100644 --- a/pgml-extension/src/bindings/transformers/transformers.rs +++ b/pgml-extension/src/bindings/transformers/transformers.rs @@ -1,17 +1,52 @@ use super::whitelist; use super::TracebackError; use anyhow::Result; +use pgrx::*; use pyo3::prelude::*; -use pyo3::types::PyTuple; +use pyo3::types::{IntoPyDict, PyDict, PyTuple}; + create_pymodule!("/src/bindings/transformers/transformers.py"); +pub struct TransformStreamIterator { + locals: Py, +} + +impl TransformStreamIterator { + fn new(python_iter: Py) -> Self { + let locals = Python::with_gil(|py| -> Result, PyErr> { + Ok([("python_iter", python_iter)].into_py_dict(py).into()) + }) + .map_err(|e| error!("{e}")) + .unwrap(); + Self { locals } + } +} + +impl Iterator for TransformStreamIterator { + type Item = String; + fn next(&mut self) -> Option { + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + Python::with_gil(|py| -> Result, PyErr> { + let code = "next(python_iter)"; + let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?; + if res.is_none() { + Ok(None) + } else { + let res: String = res.extract()?; + Ok(Some(res)) + } + }) + .map_err(|e| error!("{e}")) + .unwrap() + } +} + pub fn transform( task: &serde_json::Value, args: &serde_json::Value, inputs: Vec<&str>, ) -> Result { crate::bindings::python::activate()?; - whitelist::verify_task(task)?; let task = serde_json::to_string(task)?; @@ -45,7 +80,6 @@ pub fn transform_stream( input: &str, ) -> Result> { crate::bindings::python::activate()?; - whitelist::verify_task(task)?; let task = serde_json::to_string(task)?; @@ -75,3 +109,14 @@ pub fn transform_stream( Ok(output) }) } + +pub fn transform_stream_iterator( + task: &serde_json::Value, + args: &serde_json::Value, + input: &str, +) -> Result { + let python_iter = transform_stream(task, args, input) + .map_err(|e| error!("{e}")) + .unwrap(); + Ok(TransformStreamIterator::new(python_iter)) +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy