Skip to content

Commit 82fcbd4

Browse files
authored
Chatbot page is almost ready to go (#1054)
1 parent a2908fa commit 82fcbd4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2992
-614
lines changed

pgml-dashboard/Cargo.lock

Lines changed: 455 additions & 59 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-dashboard/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,5 @@ pgvector = { version = "0.2.2", features = [ "sqlx", "postgres" ] }
4545
console-subscriber = "*"
4646
glob = "*"
4747
pgml-components = { path = "../packages/pgml-components" }
48+
reqwest = { version = "0.11.20", features = ["json"] }
49+
pgml = { version = "0.9.2", path = "../pgml-sdks/pgml/" }

pgml-dashboard/package-lock.json

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-dashboard/package.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"dependencies": {
3+
"autosize": "^6.0.1",
4+
"dompurify": "^3.0.6",
5+
"marked": "^9.1.0"
6+
}
7+
}

pgml-dashboard/src/api/chatbot.rs

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
use anyhow::Context;
2+
use pgml::{Collection, Pipeline};
3+
use rand::{distributions::Alphanumeric, Rng};
4+
use reqwest::Client;
5+
use rocket::{
6+
http::Status,
7+
outcome::IntoOutcome,
8+
request::{self, FromRequest},
9+
route::Route,
10+
serde::json::Json,
11+
Request,
12+
};
13+
use serde::{Deserialize, Serialize};
14+
use serde_json::json;
15+
use std::time::{SystemTime, UNIX_EPOCH};
16+
17+
use crate::{
18+
forms,
19+
responses::{Error, ResponseOk},
20+
};
21+
22+
pub struct User {
23+
chatbot_session_id: String,
24+
}
25+
26+
#[rocket::async_trait]
27+
impl<'r> FromRequest<'r> for User {
28+
type Error = ();
29+
30+
async fn from_request(request: &'r Request<'_>) -> request::Outcome<User, ()> {
31+
request
32+
.cookies()
33+
.get_private("chatbot_session_id")
34+
.map(|c| User {
35+
chatbot_session_id: c.value().to_string(),
36+
})
37+
.or_forward(Status::Unauthorized)
38+
}
39+
}
40+
41+
#[derive(Serialize, Deserialize, PartialEq, Eq)]
42+
enum ChatRole {
43+
User,
44+
Bot,
45+
}
46+
47+
#[derive(Clone, Copy, Serialize, Deserialize)]
48+
enum ChatbotBrain {
49+
OpenAIGPT4,
50+
PostgresMLFalcon180b,
51+
AnthropicClaude,
52+
MetaLlama2,
53+
}
54+
55+
impl TryFrom<u8> for ChatbotBrain {
56+
type Error = anyhow::Error;
57+
58+
fn try_from(value: u8) -> anyhow::Result<Self> {
59+
match value {
60+
0 => Ok(ChatbotBrain::OpenAIGPT4),
61+
1 => Ok(ChatbotBrain::PostgresMLFalcon180b),
62+
2 => Ok(ChatbotBrain::AnthropicClaude),
63+
3 => Ok(ChatbotBrain::MetaLlama2),
64+
_ => Err(anyhow::anyhow!("Invalid brain id")),
65+
}
66+
}
67+
}
68+
69+
#[derive(Clone, Copy, Serialize, Deserialize)]
70+
enum KnowledgeBase {
71+
PostgresML,
72+
PyTorch,
73+
Rust,
74+
PostgreSQL,
75+
}
76+
77+
impl KnowledgeBase {
78+
// The topic and knowledge base are the same for now but may be different later
79+
fn topic(&self) -> &'static str {
80+
match self {
81+
Self::PostgresML => "PostgresML",
82+
Self::PyTorch => "PyTorch",
83+
Self::Rust => "Rust",
84+
Self::PostgreSQL => "PostgreSQL",
85+
}
86+
}
87+
88+
fn collection(&self) -> &'static str {
89+
match self {
90+
Self::PostgresML => "PostgresML",
91+
Self::PyTorch => "PyTorch",
92+
Self::Rust => "Rust",
93+
Self::PostgreSQL => "PostgreSQL",
94+
}
95+
}
96+
}
97+
98+
impl TryFrom<u8> for KnowledgeBase {
99+
type Error = anyhow::Error;
100+
101+
fn try_from(value: u8) -> anyhow::Result<Self> {
102+
match value {
103+
0 => Ok(KnowledgeBase::PostgresML),
104+
1 => Ok(KnowledgeBase::PyTorch),
105+
2 => Ok(KnowledgeBase::Rust),
106+
3 => Ok(KnowledgeBase::PostgreSQL),
107+
_ => Err(anyhow::anyhow!("Invalid knowledge base id")),
108+
}
109+
}
110+
}
111+
112+
#[derive(Serialize, Deserialize)]
113+
struct Document {
114+
id: String,
115+
text: String,
116+
role: ChatRole,
117+
user_id: String,
118+
model: ChatbotBrain,
119+
knowledge_base: KnowledgeBase,
120+
timestamp: u128,
121+
}
122+
123+
impl Document {
124+
fn new(text: String, role: ChatRole, user_id: String, model: ChatbotBrain, knowledge_base: KnowledgeBase) -> Document {
125+
let id = rand::thread_rng()
126+
.sample_iter(&Alphanumeric)
127+
.take(32)
128+
.map(char::from)
129+
.collect();
130+
let timestamp = SystemTime::now()
131+
.duration_since(UNIX_EPOCH)
132+
.unwrap()
133+
.as_millis();
134+
Document {
135+
id,
136+
text,
137+
role,
138+
user_id,
139+
model,
140+
knowledge_base,
141+
timestamp,
142+
}
143+
}
144+
}
145+
146+
async fn get_openai_chatgpt_answer(
147+
knowledge_base: KnowledgeBase,
148+
history: &str,
149+
context: &str,
150+
question: &str,
151+
) -> Result<String, Error> {
152+
let openai_api_key = std::env::var("OPENAI_API_KEY")?;
153+
let base_prompt = std::env::var("CHATBOT_CHATGPT_BASE_PROMPT")?;
154+
let system_prompt = std::env::var("CHATBOT_CHATGPT_SYSTEM_PROMPT")?;
155+
156+
let system_prompt = system_prompt
157+
.replace("{topic}", knowledge_base.topic())
158+
.replace("{persona}", "Engineer")
159+
.replace("{language}", "English");
160+
161+
let content = base_prompt
162+
.replace("{history}", history)
163+
.replace("{context}", context)
164+
.replace("{question}", question);
165+
166+
let body = json!({
167+
"model": "gpt-4",
168+
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": content}],
169+
"temperature": 0.7
170+
});
171+
172+
let response = Client::new()
173+
.post("https://api.openai.com/v1/chat/completions")
174+
.bearer_auth(openai_api_key)
175+
.json(&body)
176+
.send()
177+
.await?
178+
.json::<serde_json::Value>()
179+
.await?;
180+
181+
let response = response["choices"]
182+
.as_array()
183+
.context("No data returned from OpenAI")?[0]["message"]["content"]
184+
.as_str()
185+
.context("The reponse content from OpenAI was not a string")?
186+
.to_string();
187+
188+
Ok(response)
189+
}
190+
191+
#[post("/chatbot/get-answer", format = "json", data = "<data>")]
192+
pub async fn chatbot_get_answer(
193+
user: User,
194+
data: Json<forms::ChatbotPostData>,
195+
) -> Result<ResponseOk, Error> {
196+
match wrapped_chatbot_get_answer(user, data).await {
197+
Ok(response) => Ok(ResponseOk(
198+
json!({
199+
"answer": response,
200+
})
201+
.to_string(),
202+
)),
203+
Err(error) => {
204+
eprintln!("Error: {:?}", error);
205+
Ok(ResponseOk(
206+
json!({
207+
"error": error.to_string(),
208+
})
209+
.to_string(),
210+
))
211+
}
212+
}
213+
}
214+
215+
pub async fn wrapped_chatbot_get_answer(
216+
user: User,
217+
data: Json<forms::ChatbotPostData>,
218+
) -> Result<String, Error> {
219+
let brain = ChatbotBrain::try_from(data.model)?;
220+
let knowledge_base = KnowledgeBase::try_from(data.knowledge_base)?;
221+
222+
// Create it up here so the timestamps that order the conversation are accurate
223+
let user_document = Document::new(
224+
data.question.clone(),
225+
ChatRole::User,
226+
user.chatbot_session_id.clone(),
227+
brain,
228+
knowledge_base
229+
);
230+
231+
let collection = knowledge_base.collection();
232+
let collection = Collection::new(
233+
collection,
234+
Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")),
235+
);
236+
237+
let mut history_collection = Collection::new(
238+
"ChatHistory",
239+
Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")),
240+
);
241+
let messages = history_collection
242+
.get_documents(Some(
243+
json!({
244+
"limit": 5,
245+
"order_by": {"timestamp": "desc"},
246+
"filter": {
247+
"metadata": {
248+
"$and" : [
249+
{
250+
"$or":
251+
[
252+
{"role": {"$eq": ChatRole::Bot}},
253+
{"role": {"$eq": ChatRole::User}}
254+
]
255+
},
256+
{
257+
"user_id": {
258+
"$eq": user.chatbot_session_id
259+
}
260+
},
261+
{
262+
"knowledge_base": {
263+
"$eq": knowledge_base
264+
}
265+
},
266+
{
267+
"model": {
268+
"$eq": brain
269+
}
270+
}
271+
]
272+
}
273+
}
274+
275+
})
276+
.into(),
277+
))
278+
.await?;
279+
280+
let mut history = messages
281+
.into_iter()
282+
.map(|m| {
283+
// Can probably remove this clone
284+
let chat_role: ChatRole = serde_json::from_value(m["document"]["role"].to_owned())?;
285+
if chat_role == ChatRole::Bot {
286+
Ok(format!("Assistant: {}", m["document"]["text"]))
287+
} else {
288+
Ok(format!("User: {}", m["document"]["text"]))
289+
}
290+
})
291+
.collect::<anyhow::Result<Vec<String>>>()?;
292+
history.reverse();
293+
let history = history.join("\n");
294+
295+
let mut pipeline = Pipeline::new("v1", None, None, None);
296+
let context = collection
297+
.query()
298+
.vector_recall(&data.question, &mut pipeline, Some(json!({
299+
"instruction": "Represent the Wikipedia question for retrieving supporting documents: "
300+
}).into()))
301+
.limit(5)
302+
.fetch_all()
303+
.await?
304+
.into_iter()
305+
.map(|(_, context, metadata)| format!("#### Document {}: {}", metadata["id"], context))
306+
.collect::<Vec<String>>()
307+
.join("\n");
308+
309+
let answer = match brain {
310+
_ => get_openai_chatgpt_answer(knowledge_base, &history, &context, &data.question).await,
311+
}?;
312+
313+
let new_history_messages: Vec<pgml::types::Json> = vec![
314+
serde_json::to_value(user_document).unwrap().into(),
315+
serde_json::to_value(Document::new(
316+
answer.clone(),
317+
ChatRole::Bot,
318+
user.chatbot_session_id.clone(),
319+
brain,
320+
knowledge_base
321+
))
322+
.unwrap()
323+
.into(),
324+
];
325+
326+
// We do not want to block our return waiting for this to happen
327+
tokio::spawn(async move {
328+
history_collection
329+
.upsert_documents(new_history_messages, None)
330+
.await.expect("Failed to upsert user history");
331+
});
332+
333+
Ok(answer)
334+
}
335+
336+
pub fn routes() -> Vec<Route> {
337+
routes![chatbot_get_answer]
338+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy