Skip to content

Commit 67ba48c

Browse files
authored
Updated to support streaming (#1151)
1 parent 3e8cc28 commit 67ba48c

File tree

11 files changed

+328
-135
lines changed

11 files changed

+328
-135
lines changed

pgml-sdks/pgml/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "0.9.5"
3+
version = "0.9.6"
44
edition = "2021"
55
authors = ["PosgresML <team@postgresml.org>"]
66
homepage = "https://postgresml.org/"

pgml-sdks/pgml/javascript/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "pgml",
3-
"version": "0.9.5",
3+
"version": "0.9.6",
44
"description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone",
55
"keywords": [
66
"postgres",

pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,28 @@ it("can order documents", async () => {
280280
await collection.archive();
281281
});
282282

283+
///////////////////////////////////////////////////
284+
// Transformer Pipeline Tests /////////////////////
285+
///////////////////////////////////////////////////
286+
287+
it("can transformer pipeline", async () => {
288+
const t = pgml.newTransformerPipeline("text-generation");
289+
const it = await t.transform(["AI is going to"], {max_new_tokens: 5});
290+
expect(it.length).toBeGreaterThan(0)
291+
});
292+
293+
it("can transformer pipeline stream", async () => {
294+
const t = pgml.newTransformerPipeline("text-generation");
295+
const it = await t.transform_stream("AI is going to", {max_new_tokens: 5});
296+
let result = await it.next();
297+
let output = [];
298+
while (!result.done) {
299+
output.push(result.value);
300+
result = await it.next();
301+
}
302+
expect(output.length).toBeGreaterThan(0)
303+
});
304+
283305
///////////////////////////////////////////////////
284306
// Test migrations ////////////////////////////////
285307
///////////////////////////////////////////////////

pgml-sdks/pgml/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "maturin"
55
[project]
66
name = "pgml"
77
requires-python = ">=3.7"
8-
version = "0.9.5"
8+
version = "0.9.6"
99
description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases."
1010
authors = [
1111
{name = "PostgresML", email = "team@postgresml.org"},

pgml-sdks/pgml/python/pgml/pgml.pyi

Lines changed: 0 additions & 96 deletions
This file was deleted.

pgml-sdks/pgml/python/tests/test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,27 @@ async def test_order_documents():
298298
await collection.archive()
299299

300300

301+
###################################################
302+
## Transformer Pipeline Tests #####################
303+
###################################################
304+
305+
306+
@pytest.mark.asyncio
307+
async def test_transformer_pipeline():
308+
t = pgml.TransformerPipeline("text-generation")
309+
it = await t.transform(["AI is going to"], {"max_new_tokens": 5})
310+
assert (len(it)) > 0
311+
312+
@pytest.mark.asyncio
313+
async def test_transformer_pipeline_stream():
314+
t = pgml.TransformerPipeline("text-generation")
315+
it = await t.transform_stream("AI is going to", {"max_new_tokens": 5})
316+
total = []
317+
async for c in it:
318+
total.append(c)
319+
assert (len(total)) > 0
320+
321+
301322
###################################################
302323
## Migration tests ################################
303324
###################################################

pgml-sdks/pgml/src/languages/javascript.rs

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
use futures::StreamExt;
12
use neon::prelude::*;
23
use rust_bridge::javascript::{FromJsType, IntoJsResult};
4+
use std::sync::Arc;
35

46
use crate::{
57
pipeline::PipelineSyncData,
8+
transformer_pipeline::TransformerStream,
69
types::{DateTime, Json},
710
};
811

@@ -16,8 +19,9 @@ impl IntoJsResult for DateTime {
1619
self,
1720
cx: &mut C,
1821
) -> JsResult<'b, Self::Output> {
19-
let date = neon::types::JsDate::new(cx, self.0.assume_utc().unix_timestamp() as f64 * 1000.0)
20-
.expect("Error converting to JS Date");
22+
let date =
23+
neon::types::JsDate::new(cx, self.0.assume_utc().unix_timestamp() as f64 * 1000.0)
24+
.expect("Error converting to JS Date");
2125
Ok(date)
2226
}
2327
}
@@ -69,6 +73,64 @@ impl IntoJsResult for PipelineSyncData {
6973
}
7074
}
7175

76+
#[derive(Clone)]
77+
struct TransformerStreamArcMutex(Arc<tokio::sync::Mutex<TransformerStream>>);
78+
79+
impl Finalize for TransformerStreamArcMutex {}
80+
81+
fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult<JsPromise> {
82+
let this = cx.this();
83+
let s: Handle<JsBox<TransformerStreamArcMutex>> = this
84+
.get(&mut cx, "s")
85+
.expect("Error getting self in transformer_stream_iterate_next");
86+
let ts: &TransformerStreamArcMutex = &s;
87+
let ts: TransformerStreamArcMutex = ts.clone();
88+
89+
let channel = cx.channel();
90+
let (deferred, promise) = cx.promise();
91+
crate::get_or_set_runtime().spawn(async move {
92+
let mut ts = ts.0.lock().await;
93+
let v = ts.next().await;
94+
deferred
95+
.try_settle_with(&channel, move |mut cx| {
96+
let o = cx.empty_object();
97+
if let Some(v) = v {
98+
let v: String = v.expect("Error calling next on TransformerStream");
99+
let v = cx.string(v);
100+
let d = cx.boolean(false);
101+
o.set(&mut cx, "value", v)
102+
.expect("Error setting object value in transformer_sream_iterate_next");
103+
o.set(&mut cx, "done", d)
104+
.expect("Error setting object value in transformer_sream_iterate_next");
105+
} else {
106+
let d = cx.boolean(true);
107+
o.set(&mut cx, "done", d)
108+
.expect("Error setting object value in transformer_sream_iterate_next");
109+
}
110+
Ok(o)
111+
})
112+
.expect("Error sending js");
113+
});
114+
Ok(promise)
115+
}
116+
117+
impl IntoJsResult for TransformerStream {
118+
type Output = JsObject;
119+
fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>(
120+
self,
121+
cx: &mut C,
122+
) -> JsResult<'b, Self::Output> {
123+
let o = cx.empty_object();
124+
let f: Handle<JsFunction> = JsFunction::new(cx, transform_stream_iterate_next)?;
125+
o.set(cx, "next", f)?;
126+
let s = cx.boxed(TransformerStreamArcMutex(Arc::new(
127+
tokio::sync::Mutex::new(self),
128+
)));
129+
o.set(cx, "s", s)?;
130+
Ok(o)
131+
}
132+
}
133+
72134
////////////////////////////////////////////////////////////////////////////////
73135
// JS To Rust //////////////////////////////////////////////////////////////////
74136
////////////////////////////////////////////////////////////////////////////////

pgml-sdks/pgml/src/languages/python.rs

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,99 @@
1+
use futures::StreamExt;
12
use pyo3::conversion::IntoPy;
23
use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString};
34
use pyo3::{prelude::*, types::PyBool};
5+
use std::sync::Arc;
46

57
use rust_bridge::python::CustomInto;
68

7-
use crate::{pipeline::PipelineSyncData, types::Json};
9+
use crate::{pipeline::PipelineSyncData, transformer_pipeline::TransformerStream, types::Json};
810

911
////////////////////////////////////////////////////////////////////////////////
1012
// Rust to PY //////////////////////////////////////////////////////////////////
1113
////////////////////////////////////////////////////////////////////////////////
1214

13-
impl ToPyObject for Json {
14-
fn to_object(&self, py: Python) -> PyObject {
15+
impl IntoPy<PyObject> for Json {
16+
fn into_py(self, py: Python) -> PyObject {
1517
match &self.0 {
16-
serde_json::Value::Bool(x) => x.to_object(py),
18+
serde_json::Value::Bool(x) => x.into_py(py),
1719
serde_json::Value::Number(x) => {
1820
if x.is_f64() {
1921
x.as_f64()
2022
.expect("Error converting to f64 in impl ToPyObject for Json")
21-
.to_object(py)
23+
.into_py(py)
2224
} else {
2325
x.as_i64()
2426
.expect("Error converting to i64 in impl ToPyObject for Json")
25-
.to_object(py)
27+
.into_py(py)
2628
}
2729
}
28-
serde_json::Value::String(x) => x.to_object(py),
30+
serde_json::Value::String(x) => x.into_py(py),
2931
serde_json::Value::Array(x) => {
3032
let list = PyList::empty(py);
3133
for v in x.iter() {
32-
list.append(Json(v.clone()).to_object(py)).unwrap();
34+
list.append(Json(v.clone()).into_py(py)).unwrap();
3335
}
34-
list.to_object(py)
36+
list.into_py(py)
3537
}
3638
serde_json::Value::Object(x) => {
3739
let dict = PyDict::new(py);
3840
for (k, v) in x.iter() {
39-
dict.set_item(k, Json(v.clone()).to_object(py)).unwrap();
41+
dict.set_item(k, Json(v.clone()).into_py(py)).unwrap();
4042
}
41-
dict.to_object(py)
43+
dict.into_py(py)
4244
}
4345
serde_json::Value::Null => py.None(),
4446
}
4547
}
4648
}
4749

48-
impl IntoPy<PyObject> for Json {
50+
impl IntoPy<PyObject> for PipelineSyncData {
4951
fn into_py(self, py: Python) -> PyObject {
50-
self.to_object(py)
52+
Json::from(self).into_py(py)
5153
}
5254
}
5355

54-
impl ToPyObject for PipelineSyncData {
55-
fn to_object(&self, py: Python) -> PyObject {
56-
Json::from(self.clone()).to_object(py)
56+
#[pyclass]
57+
#[derive(Clone)]
58+
struct TransformerStreamPython {
59+
wrapped: Arc<tokio::sync::Mutex<TransformerStream>>,
60+
}
61+
62+
#[pymethods]
63+
impl TransformerStreamPython {
64+
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
65+
slf
66+
}
67+
68+
fn __anext__<'p>(slf: PyRefMut<'_, Self>, py: Python<'p>) -> PyResult<Option<PyObject>> {
69+
let ts = slf.wrapped.clone();
70+
let fut = pyo3_asyncio::tokio::future_into_py(py, async move {
71+
let mut ts = ts.lock().await;
72+
if let Some(o) = ts.next().await {
73+
Ok(Some(Python::with_gil(|py| {
74+
o.expect("Error calling next on TransformerStream")
75+
.to_object(py)
76+
})))
77+
} else {
78+
Err(pyo3::exceptions::PyStopAsyncIteration::new_err(
79+
"stream exhausted",
80+
))
81+
}
82+
})?;
83+
Ok(Some(fut.into()))
5784
}
5885
}
5986

60-
impl IntoPy<PyObject> for PipelineSyncData {
87+
impl IntoPy<PyObject> for TransformerStream {
6188
fn into_py(self, py: Python) -> PyObject {
62-
self.to_object(py)
89+
let f: Py<TransformerStreamPython> = Py::new(
90+
py,
91+
TransformerStreamPython {
92+
wrapped: Arc::new(tokio::sync::Mutex::new(self)),
93+
},
94+
)
95+
.expect("Error converting TransformerStream to TransformerStreamPython");
96+
f.to_object(py)
6397
}
6498
}
6599

@@ -115,6 +149,12 @@ impl FromPyObject<'_> for PipelineSyncData {
115149
}
116150
}
117151

152+
impl FromPyObject<'_> for TransformerStream {
153+
fn extract(_ob: &PyAny) -> PyResult<Self> {
154+
panic!("We must implement this, but this is impossible to be reached")
155+
}
156+
}
157+
118158
////////////////////////////////////////////////////////////////////////////////
119159
// Rust to Rust //////////////////////////////////////////////////////////////////
120160
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy