Skip to content

Commit ec8b1e3

Browse files
committed
Working RAG in the SDK - still needs some cleanup
1 parent 89c2f54 commit ec8b1e3

File tree

4 files changed

+510
-21
lines changed

4 files changed

+510
-21
lines changed

pgml-sdks/pgml/src/collection.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use walkdir::WalkDir;
1818

1919
use crate::debug_sqlx_query;
2020
use crate::filter_builder::FilterBuilder;
21+
use crate::rag_query_builder::build_rag_query;
2122
use crate::search_query_builder::build_search_query;
2223
use crate::vector_search_query_builder::build_vector_search_query;
2324
use crate::{
@@ -312,6 +313,9 @@ impl Collection {
312313

313314
let mp = MultiProgress::new();
314315
mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?;
316+
317+
// TODO: Revisit this. If the pipeline is added but fails to sync, then it will be "out of sync" with the documents in the table
318+
// This is rare, but could happen
315319
pipeline
316320
.resync(project_info, pool.acquire().await?.as_mut())
317321
.await?;
@@ -1020,6 +1024,24 @@ impl Collection {
10201024
.collect())
10211025
}
10221026

1027+
#[instrument(skip(self))]
1028+
pub async fn rag(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result<String> {
1029+
let pool = get_or_initialize_pool(&self.database_url).await?;
1030+
let (built_query, values) = build_rag_query(query.clone(), self, pipeline).await?;
1031+
let results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values)
1032+
.fetch_all(&pool)
1033+
.await?;
1034+
Ok(results[0]
1035+
.0
1036+
.as_array()
1037+
.context("Error converting LLM response to Array")?
1038+
.first()
1039+
.context("Error getting first LLM response")?
1040+
.as_str()
1041+
.context("Error converting LLM response to string")?
1042+
.to_owned())
1043+
}
1044+
10231045
/// Archives a [Collection]
10241046
/// This will free up the name to be reused. It does not delete it.
10251047
///

pgml-sdks/pgml/src/lib.rs

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ mod pipeline;
2929
mod queries;
3030
mod query_builder;
3131
mod query_runner;
32+
mod rag_query_builder;
3233
mod remote_embeddings;
3334
mod search_query_builder;
3435
mod single_field_pipeline;
@@ -1813,4 +1814,173 @@ mod tests {
18131814
collection.archive().await?;
18141815
Ok(())
18151816
}
1817+
1818+
///////////////////////////////
1819+
// RAG ////////////////////////
1820+
///////////////////////////////
1821+
1822+
#[tokio::test]
1823+
async fn can_rag_with_local_embeddings() -> anyhow::Result<()> {
1824+
internal_init_logger(None, None).ok();
1825+
let collection_name = "test r_c_crwle_1";
1826+
let mut collection = Collection::new(collection_name, None)?;
1827+
let documents = generate_dummy_documents(10);
1828+
collection.upsert_documents(documents.clone(), None).await?;
1829+
let pipeline_name = "0";
1830+
let mut pipeline = Pipeline::new(
1831+
pipeline_name,
1832+
Some(
1833+
json!({
1834+
"body": {
1835+
"splitter": {
1836+
"model": "recursive_character"
1837+
},
1838+
"semantic_search": {
1839+
"model": "intfloat/e5-small-v2",
1840+
"parameters": {
1841+
"prompt": "passage: "
1842+
}
1843+
},
1844+
},
1845+
})
1846+
.into(),
1847+
),
1848+
)?;
1849+
collection.add_pipeline(&mut pipeline).await?;
1850+
// Single variable test
1851+
let results = collection
1852+
.rag(
1853+
json!({
1854+
"CONTEXT": {
1855+
"vector_search": {
1856+
"query": {
1857+
"fields": {
1858+
"body": {
1859+
"query": "Test document: 2",
1860+
"boost": 1.0,
1861+
"parameters": {
1862+
"prompt": "query: "
1863+
}
1864+
},
1865+
},
1866+
},
1867+
"limit": 5
1868+
},
1869+
"aggregate": {
1870+
"join": "\n"
1871+
}
1872+
},
1873+
"completion": {
1874+
"model": "mistralai/Mistral-7B-Instruct-v0.2",
1875+
"prompt": "Some text with {CONTEXT}",
1876+
"temperature": 0.7
1877+
}
1878+
})
1879+
.into(),
1880+
&mut pipeline,
1881+
)
1882+
.await?;
1883+
eprintln!("{}", results);
1884+
1885+
// Multi-variable test
1886+
let results = collection
1887+
.rag(
1888+
json!({
1889+
"CONTEXT": {
1890+
"vector_search": {
1891+
"query": {
1892+
"fields": {
1893+
"body": {
1894+
"query": "Test document: 2",
1895+
"boost": 1.0,
1896+
"parameters": {
1897+
"prompt": "query: "
1898+
}
1899+
},
1900+
},
1901+
},
1902+
"limit": 2
1903+
},
1904+
"aggregate": {
1905+
"join": "\n"
1906+
}
1907+
},
1908+
"CONTEXT2": {
1909+
"vector_search": {
1910+
"query": {
1911+
"fields": {
1912+
"body": {
1913+
"query": "Test document: 3",
1914+
"boost": 1.0,
1915+
"parameters": {
1916+
"prompt": "query: "
1917+
}
1918+
},
1919+
}
1920+
},
1921+
"limit": 2
1922+
},
1923+
"aggregate": {
1924+
"join": "\n"
1925+
}
1926+
},
1927+
"completion": {
1928+
"model": "mistralai/Mistral-7B-Instruct-v0.2",
1929+
"prompt": "Some text with {CONTEXT} AND {CONTEXT2}",
1930+
"temperature": 0.7
1931+
}
1932+
})
1933+
.into(),
1934+
&mut pipeline,
1935+
)
1936+
.await?;
1937+
eprintln!("{}", results);
1938+
1939+
// Chat test
1940+
let results = collection
1941+
.rag(
1942+
json!({
1943+
"CONTEXT": {
1944+
"vector_search": {
1945+
"query": {
1946+
"fields": {
1947+
"body": {
1948+
"query": "Test document: 2",
1949+
"boost": 1.0,
1950+
"parameters": {
1951+
"prompt": "query: "
1952+
}
1953+
},
1954+
},
1955+
},
1956+
"limit": 2
1957+
},
1958+
"aggregate": {
1959+
"join": "\n"
1960+
}
1961+
},
1962+
"chat": {
1963+
"model": "mistralai/Mistral-7B-Instruct-v0.2",
1964+
"messages": [
1965+
// {
1966+
// "role": "system",
1967+
// "content": "You are a friendly and helpful chatbot"
1968+
// },
1969+
{
1970+
"role": "user",
1971+
"content": "Some text with {CONTEXT}",
1972+
}
1973+
],
1974+
"temperature": 0.7
1975+
}
1976+
})
1977+
.into(),
1978+
&mut pipeline,
1979+
)
1980+
.await?;
1981+
eprintln!("{}", results);
1982+
1983+
// collection.archive().await?;
1984+
Ok(())
1985+
}
18161986
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy