Skip to content

Commit 0be25d0

Browse files
authored
Added OpenSourceAI and conversational support in the extension (#1206)
1 parent dd18739 commit 0be25d0

File tree

16 files changed

+1140
-161
lines changed

16 files changed

+1140
-161
lines changed

pgml-extension/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.8.0"
3+
version = "2.8.1"
44
edition = "2021"
55

66
[lib]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
-- pgml::api::transform_conversational_json
2+
CREATE FUNCTION pgml."transform"(
3+
"task" jsonb, /* pgrx::datum::json::JsonB */
4+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
5+
"inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec<pgrx::datum::json::JsonB> */
6+
"cache" bool DEFAULT false /* bool */
7+
) RETURNS jsonb /* alloc::string::String */
8+
IMMUTABLE STRICT PARALLEL SAFE
9+
LANGUAGE c /* Rust */
10+
AS 'MODULE_PATHNAME', 'transform_conversational_json_wrapper';
11+
12+
-- pgml::api::transform_conversational_string
13+
CREATE FUNCTION pgml."transform"(
14+
"task" TEXT, /* alloc::string::String */
15+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
16+
"inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec<pgrx::datum::json::JsonB> */
17+
"cache" bool DEFAULT false /* bool */
18+
) RETURNS jsonb /* alloc::string::String */
19+
IMMUTABLE STRICT PARALLEL SAFE
20+
LANGUAGE c /* Rust */
21+
AS 'MODULE_PATHNAME', 'transform_conversational_string_wrapper';
22+
23+
-- pgml::api::transform_stream_string
24+
DROP FUNCTION IF EXISTS pgml."transform_stream"(text,jsonb,text,boolean);
25+
CREATE FUNCTION pgml."transform_stream"(
26+
"task" TEXT, /* alloc::string::String */
27+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
28+
"input" TEXT DEFAULT '', /* &str */
29+
"cache" bool DEFAULT false /* bool */
30+
) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */
31+
IMMUTABLE STRICT PARALLEL SAFE
32+
LANGUAGE c /* Rust */
33+
AS 'MODULE_PATHNAME', 'transform_stream_string_wrapper';
34+
35+
-- pgml::api::transform_stream_json
36+
DROP FUNCTION IF EXISTS pgml."transform_stream"(jsonb,jsonb,text,boolean);
37+
CREATE FUNCTION pgml."transform_stream"(
38+
"task" jsonb, /* pgrx::datum::json::JsonB */
39+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
40+
"input" TEXT DEFAULT '', /* &str */
41+
"cache" bool DEFAULT false /* bool */
42+
) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */
43+
IMMUTABLE STRICT PARALLEL SAFE
44+
LANGUAGE c /* Rust */
45+
AS 'MODULE_PATHNAME', 'transform_stream_json_wrapper';
46+
47+
-- pgml::api::transform_stream_conversational_json
48+
CREATE FUNCTION pgml."transform_stream"(
49+
"task" TEXT, /* alloc::string::String */
50+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
51+
"inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec<pgrx::datum::json::JsonB> */
52+
"cache" bool DEFAULT false /* bool */
53+
) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */
54+
IMMUTABLE STRICT PARALLEL SAFE
55+
LANGUAGE c /* Rust */
56+
AS 'MODULE_PATHNAME', 'transform_stream_conversational_string_wrapper';
57+
58+
-- pgml::api::transform_stream_conversational_string
59+
CREATE FUNCTION pgml."transform_stream"(
60+
"task" jsonb, /* pgrx::datum::json::JsonB */
61+
"args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */
62+
"inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec<pgrx::datum::json::JsonB> */
63+
"cache" bool DEFAULT false /* bool */
64+
) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */
65+
IMMUTABLE STRICT PARALLEL SAFE
66+
LANGUAGE c /* Rust */
67+
AS 'MODULE_PATHNAME', 'transform_stream_conversational_json_wrapper';

pgml-extension/src/api.rs

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,50 @@ pub fn transform_string(
632632
}
633633
}
634634

635+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
636+
#[pg_extern(immutable, parallel_safe, name = "transform")]
637+
#[allow(unused_variables)] // cache is maintained for api compatibility
638+
pub fn transform_conversational_json(
639+
task: JsonB,
640+
args: default!(JsonB, "'{}'"),
641+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
642+
cache: default!(bool, false),
643+
) -> JsonB {
644+
if !task.0["task"]
645+
.as_str()
646+
.is_some_and(|v| v == "conversational")
647+
{
648+
error!(
649+
"ARRAY[]::JSONB inputs for transform should only be used with a conversational task"
650+
);
651+
}
652+
match crate::bindings::transformers::transform(&task.0, &args.0, inputs) {
653+
Ok(output) => JsonB(output),
654+
Err(e) => error!("{e}"),
655+
}
656+
}
657+
658+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
659+
#[pg_extern(immutable, parallel_safe, name = "transform")]
660+
#[allow(unused_variables)] // cache is maintained for api compatibility
661+
pub fn transform_conversational_string(
662+
task: String,
663+
args: default!(JsonB, "'{}'"),
664+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
665+
cache: default!(bool, false),
666+
) -> JsonB {
667+
if task != "conversational" {
668+
error!(
669+
"ARRAY[]::JSONB inputs for transform should only be used with a conversational task"
670+
);
671+
}
672+
let task_json = json!({ "task": task });
673+
match crate::bindings::transformers::transform(&task_json, &args.0, inputs) {
674+
Ok(output) => JsonB(output),
675+
Err(e) => error!("{e}"),
676+
}
677+
}
678+
635679
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
636680
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
637681
#[allow(unused_variables)] // cache is maintained for api compatibility
@@ -640,7 +684,7 @@ pub fn transform_stream_json(
640684
args: default!(JsonB, "'{}'"),
641685
input: default!(&str, "''"),
642686
cache: default!(bool, false),
643-
) -> SetOfIterator<'static, String> {
687+
) -> SetOfIterator<'static, JsonB> {
644688
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
645689
let python_iter =
646690
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input)
@@ -657,7 +701,7 @@ pub fn transform_stream_string(
657701
args: default!(JsonB, "'{}'"),
658702
input: default!(&str, "''"),
659703
cache: default!(bool, false),
660-
) -> SetOfIterator<'static, String> {
704+
) -> SetOfIterator<'static, JsonB> {
661705
let task_json = json!({ "task": task });
662706
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
663707
let python_iter =
@@ -667,6 +711,54 @@ pub fn transform_stream_string(
667711
SetOfIterator::new(python_iter)
668712
}
669713

714+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
715+
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
716+
#[allow(unused_variables)] // cache is maintained for api compatibility
717+
pub fn transform_stream_conversational_json(
718+
task: JsonB,
719+
args: default!(JsonB, "'{}'"),
720+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
721+
cache: default!(bool, false),
722+
) -> SetOfIterator<'static, JsonB> {
723+
if !task.0["task"]
724+
.as_str()
725+
.is_some_and(|v| v == "conversational")
726+
{
727+
error!(
728+
"ARRAY[]::JSONB inputs for transform_stream should only be used with a conversational task"
729+
);
730+
}
731+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
732+
let python_iter =
733+
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs)
734+
.map_err(|e| error!("{e}"))
735+
.unwrap();
736+
SetOfIterator::new(python_iter)
737+
}
738+
739+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
740+
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
741+
#[allow(unused_variables)] // cache is maintained for api compatibility
742+
pub fn transform_stream_conversational_string(
743+
task: String,
744+
args: default!(JsonB, "'{}'"),
745+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
746+
cache: default!(bool, false),
747+
) -> SetOfIterator<'static, JsonB> {
748+
if task != "conversational" {
749+
error!(
750+
"ARRAY::JSONB inputs for transform_stream should only be used with a conversational task"
751+
);
752+
}
753+
let task_json = json!({ "task": task });
754+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
755+
let python_iter =
756+
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs)
757+
.map_err(|e| error!("{e}"))
758+
.unwrap();
759+
SetOfIterator::new(python_iter)
760+
}
761+
670762
#[cfg(feature = "python")]
671763
#[pg_extern(immutable, parallel_safe, name = "generate")]
672764
fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String {

pgml-extension/src/bindings/transformers/transform.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,28 @@ impl TransformStreamIterator {
2323
}
2424

2525
impl Iterator for TransformStreamIterator {
26-
type Item = String;
26+
type Item = JsonB;
2727
fn next(&mut self) -> Option<Self::Item> {
2828
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
29-
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
29+
Python::with_gil(|py| -> Result<Option<JsonB>, PyErr> {
3030
let code = "next(python_iter)";
3131
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
3232
if res.is_none() {
3333
Ok(None)
3434
} else {
35-
let res: String = res.extract()?;
36-
Ok(Some(res))
35+
let res: Vec<String> = res.extract()?;
36+
Ok(Some(JsonB(serde_json::to_value(res).unwrap())))
3737
}
3838
})
3939
.map_err(|e| error!("{e}"))
4040
.unwrap()
4141
}
4242
}
4343

44-
pub fn transform(
44+
pub fn transform<T: serde::Serialize>(
4545
task: &serde_json::Value,
4646
args: &serde_json::Value,
47-
inputs: Vec<&str>,
47+
inputs: T,
4848
) -> Result<serde_json::Value> {
4949
crate::bindings::python::activate()?;
5050
whitelist::verify_task(task)?;
@@ -74,17 +74,17 @@ pub fn transform(
7474
Ok(serde_json::from_str(&results)?)
7575
}
7676

77-
pub fn transform_stream(
77+
pub fn transform_stream<T: serde::Serialize>(
7878
task: &serde_json::Value,
7979
args: &serde_json::Value,
80-
input: &str,
80+
input: T,
8181
) -> Result<Py<PyAny>> {
8282
crate::bindings::python::activate()?;
8383
whitelist::verify_task(task)?;
8484

8585
let task = serde_json::to_string(task)?;
8686
let args = serde_json::to_string(args)?;
87-
let inputs = serde_json::to_string(&vec![input])?;
87+
let input = serde_json::to_string(&input)?;
8888

8989
Python::with_gil(|py| -> Result<Py<PyAny>> {
9090
let transform: Py<PyAny> = get_module!(PY_MODULE)
@@ -99,7 +99,7 @@ pub fn transform_stream(
9999
&[
100100
task.into_py(py),
101101
args.into_py(py),
102-
inputs.into_py(py),
102+
input.into_py(py),
103103
true.into_py(py),
104104
],
105105
),
@@ -110,10 +110,10 @@ pub fn transform_stream(
110110
})
111111
}
112112

113-
pub fn transform_stream_iterator(
113+
pub fn transform_stream_iterator<T: serde::Serialize>(
114114
task: &serde_json::Value,
115115
args: &serde_json::Value,
116-
input: &str,
116+
input: T,
117117
) -> Result<TransformStreamIterator> {
118118
let python_iter = transform_stream(task, args, input)
119119
.map_err(|e| error!("{e}"))

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy