Skip to content

Commit 5e80c26

Browse files
authored
SDK - Better async streaming and 1.0.2 bump
1 parent f2f5506 commit 5e80c26

File tree

5 files changed

+40
-114
lines changed

5 files changed

+40
-114
lines changed

pgml-sdks/pgml/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "1.0.1"
3+
version = "1.0.2"
44
edition = "2021"
55
authors = ["PosgresML <team@postgresml.org>"]
66
homepage = "https://postgresml.org/"

pgml-sdks/pgml/javascript/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "pgml",
3-
"version": "1.0.1",
3+
"version": "1.0.2",
44
"description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone",
55
"keywords": [
66
"postgres",

pgml-sdks/pgml/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "maturin"
55
[project]
66
name = "pgml"
77
requires-python = ">=3.7"
8-
version = "1.0.1"
8+
version = "1.0.2"
99
description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases."
1010
authors = [
1111
{name = "PostgresML", email = "team@postgresml.org"},

pgml-sdks/pgml/src/transformer_pipeline.rs

Lines changed: 34 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
use anyhow::Context;
2-
use futures::Stream;
32
use rust_bridge::{alias, alias_methods};
4-
use sqlx::{postgres::PgRow, Row};
5-
use sqlx::{Postgres, Transaction};
6-
use std::collections::VecDeque;
7-
use std::future::Future;
8-
use std::pin::Pin;
9-
use std::task::Poll;
3+
use sqlx::Row;
104
use tracing::instrument;
115

126
/// Provides access to builtin database methods
@@ -22,99 +16,6 @@ use crate::{get_or_initialize_pool, types::Json};
2216
#[cfg(feature = "python")]
2317
use crate::types::{GeneralJsonAsyncIteratorPython, JsonPython};
2418

25-
#[allow(clippy::type_complexity)]
26-
struct TransformerStream {
27-
transaction: Option<Transaction<'static, Postgres>>,
28-
future: Option<Pin<Box<dyn Future<Output = Result<Vec<PgRow>, sqlx::Error>> + Send + 'static>>>,
29-
commit: Option<Pin<Box<dyn Future<Output = Result<(), sqlx::Error>> + Send + 'static>>>,
30-
done: bool,
31-
query: String,
32-
db_batch_size: i32,
33-
results: VecDeque<PgRow>,
34-
}
35-
36-
impl std::fmt::Debug for TransformerStream {
37-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38-
f.debug_struct("TransformerStream").finish()
39-
}
40-
}
41-
42-
impl TransformerStream {
43-
fn new(transaction: Transaction<'static, Postgres>, db_batch_size: i32) -> Self {
44-
let query = format!("FETCH {} FROM c", db_batch_size);
45-
Self {
46-
transaction: Some(transaction),
47-
future: None,
48-
commit: None,
49-
done: false,
50-
query,
51-
db_batch_size,
52-
results: VecDeque::new(),
53-
}
54-
}
55-
}
56-
57-
impl Stream for TransformerStream {
58-
type Item = anyhow::Result<Json>;
59-
60-
fn poll_next(
61-
mut self: Pin<&mut Self>,
62-
cx: &mut std::task::Context<'_>,
63-
) -> Poll<Option<Self::Item>> {
64-
if self.done {
65-
if let Some(c) = self.commit.as_mut() {
66-
if c.as_mut().poll(cx).is_ready() {
67-
self.commit = None;
68-
}
69-
}
70-
} else {
71-
if self.future.is_none() {
72-
unsafe {
73-
let s = self.as_mut().get_unchecked_mut();
74-
let s: *mut Self = s;
75-
let s = Box::leak(Box::from_raw(s));
76-
s.future = Some(Box::pin(
77-
sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()),
78-
));
79-
}
80-
}
81-
82-
if let Poll::Ready(o) = self.as_mut().future.as_mut().unwrap().as_mut().poll(cx) {
83-
let rows = o?;
84-
if rows.len() < self.db_batch_size as usize {
85-
self.done = true;
86-
unsafe {
87-
let s = self.as_mut().get_unchecked_mut();
88-
let transaction = std::mem::take(&mut s.transaction).unwrap();
89-
s.commit = Some(Box::pin(transaction.commit()));
90-
}
91-
} else {
92-
unsafe {
93-
let s = self.as_mut().get_unchecked_mut();
94-
let s: *mut Self = s;
95-
let s = Box::leak(Box::from_raw(s));
96-
s.future = Some(Box::pin(
97-
sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()),
98-
));
99-
}
100-
}
101-
for r in rows.into_iter() {
102-
self.results.push_back(r)
103-
}
104-
}
105-
}
106-
107-
if !self.results.is_empty() {
108-
let r = self.results.pop_front().unwrap();
109-
Poll::Ready(Some(Ok(r.get::<Json, _>(0))))
110-
} else if self.done {
111-
Poll::Ready(None)
112-
} else {
113-
Poll::Pending
114-
}
115-
}
116-
}
117-
11819
#[alias_methods(new, transform, transform_stream)]
11920
impl TransformerPipeline {
12021
/// Creates a new [TransformerPipeline]
@@ -200,7 +101,7 @@ impl TransformerPipeline {
200101
) -> anyhow::Result<GeneralJsonAsyncIterator> {
201102
let pool = get_or_initialize_pool(&self.database_url).await?;
202103
let args = args.unwrap_or_default();
203-
let batch_size = batch_size.unwrap_or(10);
104+
let batch_size = batch_size.unwrap_or(1);
204105

205106
let mut transaction = pool.begin().await?;
206107
// We set the task in the new constructor so we can unwrap here
@@ -234,10 +135,37 @@ impl TransformerPipeline {
234135
.await?;
235136
}
236137

237-
Ok(GeneralJsonAsyncIterator(Box::pin(TransformerStream::new(
238-
transaction,
239-
batch_size,
240-
))))
138+
let s = futures::stream::try_unfold(transaction, move |mut transaction| async move {
139+
let query = format!("FETCH {} FROM c", batch_size);
140+
let mut res: Vec<Json> = sqlx::query_scalar(&query)
141+
.fetch_all(&mut *transaction)
142+
.await?;
143+
if !res.is_empty() {
144+
if batch_size > 1 {
145+
let res: Vec<String> = res
146+
.into_iter()
147+
.map(|v| {
148+
v.0.as_array()
149+
.context("internal SDK error - cannot parse db value as array. Please post a new github issue")
150+
.map(|v| {
151+
v[0].as_str()
152+
.context(
153+
"internal SDK error - cannot parse db value as string. Please post a new github issue",
154+
)
155+
.map(|v| v.to_owned())
156+
})
157+
})
158+
.collect::<anyhow::Result<anyhow::Result<Vec<String>>>>()??;
159+
Ok(Some((serde_json::json!(res).into(), transaction)))
160+
} else {
161+
Ok(Some((std::mem::take(&mut res[0]), transaction)))
162+
}
163+
} else {
164+
transaction.commit().await?;
165+
Ok(None)
166+
}
167+
});
168+
Ok(GeneralJsonAsyncIterator(Box::pin(s)))
241169
}
242170
}
243171

@@ -305,7 +233,7 @@ mod tests {
305233
serde_json::json!("AI is going to").into(),
306234
Some(
307235
serde_json::json!({
308-
"max_new_tokens": 10
236+
"max_new_tokens": 30
309237
})
310238
.into(),
311239
),

pgml-sdks/pgml/src/types.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use anyhow::Context;
2-
use futures::{Stream, StreamExt};
2+
use futures::{stream::BoxStream, Stream, StreamExt};
33
use itertools::Itertools;
44
use rust_bridge::alias_manual;
55
use sea_query::Iden;
@@ -123,11 +123,9 @@ impl IntoTableNameAndSchema for String {
123123
}
124124
}
125125

126-
/// A wrapper around `std::pin::Pin<Box<dyn Stream<Item = anyhow::Result<Json>> + Send>>`
126+
/// A wrapper around `BoxStream<'static, anyhow::Result<Json>>`
127127
#[derive(alias_manual)]
128-
pub struct GeneralJsonAsyncIterator(
129-
pub std::pin::Pin<Box<dyn Stream<Item = anyhow::Result<Json>> + Send>>,
130-
);
128+
pub struct GeneralJsonAsyncIterator(pub BoxStream<'static, anyhow::Result<Json>>);
131129

132130
impl Stream for GeneralJsonAsyncIterator {
133131
type Item = anyhow::Result<Json>;

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy