Skip to content

Commit fb5b502

Browse files
authored
Added TransformerPipeline (#1128)
1 parent f2e4517 commit fb5b502

File tree

3 files changed

+106
-72
lines changed

3 files changed

+106
-72
lines changed

pgml-sdks/pgml/src/lib.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod query_builder;
2626
mod query_runner;
2727
mod remote_embeddings;
2828
mod splitter;
29+
mod transformer_pipeline;
2930
pub mod types;
3031
mod utils;
3132

@@ -35,6 +36,7 @@ pub use collection::Collection;
3536
pub use model::Model;
3637
pub use pipeline::Pipeline;
3738
pub use splitter::Splitter;
39+
pub use transformer_pipeline::TransformerPipeline;
3840

3941
// This is use when inserting collections to set the sdk_version used during creation
4042
static SDK_VERSION: &str = "0.9.2";
@@ -149,6 +151,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> {
149151
m.add_class::<model::ModelPython>()?;
150152
m.add_class::<splitter::SplitterPython>()?;
151153
m.add_class::<builtins::BuiltinsPython>()?;
154+
m.add_class::<transformer_pipeline::TransformerPipelinePython>()?;
152155
Ok(())
153156
}
154157

@@ -193,6 +196,10 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> {
193196
cx.export_function("newModel", model::ModelJavascript::new)?;
194197
cx.export_function("newSplitter", splitter::SplitterJavascript::new)?;
195198
cx.export_function("newBuiltins", builtins::BuiltinsJavascript::new)?;
199+
cx.export_function(
200+
"newTransformerPipeline",
201+
transformer_pipeline::TransformerPipelineJavascript::new,
202+
)?;
196203
cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?;
197204
Ok(())
198205
}
@@ -448,7 +455,6 @@ mod tests {
448455
Some("text-embedding-ada-002".to_string()),
449456
Some("openai".to_string()),
450457
None,
451-
None,
452458
);
453459
let splitter = Splitter::default();
454460
let mut pipeline = Pipeline::new(
@@ -527,7 +533,6 @@ mod tests {
527533
Some("hkunlp/instructor-base".to_string()),
528534
Some("python".to_string()),
529535
Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()),
530-
None,
531536
);
532537
let splitter = Splitter::default();
533538
let mut pipeline = Pipeline::new(
@@ -579,7 +584,6 @@ mod tests {
579584
Some("text-embedding-ada-002".to_string()),
580585
Some("openai".to_string()),
581586
None,
582-
None,
583587
);
584588
let splitter = Splitter::default();
585589
let mut pipeline = Pipeline::new(
@@ -660,7 +664,6 @@ mod tests {
660664
Some("text-embedding-ada-002".to_string()),
661665
Some("openai".to_string()),
662666
None,
663-
None,
664667
);
665668
let splitter = Splitter::default();
666669
let mut pipeline = Pipeline::new(

pgml-sdks/pgml/src/model.rs

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use anyhow::Context;
22
use rust_bridge::{alias, alias_methods};
3-
use serde_json::json;
43
use sqlx::postgres::PgPool;
5-
use sqlx::Row;
64
use tracing::instrument;
75

86
use crate::{
@@ -61,14 +59,11 @@ pub struct Model {
6159
pub parameters: Json,
6260
project_info: Option<ProjectInfo>,
6361
pub(crate) database_data: Option<ModelDatabaseData>,
64-
// This database_url is specifically used only for the model when calling transform and other
65-
// one-off methods
66-
database_url: Option<String>,
6762
}
6863

6964
impl Default for Model {
7065
fn default() -> Self {
71-
Self::new(None, None, None, None)
66+
Self::new(None, None, None)
7267
}
7368
}
7469

@@ -88,12 +83,7 @@ impl Model {
8883
/// use pgml::Model;
8984
/// let model = Model::new(Some("intfloat/e5-small".to_string()), None, None, None);
9085
/// ```
91-
pub fn new(
92-
name: Option<String>,
93-
source: Option<String>,
94-
parameters: Option<Json>,
95-
database_url: Option<String>,
96-
) -> Self {
86+
pub fn new(name: Option<String>, source: Option<String>, parameters: Option<Json>) -> Self {
9787
let name = name.unwrap_or("intfloat/e5-small".to_string());
9888
let parameters = parameters.unwrap_or(Json(serde_json::json!({})));
9989
let source = source.unwrap_or("pgml".to_string());
@@ -105,7 +95,6 @@ impl Model {
10595
parameters,
10696
project_info: None,
10797
database_data: None,
108-
database_url,
10998
}
11099
}
111100

@@ -191,30 +180,6 @@ impl Model {
191180
.database_url;
192181
get_or_initialize_pool(database_url).await
193182
}
194-
195-
pub async fn transform(
196-
&self,
197-
task: &str,
198-
inputs: Vec<String>,
199-
args: Option<Json>,
200-
) -> anyhow::Result<Json> {
201-
let pool = get_or_initialize_pool(&self.database_url).await?;
202-
let task = json!({
203-
"task": task,
204-
"model": self.name,
205-
"trust_remote_code": true
206-
});
207-
let args = args.unwrap_or_default();
208-
let query = sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)");
209-
let results = query
210-
.bind(task)
211-
.bind(inputs)
212-
.bind(&args)
213-
.fetch_all(&pool)
214-
.await?;
215-
let results = results.get(0).unwrap().get::<serde_json::Value, _>(0);
216-
Ok(Json(results))
217-
}
218183
}
219184

220185
impl From<models::PipelineWithModelAndSplitter> for Model {
@@ -228,7 +193,6 @@ impl From<models::PipelineWithModelAndSplitter> for Model {
228193
id: x.model_id,
229194
created_at: x.model_created_at,
230195
}),
231-
database_url: None,
232196
}
233197
}
234198
}
@@ -244,36 +208,6 @@ impl From<models::Model> for Model {
244208
id: model.id,
245209
created_at: model.created_at,
246210
}),
247-
database_url: None,
248211
}
249212
}
250213
}
251-
252-
#[cfg(test)]
253-
mod tests {
254-
use super::*;
255-
use crate::internal_init_logger;
256-
257-
#[sqlx::test]
258-
async fn model_can_transform() -> anyhow::Result<()> {
259-
internal_init_logger(None, None).ok();
260-
let model = Model::new(
261-
Some("Helsinki-NLP/opus-mt-en-fr".to_string()),
262-
Some("pgml".to_string()),
263-
None,
264-
None,
265-
);
266-
let results = model
267-
.transform(
268-
"translation",
269-
vec![
270-
"How are you doing today?".to_string(),
271-
"What is a good song?".to_string(),
272-
],
273-
None,
274-
)
275-
.await?;
276-
assert!(results.as_array().is_some());
277-
Ok(())
278-
}
279-
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
use rust_bridge::{alias, alias_methods};
2+
use sqlx::Row;
3+
use tracing::instrument;
4+
5+
/// Provides access to builtin database methods
6+
#[derive(alias, Debug, Clone)]
7+
pub struct TransformerPipeline {
8+
task: Json,
9+
database_url: Option<String>,
10+
}
11+
12+
use crate::{get_or_initialize_pool, types::Json};
13+
14+
#[cfg(feature = "python")]
15+
use crate::types::JsonPython;
16+
17+
#[alias_methods(new, transform)]
18+
impl TransformerPipeline {
19+
pub fn new(
20+
task: &str,
21+
model: Option<String>,
22+
args: Option<Json>,
23+
database_url: Option<String>,
24+
) -> Self {
25+
let mut args = args.unwrap_or_default();
26+
let a = args.as_object_mut().expect("args must be an object");
27+
a.insert("task".to_string(), task.to_string().into());
28+
if let Some(m) = model {
29+
a.insert("model".to_string(), m.into());
30+
}
31+
32+
Self {
33+
task: args,
34+
database_url,
35+
}
36+
}
37+
38+
#[instrument(skip(self))]
39+
pub async fn transform(&self, inputs: Vec<String>, args: Option<Json>) -> anyhow::Result<Json> {
40+
let pool = get_or_initialize_pool(&self.database_url).await?;
41+
let args = args.unwrap_or_default();
42+
43+
let results = sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)")
44+
.bind(&self.task)
45+
.bind(inputs)
46+
.bind(&args)
47+
.fetch_all(&pool)
48+
.await?;
49+
let results = results.get(0).unwrap().get::<serde_json::Value, _>(0);
50+
Ok(Json(results))
51+
}
52+
}
53+
54+
#[cfg(test)]
55+
mod tests {
56+
use super::*;
57+
use crate::internal_init_logger;
58+
59+
#[sqlx::test]
60+
async fn transformer_pipeline_can_transform() -> anyhow::Result<()> {
61+
internal_init_logger(None, None).ok();
62+
let t = TransformerPipeline::new(
63+
"translation_en_to_fr",
64+
Some("t5-base".to_string()),
65+
None,
66+
None,
67+
);
68+
let results = t
69+
.transform(
70+
vec![
71+
"How are you doing today?".to_string(),
72+
"What is a good song?".to_string(),
73+
],
74+
None,
75+
)
76+
.await?;
77+
assert!(results.as_array().is_some());
78+
Ok(())
79+
}
80+
81+
#[sqlx::test]
82+
async fn transformer_pipeline_can_transform_with_default_model() -> anyhow::Result<()> {
83+
internal_init_logger(None, None).ok();
84+
let t = TransformerPipeline::new("translation_en_to_fr", None, None, None);
85+
let results = t
86+
.transform(
87+
vec![
88+
"How are you doing today?".to_string(),
89+
"What is a good song?".to_string(),
90+
],
91+
None,
92+
)
93+
.await?;
94+
assert!(results.as_array().is_some());
95+
Ok(())
96+
}
97+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy