Skip to content

Commit 0e9a873

Browse files
authored
Move PythonIterator into its own function (#1156)
1 parent d5c8629 commit 0e9a873

File tree

2 files changed

+58
-49
lines changed

2 files changed

+58
-49
lines changed

pgml-extension/src/api.rs

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ use std::str::FromStr;
44
use ndarray::Zip;
55
use pgrx::iter::{SetOfIterator, TableIterator};
66
use pgrx::*;
7-
use pyo3::prelude::*;
8-
use pyo3::types::{IntoPyDict, PyDict};
97

108
#[cfg(feature = "python")]
119
use serde_json::json;
@@ -634,40 +632,6 @@ pub fn transform_string(
634632
}
635633
}
636634

637-
struct TransformStreamIterator {
638-
locals: Py<PyDict>,
639-
}
640-
641-
impl TransformStreamIterator {
642-
fn new(python_iter: Py<PyAny>) -> Self {
643-
let locals = Python::with_gil(|py| -> Result<Py<PyDict>, PyErr> {
644-
Ok([("python_iter", python_iter)].into_py_dict(py).into())
645-
})
646-
.map_err(|e| error!("{e}"))
647-
.unwrap();
648-
Self { locals }
649-
}
650-
}
651-
652-
impl Iterator for TransformStreamIterator {
653-
type Item = String;
654-
fn next(&mut self) -> Option<Self::Item> {
655-
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
656-
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
657-
let code = "next(python_iter)";
658-
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
659-
if res.is_none() {
660-
Ok(None)
661-
} else {
662-
let res: String = res.extract()?;
663-
Ok(Some(res))
664-
}
665-
})
666-
.map_err(|e| error!("{e}"))
667-
.unwrap()
668-
}
669-
}
670-
671635
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
672636
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
673637
#[allow(unused_variables)] // cache is maintained for api compatibility
@@ -678,11 +642,11 @@ pub fn transform_stream_json(
678642
cache: default!(bool, false),
679643
) -> SetOfIterator<'static, String> {
680644
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
681-
let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, input)
682-
.map_err(|e| error!("{e}"))
683-
.unwrap();
684-
let res = TransformStreamIterator::new(python_iter);
685-
SetOfIterator::new(res)
645+
let python_iter =
646+
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input)
647+
.map_err(|e| error!("{e}"))
648+
.unwrap();
649+
SetOfIterator::new(python_iter)
686650
}
687651

688652
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
@@ -696,11 +660,11 @@ pub fn transform_stream_string(
696660
) -> SetOfIterator<'static, String> {
697661
let task_json = json!({ "task": task });
698662
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
699-
let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, input)
700-
.map_err(|e| error!("{e}"))
701-
.unwrap();
702-
let res = TransformStreamIterator::new(python_iter);
703-
SetOfIterator::new(res)
663+
let python_iter =
664+
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input)
665+
.map_err(|e| error!("{e}"))
666+
.unwrap();
667+
SetOfIterator::new(python_iter)
704668
}
705669

706670
#[cfg(feature = "python")]

pgml-extension/src/bindings/transformers/transformers.rs

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,52 @@
11
use super::whitelist;
22
use super::TracebackError;
33
use anyhow::Result;
4+
use pgrx::*;
45
use pyo3::prelude::*;
5-
use pyo3::types::PyTuple;
6+
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
7+
68
create_pymodule!("/src/bindings/transformers/transformers.py");
79

10+
pub struct TransformStreamIterator {
11+
locals: Py<PyDict>,
12+
}
13+
14+
impl TransformStreamIterator {
15+
fn new(python_iter: Py<PyAny>) -> Self {
16+
let locals = Python::with_gil(|py| -> Result<Py<PyDict>, PyErr> {
17+
Ok([("python_iter", python_iter)].into_py_dict(py).into())
18+
})
19+
.map_err(|e| error!("{e}"))
20+
.unwrap();
21+
Self { locals }
22+
}
23+
}
24+
25+
impl Iterator for TransformStreamIterator {
26+
type Item = String;
27+
fn next(&mut self) -> Option<Self::Item> {
28+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
29+
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
30+
let code = "next(python_iter)";
31+
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
32+
if res.is_none() {
33+
Ok(None)
34+
} else {
35+
let res: String = res.extract()?;
36+
Ok(Some(res))
37+
}
38+
})
39+
.map_err(|e| error!("{e}"))
40+
.unwrap()
41+
}
42+
}
43+
844
pub fn transform(
945
task: &serde_json::Value,
1046
args: &serde_json::Value,
1147
inputs: Vec<&str>,
1248
) -> Result<serde_json::Value> {
1349
crate::bindings::python::activate()?;
14-
1550
whitelist::verify_task(task)?;
1651

1752
let task = serde_json::to_string(task)?;
@@ -45,7 +80,6 @@ pub fn transform_stream(
4580
input: &str,
4681
) -> Result<Py<PyAny>> {
4782
crate::bindings::python::activate()?;
48-
4983
whitelist::verify_task(task)?;
5084

5185
let task = serde_json::to_string(task)?;
@@ -75,3 +109,14 @@ pub fn transform_stream(
75109
Ok(output)
76110
})
77111
}
112+
113+
pub fn transform_stream_iterator(
114+
task: &serde_json::Value,
115+
args: &serde_json::Value,
116+
input: &str,
117+
) -> Result<TransformStreamIterator> {
118+
let python_iter = transform_stream(task, args, input)
119+
.map_err(|e| error!("{e}"))
120+
.unwrap();
121+
Ok(TransformStreamIterator::new(python_iter))
122+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy