diff --git a/pgml-extension/examples/chunking.sql b/pgml-extension/examples/chunking.sql new file mode 100644 index 000000000..f8559ef7c --- /dev/null +++ b/pgml-extension/examples/chunking.sql @@ -0,0 +1,62 @@ +--- Chunk text for LLM embeddings and vectorization. + +DROP TABLE documents CASCADE; +CREATE TABLE documents ( + id BIGSERIAL PRIMARY KEY, + document TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +DROP TABLE splitters CASCADE; +CREATE TABLE splitters ( + id BIGSERIAL PRIMARY KEY, + splitter VARCHAR NOT NULL DEFAULT 'recursive_character' +); + +DROP TABLE document_chunks CASCADE; +CREATE TABLE document_chunks( + id BIGSERIAL PRIMARY KEY, + document_id BIGINT NOT NULL REFERENCES documents(id), + splitter_id BIGINT NOT NULL REFERENCES splitters(id), + chunk_index BIGINT NOT NULL, + chunk VARCHAR +); + +INSERT INTO documents VALUES ( + 1, + 'It was the best of times, it was the worst of times, it was the age of wisdom, + it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity, it was the season of Light, + it was the season of Darkness, it was the spring of hope, it was the winter of despair, we had everything before us, + we had nothing before us, we were all going direct to Heaven, we were all going direct the other way—in short, the period was so far like + the present period, that some of its noisiest authorities insisted on its being received, for good or for evil, in the superlative degree of comparison only.', + NOW() +); + +INSERT INTO splitters VALUES (1, 'recursive_character'); + +WITH document AS ( + SELECT id, document + FROM documents + WHERE id = 1 +), + +splitter AS ( + SELECT id, splitter + FROM splitters + WHERE id = 1 +) + +INSERT INTO document_chunks SELECT + nextval('document_chunks_id_seq'::regclass), + (SELECT id FROM document), + (SELECT id FROM splitter), + chunk_index, + chunk +FROM + pgml.chunk( + (SELECT splitter FROM splitter), + (SELECT document FROM document), + '{"chunk_size": 2, "chunk_overlap": 2}' + ); + +SELECT * FROM document_chunks LIMIT 5; diff --git a/pgml-extension/requirements.txt b/pgml-extension/requirements.txt index 1d766a091..405dc0a70 100644 --- a/pgml-extension/requirements.txt +++ b/pgml-extension/requirements.txt @@ -18,3 +18,4 @@ torchvision==0.15.2 tqdm==4.65.0 transformers==4.29.2 xgboost==1.7.5 +langchain==0.0.180 diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index e3ecff1e4..914952e91 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -565,7 +565,23 @@ fn load_dataset( #[pg_extern(immutable, parallel_safe)] pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec { - crate::bindings::transformers::embed(transformer, &text, &kwargs.0) + crate::bindings::transformers::embed(transformer, text, &kwargs.0) +} + +#[pg_extern(immutable, parallel_safe)] +pub fn chunk( + splitter: &str, + text: &str, + kwargs: default!(JsonB, "'{}'"), +) -> TableIterator<'static, (name!(chunk_index, i64), name!(chunk, String))> { + let chunks = crate::bindings::langchain::chunk(splitter, text, &kwargs.0); + let chunks = chunks + .into_iter() + .enumerate() + .map(|(i, chunk)| (i as i64 + 1, chunk)) + .collect::>(); + + TableIterator::new(chunks.into_iter()) } #[cfg(feature = "python")] @@ -575,7 +591,7 @@ pub fn transform_json( task: JsonB, args: default!(JsonB, "'{}'"), inputs: default!(Vec, "ARRAY[]::TEXT[]"), - cache: default!(bool, false) + cache: default!(bool, false), ) -> JsonB { JsonB(crate::bindings::transformers::transform( &task.0, &args.0, &inputs, @@ -589,7 +605,7 @@ pub fn transform_string( task: String, args: default!(JsonB, "'{}'"), inputs: default!(Vec, "ARRAY[]::TEXT[]"), - cache: default!(bool, false) + cache: default!(bool, false), ) -> JsonB { let mut task_map = HashMap::new(); task_map.insert("task", task); diff --git a/pgml-extension/src/bindings/langchain.py b/pgml-extension/src/bindings/langchain.py new file mode 100644 index 000000000..7bd224230 --- /dev/null +++ b/pgml-extension/src/bindings/langchain.py @@ -0,0 +1,29 @@ +from langchain.text_splitter import ( + CharacterTextSplitter, + LatexTextSplitter, + MarkdownTextSplitter, + NLTKTextSplitter, + PythonCodeTextSplitter, + RecursiveCharacterTextSplitter, + SpacyTextSplitter, +) +import json + +SPLITTERS = { + "character": CharacterTextSplitter, + "latex": LatexTextSplitter, + "markdown": MarkdownTextSplitter, + "nltk": NLTKTextSplitter, + "python": PythonCodeTextSplitter, + "recursive_character": RecursiveCharacterTextSplitter, + "spacy": SpacyTextSplitter, +} + + +def chunk(splitter, text, args): + kwargs = json.loads(args) + + if splitter in SPLITTERS: + return SPLITTERS[splitter](**kwargs).split_text(text) + else: + raise ValueError("Unknown splitter: {}".format(splitter)) diff --git a/pgml-extension/src/bindings/langchain.rs b/pgml-extension/src/bindings/langchain.rs new file mode 100644 index 000000000..61b3d61ef --- /dev/null +++ b/pgml-extension/src/bindings/langchain.rs @@ -0,0 +1,37 @@ +use once_cell::sync::Lazy; +use pgrx::*; +use pyo3::prelude::*; +use pyo3::types::PyTuple; + +static PY_MODULE: Lazy> = Lazy::new(|| { + Python::with_gil(|py| -> Py { + let src = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/bindings/langchain.py" + )); + + PyModule::from_code(py, src, "", "").unwrap().into() + }) +}); + +pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Vec { + crate::bindings::venv::activate(); + + let kwargs = serde_json::to_string(kwargs).unwrap(); + + Python::with_gil(|py| -> Vec { + let chunk: Py = PY_MODULE.getattr(py, "chunk").unwrap().into(); + + chunk + .call1( + py, + PyTuple::new( + py, + &[splitter.into_py(py), text.into_py(py), kwargs.into_py(py)], + ), + ) + .unwrap() + .extract(py) + .unwrap() + }) +} diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index b147a8104..77a46e161 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -5,6 +5,8 @@ use pgrx::*; use crate::orm::*; +#[cfg(feature = "python")] +pub mod langchain; pub mod lightgbm; pub mod linfa; #[cfg(feature = "python")] diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 9ce3825f4..8f296d812 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -40,11 +40,7 @@ pub fn transform( py, PyTuple::new( py, - &[ - task.into_py(py), - args.into_py(py), - inputs.into_py(py), - ], + &[task.into_py(py), args.into_py(py), inputs.into_py(py)], ), ) .unwrap() diff --git a/pgml-extension/tests/test.sql b/pgml-extension/tests/test.sql index 1c4dd614b..c1ef81d58 100644 --- a/pgml-extension/tests/test.sql +++ b/pgml-extension/tests/test.sql @@ -27,5 +27,6 @@ SELECT pgml.load_dataset('wine'); \i examples/multi_classification.sql \i examples/regression.sql \i examples/vectors.sql +\i examples/chunking.sql -- transformers are generally too slow to run in the test suite --\i examples/transformers.sql pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy