diff --git a/pgml-dashboard/Cargo.lock b/pgml-dashboard/Cargo.lock index f28f4db10..121b52b3b 100644 --- a/pgml-dashboard/Cargo.lock +++ b/pgml-dashboard/Cargo.lock @@ -854,6 +854,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "data-encoding" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" + [[package]] name = "debugid" version = "0.8.0" @@ -1229,9 +1235,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" dependencies = [ "futures-channel", "futures-core", @@ -1244,9 +1250,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", "futures-sink", @@ -1254,15 +1260,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" dependencies = [ "futures-core", "futures-task", @@ -1282,15 +1288,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", @@ -1299,21 +1305,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-channel", "futures-core", @@ -2515,6 +2521,7 @@ dependencies = [ "csv-async", "dotenv", "env_logger", + "futures", "glob", "itertools", "lazy_static", @@ -2530,6 +2537,7 @@ dependencies = [ "regex", "reqwest", "rocket", + "rocket_ws", "sailfish", "scraper", "sentry", @@ -3041,8 +3049,8 @@ dependencies = [ [[package]] name = "rocket" -version = "0.5.0-rc.3" -source = "git+https://github.com/SergioBenitez/Rocket#07fe79796f058ab12683ff9e344558bece263274" +version = "0.6.0-dev" +source = "git+https://github.com/SergioBenitez/Rocket#7f7d352e453e83f3d23ee12f8965ce75c977fcea" dependencies = [ "async-stream", "async-trait", @@ -3078,8 +3086,8 @@ dependencies = [ [[package]] name = "rocket_codegen" -version = "0.5.0-rc.3" -source = "git+https://github.com/SergioBenitez/Rocket#07fe79796f058ab12683ff9e344558bece263274" +version = "0.6.0-dev" +source = "git+https://github.com/SergioBenitez/Rocket#7f7d352e453e83f3d23ee12f8965ce75c977fcea" dependencies = [ "devise", "glob", @@ -3094,8 +3102,8 @@ dependencies = [ [[package]] name = "rocket_http" -version = "0.5.0-rc.3" -source = "git+https://github.com/SergioBenitez/Rocket#07fe79796f058ab12683ff9e344558bece263274" +version = "0.6.0-dev" +source = "git+https://github.com/SergioBenitez/Rocket#7f7d352e453e83f3d23ee12f8965ce75c977fcea" dependencies = [ "cookie", "either", @@ -3118,6 +3126,15 @@ dependencies = [ "uncased", ] +[[package]] +name = "rocket_ws" +version = "0.1.0" +source = "git+https://github.com/SergioBenitez/Rocket#7f7d352e453e83f3d23ee12f8965ce75c977fcea" +dependencies = [ + "rocket", + "tokio-tungstenite", +] + [[package]] name = "rust-stemmers" version = "1.2.0" @@ -4337,6 +4354,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.8" @@ -4525,6 +4554,25 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "tungstenite" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e3dac10fd62eaf6617d3a904ae222845979aec67c615d1c842b4002c7666fb9" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typed-arena" version = "2.0.2" diff --git a/pgml-dashboard/Cargo.toml b/pgml-dashboard/Cargo.toml index 47238f6ed..6d1b803dd 100644 --- a/pgml-dashboard/Cargo.toml +++ b/pgml-dashboard/Cargo.toml @@ -50,3 +50,5 @@ tokio = { version = "1", features = ["full"] } url = "2.4" yaml-rust = "0.4" zoomies = { git="https://github.com/HyperparamAI/zoomies.git", branch="master" } +ws = { package = "rocket_ws", git = "https://github.com/SergioBenitez/Rocket" } +futures = "0.3.29" diff --git a/pgml-dashboard/src/api/chatbot.rs b/pgml-dashboard/src/api/chatbot.rs index c4b12d0c2..0b8978844 100644 --- a/pgml-dashboard/src/api/chatbot.rs +++ b/pgml-dashboard/src/api/chatbot.rs @@ -1,9 +1,10 @@ use anyhow::Context; -use pgml::{Collection, Pipeline}; +use futures::stream::StreamExt; +use pgml::{types::GeneralJsonAsyncIterator, Collection, OpenSourceAI, Pipeline}; use rand::{distributions::Alphanumeric, Rng}; use reqwest::Client; use rocket::{ - http::Status, + http::{Cookie, CookieJar, Status}, outcome::IntoOutcome, request::{self, FromRequest}, route::Route, @@ -14,11 +15,6 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use std::time::{SystemTime, UNIX_EPOCH}; -use crate::{ - forms, - responses::{Error, ResponseOk}, -}; - pub struct User { chatbot_session_id: String, } @@ -40,32 +36,134 @@ impl<'r> FromRequest<'r> for User { #[derive(Serialize, Deserialize, PartialEq, Eq)] enum ChatRole { + System, User, Bot, } +impl ChatRole { + fn to_model_specific_role(&self, brain: &ChatbotBrain) -> &'static str { + match self { + ChatRole::User => "user", + ChatRole::Bot => match brain { + ChatbotBrain::OpenAIGPT4 + | ChatbotBrain::TekniumOpenHermes25Mistral7B + | ChatbotBrain::Starling7b => "assistant", + ChatbotBrain::GrypheMythoMaxL213b => "model", + }, + ChatRole::System => "system", + } + } +} + #[derive(Clone, Copy, Serialize, Deserialize)] enum ChatbotBrain { OpenAIGPT4, - PostgresMLFalcon180b, - AnthropicClaude, - MetaLlama2, + TekniumOpenHermes25Mistral7B, + GrypheMythoMaxL213b, + Starling7b, +} + +impl ChatbotBrain { + fn is_open_source(&self) -> bool { + !matches!(self, Self::OpenAIGPT4) + } + + fn get_system_message( + &self, + knowledge_base: &KnowledgeBase, + context: &str, + ) -> anyhow::Result { + match self { + Self::OpenAIGPT4 => { + let system_prompt = std::env::var("CHATBOT_CHATGPT_SYSTEM_PROMPT")?; + let system_prompt = system_prompt + .replace("{topic}", knowledge_base.topic()) + .replace("{persona}", "Engineer") + .replace("{language}", "English"); + Ok(serde_json::json!({ + "role": "system", + "content": system_prompt + })) + } + _ => Ok(serde_json::json!({ + "role": "system", + "content": format!(r#"You are a friendly and helpful chatbot that uses the following documents to answer the user's questions with the best of your ability. There is one rule: Do Not Lie. + +{} + + "#, context) + })), + } + } + + fn into_model_json(self) -> serde_json::Value { + match self { + Self::TekniumOpenHermes25Mistral7B => serde_json::json!({ + "model": "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ", + "revision": "main", + "device_map": "auto", + "quantization_config": { + "bits": 4, + "max_input_length": 10000 + } + }), + Self::GrypheMythoMaxL213b => serde_json::json!({ + "model": "TheBloke/MythoMax-L2-13B-GPTQ", + "revision": "main", + "device_map": "auto", + "quantization_config": { + "bits": 4, + "max_input_length": 10000 + } + }), + Self::Starling7b => serde_json::json!({ + "model": "TheBloke/Starling-LM-7B-alpha-GPTQ", + "revision": "main", + "device_map": "auto", + "quantization_config": { + "bits": 4, + "max_input_length": 10000 + } + }), + _ => unimplemented!(), + } + } + + fn get_chat_template(&self) -> Option<&'static str> { + match self { + Self::TekniumOpenHermes25Mistral7B => Some("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"), + Self::GrypheMythoMaxL213b => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### Instruction:\n' + message['content'] + '\n'}}\n{% elif message['role'] == 'system' %}\n{{ message['content'] + '\n'}}\n{% elif message['role'] == 'model' %}\n{{ '### Response:>\n' + message['content'] + eos_token + '\n'}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Response:' }}\n{% endif %}\n{% endfor %}"), + _ => None + } + } } -impl TryFrom for ChatbotBrain { +impl TryFrom<&str> for ChatbotBrain { type Error = anyhow::Error; - fn try_from(value: u8) -> anyhow::Result { + fn try_from(value: &str) -> anyhow::Result { match value { - 0 => Ok(ChatbotBrain::OpenAIGPT4), - 1 => Ok(ChatbotBrain::PostgresMLFalcon180b), - 2 => Ok(ChatbotBrain::AnthropicClaude), - 3 => Ok(ChatbotBrain::MetaLlama2), + "teknium/OpenHermes-2.5-Mistral-7B" => Ok(ChatbotBrain::TekniumOpenHermes25Mistral7B), + "Gryphe/MythoMax-L2-13b" => Ok(ChatbotBrain::GrypheMythoMaxL213b), + "openai" => Ok(ChatbotBrain::OpenAIGPT4), + "berkeley-nest/Starling-LM-7B-alpha" => Ok(ChatbotBrain::Starling7b), _ => Err(anyhow::anyhow!("Invalid brain id")), } } } +impl From for &'static str { + fn from(value: ChatbotBrain) -> Self { + match value { + ChatbotBrain::TekniumOpenHermes25Mistral7B => "teknium/OpenHermes-2.5-Mistral-7B", + ChatbotBrain::GrypheMythoMaxL213b => "Gryphe/MythoMax-L2-13b", + ChatbotBrain::OpenAIGPT4 => "openai", + ChatbotBrain::Starling7b => "berkeley-nest/Starling-LM-7B-alpha", + } + } +} + #[derive(Clone, Copy, Serialize, Deserialize)] enum KnowledgeBase { PostgresML, @@ -95,20 +193,31 @@ impl KnowledgeBase { } } -impl TryFrom for KnowledgeBase { +impl TryFrom<&str> for KnowledgeBase { type Error = anyhow::Error; - fn try_from(value: u8) -> anyhow::Result { + fn try_from(value: &str) -> anyhow::Result { match value { - 0 => Ok(KnowledgeBase::PostgresML), - 1 => Ok(KnowledgeBase::PyTorch), - 2 => Ok(KnowledgeBase::Rust), - 3 => Ok(KnowledgeBase::PostgreSQL), + "postgresml" => Ok(KnowledgeBase::PostgresML), + "pytorch" => Ok(KnowledgeBase::PyTorch), + "rust" => Ok(KnowledgeBase::Rust), + "postgresql" => Ok(KnowledgeBase::PostgreSQL), _ => Err(anyhow::anyhow!("Invalid knowledge base id")), } } } +impl From for &'static str { + fn from(value: KnowledgeBase) -> Self { + match value { + KnowledgeBase::PostgresML => "postgresml", + KnowledgeBase::PyTorch => "pytorch", + KnowledgeBase::Rust => "rust", + KnowledgeBase::PostgreSQL => "postgresql", + } + } +} + #[derive(Serialize, Deserialize)] struct Document { id: String, @@ -122,7 +231,7 @@ struct Document { impl Document { fn new( - text: String, + text: &str, role: ChatRole, user_id: String, model: ChatbotBrain, @@ -139,7 +248,7 @@ impl Document { .as_millis(); Document { id, - text, + text: text.to_string(), role, user_id, model, @@ -149,29 +258,11 @@ impl Document { } } -async fn get_openai_chatgpt_answer( - knowledge_base: KnowledgeBase, - history: &str, - context: &str, - question: &str, -) -> Result { +async fn get_openai_chatgpt_answer(messages: M) -> anyhow::Result { let openai_api_key = std::env::var("OPENAI_API_KEY")?; - let base_prompt = std::env::var("CHATBOT_CHATGPT_BASE_PROMPT")?; - let system_prompt = std::env::var("CHATBOT_CHATGPT_SYSTEM_PROMPT")?; - - let system_prompt = system_prompt - .replace("{topic}", knowledge_base.topic()) - .replace("{persona}", "Engineer") - .replace("{language}", "English"); - - let content = base_prompt - .replace("{history}", history) - .replace("{context}", context) - .replace("{question}", question); - let body = json!({ "model": "gpt-3.5-turbo", - "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": content}], + "messages": messages, "temperature": 0.7 }); @@ -194,60 +285,133 @@ async fn get_openai_chatgpt_answer( Ok(response) } -#[post("/chatbot/get-answer", format = "json", data = "")] -pub async fn chatbot_get_answer( - user: User, - data: Json, -) -> Result { - match wrapped_chatbot_get_answer(user, data).await { - Ok(response) => Ok(ResponseOk( - json!({ - "answer": response, - }) - .to_string(), - )), - Err(error) => { - eprintln!("Error: {:?}", error); - Ok(ResponseOk( - json!({ - "error": error.to_string(), - }) - .to_string(), - )) +struct UpdateHistory { + collection: Collection, + user_document: Document, + model: ChatbotBrain, + knowledge_base: KnowledgeBase, +} + +impl UpdateHistory { + fn new( + collection: Collection, + user_document: Document, + model: ChatbotBrain, + knowledge_base: KnowledgeBase, + ) -> Self { + Self { + collection, + user_document, + model, + knowledge_base, } } + + fn update_history(mut self, chatbot_response: &str) -> anyhow::Result<()> { + let chatbot_document = Document::new( + chatbot_response, + ChatRole::Bot, + self.user_document.user_id.to_owned(), + self.model, + self.knowledge_base, + ); + let new_history_messages: Vec = vec![ + serde_json::to_value(self.user_document).unwrap().into(), + serde_json::to_value(chatbot_document).unwrap().into(), + ]; + // We do not want to block our return waiting for this to happen + tokio::spawn(async move { + self.collection + .upsert_documents(new_history_messages, None) + .await + .expect("Failed to upsert user history"); + }); + Ok(()) + } } -pub async fn wrapped_chatbot_get_answer( - user: User, - data: Json, -) -> Result { - let brain = ChatbotBrain::try_from(data.model)?; - let knowledge_base = KnowledgeBase::try_from(data.knowledge_base)?; - - // Create it up here so the timestamps that order the conversation are accurate - let user_document = Document::new( - data.question.clone(), - ChatRole::User, - user.chatbot_session_id.clone(), - brain, - knowledge_base, - ); +#[derive(Serialize)] +struct StreamResponse { + id: Option, + error: Option, + result: Option, + partial_result: Option, +} - let collection = knowledge_base.collection(); - let collection = Collection::new( - collection, - Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), - ); +impl StreamResponse { + fn from_error(id: Option, error: E) -> Self { + StreamResponse { + id, + error: Some(format!("{error}")), + result: None, + partial_result: None, + } + } + + fn from_result(id: u64, result: &str) -> Self { + StreamResponse { + id: Some(id), + error: None, + result: Some(result.to_string()), + partial_result: None, + } + } + + fn from_partial_result(id: u64, result: &str) -> Self { + StreamResponse { + id: Some(id), + error: None, + result: None, + partial_result: Some(result.to_string()), + } + } +} + +#[get("/chatbot/clear-history")] +pub async fn clear_history(cookies: &CookieJar<'_>) -> Status { + // let cookie = Cookie::build("chatbot_session_id").path("/"); + let cookie = Cookie::new("chatbot_session_id", ""); + cookies.remove(cookie); + Status::Ok +} + +#[derive(Serialize)] +pub struct GetHistoryResponse { + result: Option>, + error: Option, +} + +#[derive(Serialize)] +struct HistoryMessage { + side: String, + content: String, + knowledge_base: String, + brain: String, +} + +#[get("/chatbot/get-history")] +pub async fn chatbot_get_history(user: User) -> Json { + match do_chatbot_get_history(&user, 100).await { + Ok(messages) => Json(GetHistoryResponse { + result: Some(messages), + error: None, + }), + Err(e) => Json(GetHistoryResponse { + result: None, + error: Some(format!("{e}")), + }), + } +} - let mut history_collection = Collection::new( +async fn do_chatbot_get_history(user: &User, limit: usize) -> anyhow::Result> { + let history_collection = Collection::new( "ChatHistory", Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), ); - let messages = history_collection + let mut messages = history_collection .get_documents(Some( json!({ - "limit": 5, + "limit": limit, "order_by": {"timestamp": "desc"}, "filter": { "metadata": { @@ -263,16 +427,6 @@ pub async fn wrapped_chatbot_get_answer( "user_id": { "$eq": user.chatbot_session_id } - }, - { - "knowledge_base": { - "$eq": knowledge_base - } - }, - { - "model": { - "$eq": brain - } } ] } @@ -282,24 +436,108 @@ pub async fn wrapped_chatbot_get_answer( .into(), )) .await?; - - let mut history = messages + messages.reverse(); + let messages: anyhow::Result> = messages .into_iter() .map(|m| { - // Can probably remove this clone - let chat_role: ChatRole = serde_json::from_value(m["document"]["role"].to_owned())?; - if chat_role == ChatRole::Bot { - Ok(format!("Assistant: {}", m["document"]["text"])) - } else { - Ok(format!("User: {}", m["document"]["text"])) - } + let side: String = m["document"]["role"] + .as_str() + .context("Error parsing chat role")? + .to_string() + .to_lowercase(); + let content: String = m["document"]["text"] + .as_str() + .context("Error parsing text")? + .to_string(); + let model: ChatbotBrain = serde_json::from_value(m["document"]["model"].to_owned()) + .context("Error parsing model")?; + let model: &str = model.into(); + let knowledge_base: KnowledgeBase = + serde_json::from_value(m["document"]["knowledge_base"].to_owned()) + .context("Error parsing knowledge_base")?; + let knowledge_base: &str = knowledge_base.into(); + Ok(HistoryMessage { + side, + content, + brain: model.to_string(), + knowledge_base: knowledge_base.to_string(), + }) }) - .collect::>>()?; - history.reverse(); - let history = history.join("\n"); + .collect(); + messages +} - let pipeline = Pipeline::new("v1", None, None, None); - let context = collection +#[get("/chatbot/get-answer")] +pub async fn chatbot_get_answer(user: User, ws: ws::WebSocket) -> ws::Stream!['static] { + ws::Stream! { ws => + for await message in ws { + let v = process_message(message, &user).await; + match v { + Ok((v, id)) => + match v { + ProcessMessageResponse::StreamResponse((mut it, update_history)) => { + let mut total_text: Vec = Vec::new(); + while let Some(value) = it.next().await { + match value { + Ok(v) => { + let v: &str = v["choices"][0]["delta"]["content"].as_str().unwrap(); + total_text.push(v.to_string()); + yield ws::Message::from(serde_json::to_string(&StreamResponse::from_partial_result(id, v)).unwrap()); + }, + Err(e) => yield ws::Message::from(serde_json::to_string(&StreamResponse::from_error(Some(id), e)).unwrap()) + } + } + update_history.update_history(&total_text.join("")).unwrap(); + }, + ProcessMessageResponse::FullResponse(resp) => { + yield ws::Message::from(serde_json::to_string(&StreamResponse::from_result(id, &resp)).unwrap()); + } + } + Err(e) => { + yield ws::Message::from(serde_json::to_string(&StreamResponse::from_error(None, e)).unwrap()); + } + } + }; + } +} + +enum ProcessMessageResponse { + StreamResponse((GeneralJsonAsyncIterator, UpdateHistory)), + FullResponse(String), +} + +#[derive(Deserialize)] +struct Message { + id: u64, + model: String, + knowledge_base: String, + question: String, +} + +async fn process_message( + message: Result, + user: &User, +) -> anyhow::Result<(ProcessMessageResponse, u64)> { + if let ws::Message::Text(s) = message? { + let data: Message = serde_json::from_str(&s)?; + let brain = ChatbotBrain::try_from(data.model.as_str())?; + let knowledge_base = KnowledgeBase::try_from(data.knowledge_base.as_str())?; + + let user_document = Document::new( + &data.question, + ChatRole::User, + user.chatbot_session_id.clone(), + brain, + knowledge_base, + ); + + let pipeline = Pipeline::new("v1", None, None, None); + let collection = knowledge_base.collection(); + let collection = Collection::new( + collection, + Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), + ); + let context = collection .query() .vector_recall(&data.question, &pipeline, Some(json!({ "instruction": "Represent the Wikipedia question for retrieving supporting documents: " @@ -308,37 +546,152 @@ pub async fn wrapped_chatbot_get_answer( .fetch_all() .await? .into_iter() - .map(|(_, context, metadata)| format!("#### Document {}: {}", metadata["id"], context)) + .map(|(_, context, metadata)| format!("\n\n#### Document {}: \n{}\n\n", metadata["id"], context)) .collect::>() .join("\n"); - let answer = - get_openai_chatgpt_answer(knowledge_base, &history, &context, &data.question).await?; - - let new_history_messages: Vec = vec![ - serde_json::to_value(user_document).unwrap().into(), - serde_json::to_value(Document::new( - answer.clone(), - ChatRole::Bot, - user.chatbot_session_id.clone(), - brain, - knowledge_base, - )) - .unwrap() - .into(), - ]; - - // We do not want to block our return waiting for this to happen - tokio::spawn(async move { - history_collection - .upsert_documents(new_history_messages, None) - .await - .expect("Failed to upsert user history"); - }); + let history_collection = Collection::new( + "ChatHistory", + Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), + ); + let mut messages = history_collection + .get_documents(Some( + json!({ + "limit": 5, + "order_by": {"timestamp": "desc"}, + "filter": { + "metadata": { + "$and" : [ + { + "$or": + [ + {"role": {"$eq": ChatRole::Bot}}, + {"role": {"$eq": ChatRole::User}} + ] + }, + { + "user_id": { + "$eq": user.chatbot_session_id + } + }, + { + "knowledge_base": { + "$eq": knowledge_base + } + }, + // This is where we would match on the model if we wanted to + ] + } + } - Ok(answer) + }) + .into(), + )) + .await?; + messages.reverse(); + + let (mut history, _) = + messages + .into_iter() + .fold((Vec::new(), None), |(mut new_history, role), value| { + let current_role: ChatRole = + serde_json::from_value(value["document"]["role"].to_owned()) + .expect("Error parsing chat role"); + if let Some(role) = role { + if role == current_role { + match role { + ChatRole::User => new_history.push( + serde_json::json!({ + "role": ChatRole::Bot.to_model_specific_role(&brain), + "content": "*no response due to error*" + }) + .into(), + ), + ChatRole::Bot => new_history.push( + serde_json::json!({ + "role": ChatRole::User.to_model_specific_role(&brain), + "content": "*no response due to error*" + }) + .into(), + ), + _ => panic!("Too many system messages"), + } + } + let new_message: pgml::types::Json = serde_json::json!({ + "role": current_role.to_model_specific_role(&brain), + "content": value["document"]["text"] + }) + .into(); + new_history.push(new_message); + } else if matches!(current_role, ChatRole::User) { + let new_message: pgml::types::Json = serde_json::json!({ + "role": current_role.to_model_specific_role(&brain), + "content": value["document"]["text"] + }) + .into(); + new_history.push(new_message); + } + (new_history, Some(current_role)) + }); + + let system_message = brain.get_system_message(&knowledge_base, &context)?; + history.insert(0, system_message.into()); + + // Need to make sure we aren't about to add two user messages back to back + if let Some(message) = history.last() { + if message["role"].as_str().unwrap() == ChatRole::User.to_model_specific_role(&brain) { + history.push( + serde_json::json!({ + "role": ChatRole::Bot.to_model_specific_role(&brain), + "content": "*no response due to errors*" + }) + .into(), + ); + } + } + history.push( + serde_json::json!({ + "role": ChatRole::User.to_model_specific_role(&brain), + "content": data.question + }) + .into(), + ); + + let update_history = + UpdateHistory::new(history_collection, user_document, brain, knowledge_base); + + if brain.is_open_source() { + let op = OpenSourceAI::new(Some( + std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set"), + )); + let chat_template = brain.get_chat_template(); + let stream = op + .chat_completions_create_stream_async( + brain.into_model_json().into(), + history, + Some(10000), + None, + None, + chat_template.map(|t| t.to_string()), + ) + .await?; + Ok(( + ProcessMessageResponse::StreamResponse((stream, update_history)), + data.id, + )) + } else { + let response = match brain { + ChatbotBrain::OpenAIGPT4 => get_openai_chatgpt_answer(history).await?, + _ => unimplemented!(), + }; + update_history.update_history(&response)?; + Ok((ProcessMessageResponse::FullResponse(response), data.id)) + } + } else { + Err(anyhow::anyhow!("Error invalid message format")) + } } pub fn routes() -> Vec { - routes![chatbot_get_answer] + routes![chatbot_get_answer, chatbot_get_history, clear_history] } diff --git a/pgml-dashboard/src/components/chatbot/chatbot.scss b/pgml-dashboard/src/components/chatbot/chatbot.scss index e4bc2f723..a8b934dd5 100644 --- a/pgml-dashboard/src/components/chatbot/chatbot.scss +++ b/pgml-dashboard/src/components/chatbot/chatbot.scss @@ -19,6 +19,7 @@ div[data-controller="chatbot"] { #chatbot-change-the-brain-title, #knowledge-base-title { + font-size: 1.25rem; padding: 0.5rem; padding-top: 0.85rem; margin-bottom: 1rem; @@ -30,6 +31,7 @@ div[data-controller="chatbot"] { margin-top: calc($spacer * 4); } + div[data-chatbot-target="clear"], .chatbot-brain-option-label, .chatbot-knowledge-base-option-label { cursor: pointer; @@ -37,7 +39,7 @@ div[data-controller="chatbot"] { transition: all 0.1s; } - .chatbot-brain-option-label:hover { + .chatbot-brain-option-label:hover, div[data-chatbot-target="clear"]:hover { background-color: #{$gray-800}; } @@ -59,8 +61,8 @@ div[data-controller="chatbot"] { } .chatbot-brain-option-logo { - height: 30px; width: 30px; + height: 30px; background-position: center; background-repeat: no-repeat; background-size: contain; @@ -70,6 +72,14 @@ div[data-controller="chatbot"] { padding-left: 2rem; } + #brain-knowledge-base-divider-line { + height: 0.15rem; + width: 100%; + background-color: #{$gray-500}; + margin-top: 1.5rem; + margin-bottom: 1.5rem; + } + .chatbot-example-questions { display: none; max-height: 66px; @@ -299,4 +309,10 @@ div[data-controller="chatbot"].chatbot-full { #knowledge-base-wrapper { display: block; } + #brain-knowledge-base-divider-line { + display: none; + } + #clear-history-text { + display: block !important; + } } diff --git a/pgml-dashboard/src/components/chatbot/chatbot_controller.js b/pgml-dashboard/src/components/chatbot/chatbot_controller.js index ef6703b33..d6240c645 100644 --- a/pgml-dashboard/src/components/chatbot/chatbot_controller.js +++ b/pgml-dashboard/src/components/chatbot/chatbot_controller.js @@ -4,6 +4,10 @@ import autosize from "autosize"; import DOMPurify from "dompurify"; import * as marked from "marked"; +const getRandomInt = () => { + return Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); +} + const LOADING_MESSAGE = `
Loading
@@ -11,40 +15,44 @@ const LOADING_MESSAGE = `
`; -const getBackgroundImageURLForSide = (side, knowledgeBase) => { +const getBackgroundImageURLForSide = (side, brain) => { if (side == "user") { return "/dashboard/static/images/chatbot_user.webp"; } else { - if (knowledgeBase == 0) { - return "/dashboard/static/images/owl_gradient.svg"; - } else if (knowledgeBase == 1) { - return "/dashboard/static/images/logos/pytorch.svg"; - } else if (knowledgeBase == 2) { - return "/dashboard/static/images/logos/rust.svg"; - } else if (knowledgeBase == 3) { - return "/dashboard/static/images/logos/postgresql.svg"; + if (brain == "teknium/OpenHermes-2.5-Mistral-7B") { + return "/dashboard/static/images/logos/openhermes.webp" + } else if (brain == "Gryphe/MythoMax-L2-13b") { + return "/dashboard/static/images/logos/mythomax.webp" + } else if (brain == "berkeley-nest/Starling-LM-7B-alpha") { + return "/dashboard/static/images/logos/starling.webp" + } else if (brain == "openai") { + return "/dashboard/static/images/logos/openai.webp" } } }; -const createHistoryMessage = (side, question, id, knowledgeBase) => { - id = id || ""; +const createHistoryMessage = (message) => { + if (message.side == "system") { + return ` +
${message.text}
+ `; + } return ` -
-
- ${question} +
+ ${message.get_html()}
@@ -52,17 +60,29 @@ const createHistoryMessage = (side, question, id, knowledgeBase) => { }; const knowledgeBaseIdToName = (knowledgeBase) => { - if (knowledgeBase == 0) { + if (knowledgeBase == "postgresml") { return "PostgresML"; - } else if (knowledgeBase == 1) { + } else if (knowledgeBase == "pytorch") { return "PyTorch"; - } else if (knowledgeBase == 2) { + } else if (knowledgeBase == "rust") { return "Rust"; - } else if (knowledgeBase == 3) { + } else if (knowledgeBase == "postgresql") { return "PostgreSQL"; } }; +const brainIdToName = (brain) => { + if (brain == "teknium/OpenHermes-2.5-Mistral-7B") { + return "OpenHermes" + } else if (brain == "Gryphe/MythoMax-L2-13b") { + return "MythoMax" + } else if (brain == "berkeley-nest/Starling-LM-7B-alpha") { + return "Starling" + } else if (brain == "openai") { + return "ChatGPT" + } +} + const createKnowledgeBaseNotice = (knowledgeBase) => { return `
Chatting with Knowledge Base ${knowledgeBaseIdToName( @@ -71,21 +91,72 @@ const createKnowledgeBaseNotice = (knowledgeBase) => { `; }; -const getAnswer = async (question, model, knowledgeBase) => { - const response = await fetch("/chatbot/get-answer", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ question, model, knowledgeBase }), - }); - return response.json(); -}; +class Message { + constructor(id, side, brain, text, is_partial=false) { + this.id = id + this.side = side + this.brain = brain + this.text = text + this.is_partial = is_partial + } + + get_html() { + return DOMPurify.sanitize(marked.parse(this.text)); + } +} + +class RawMessage extends Message { + constructor(id, side, text, is_partial=false) { + super(id, side, text, is_partial); + } + + get_html() { + return this.text; + } +} + +class MessageHistory { + constructor() { + this.messageHistory = {}; + } + + add_message(message, knowledgeBase) { + console.log("ADDDING", message, knowledgeBase); + if (!(knowledgeBase in this.messageHistory)) { + this.messageHistory[knowledgeBase] = []; + } + if (message.is_partial) { + let current_message = this.messageHistory[knowledgeBase].find(item => item.id == message.id); + if (!current_message) { + this.messageHistory[knowledgeBase].push(message); + } else { + current_message.text += message.text; + } + } else { + if (this.messageHistory[knowledgeBase].length == 0 || message.side != "system") { + this.messageHistory[knowledgeBase].push(message); + } else if (this.messageHistory[knowledgeBase][this.messageHistory[knowledgeBase].length -1].side == "system") { + this.messageHistory[knowledgeBase][this.messageHistory[knowledgeBase].length -1] = message + } else { + this.messageHistory[knowledgeBase].push(message); + } + } + } + + get_messages(knowledgeBase) { + if (!(knowledgeBase in this.messageHistory)) { + return []; + } else { + return this.messageHistory[knowledgeBase]; + } + } +} export default class extends Controller { initialize() { - this.alertCount = 0; - this.gettingAnswer = false; + this.messageHistory = new MessageHistory(); + this.messageIdToKnowledgeBaseId = {}; + this.expanded = false; this.chatbot = document.getElementById("chatbot"); this.expandContractImage = document.getElementById( @@ -100,55 +171,105 @@ export default class extends Controller { this.exampleQuestions = document.getElementsByClassName( "chatbot-example-questions", ); - this.handleBrainChange(); // This will set our initial brain this.handleKnowledgeBaseChange(); // This will set our initial knowledge base + this.handleBrainChange(); // This will set our initial brain this.handleResize(); + + const url = ((window.location.protocol === "https:") ? "wss://" : "ws://") + window.location.hostname + (((window.location.port != 80) && (window.location.port != 443)) ? ":" + window.location.port : "") + window.location.pathname + "/get-answer"; + this.socket = new WebSocket(url); + this.socket.onmessage = (message) => { + let result = JSON.parse(message.data); + console.log(result); + + if (result.error) { + this.showChatbotAlert("Error", "Error getting chatbot answer"); + console.log(result.error); + this.redrawChat(); // This clears any loading messages + } else { + let message; + if (result.partial_result) { + message = new Message(result.id, "bot", this.brain, result.partial_result, true); + } else { + message = new Message(result.id, "bot", this.brain, result.result); + } + this.messageHistory.add_message(message, this.messageIdToKnowledgeBaseId[message.id]); + this.redrawChat(); + } + this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + }; + + this.socket.onclose = () => { + window.setTimeout(() => this.openConnection(), 500); + }; + this.getHistory(); + } + + async clearHistory() { + // This endpoint clears the chatbot_sesion_id cookie + await fetch("/chatbot/clear-history"); + window.location.reload(); + } + + async getHistory() { + const result = await fetch("/chatbot/get-history"); + const history = await result.json(); + if (history.error) { + console.log("Error getting chat history", history.error) + } else { + for (const message of history.result) { + const newMessage = new Message(getRandomInt(), message.side, message.brain, message.content, false); + console.log(newMessage); + this.messageHistory.add_message(newMessage, message.knowledge_base); + } + } + this.redrawChat(); + } + + redrawChat() { + this.chatHistory.innerHTML = ""; + const messages = this.messageHistory.get_messages(this.knowledgeBase); + for (const message of messages) { + console.log("Drawing", message); + this.chatHistory.insertAdjacentHTML( + "beforeend", + createHistoryMessage(message), + ); + } + + // Hide or show example questions + this.hideExampleQuestions(); + if (messages.length == 0 || (messages.length == 1 && messages[0].side == "system")) { + document + .getElementById(`chatbot-example-questions-${this.knowledgeBase}`) + .style.setProperty("display", "flex", "important"); + } + + this.chatHistory.scrollTop = this.chatHistory.scrollHeight; } newUserQuestion(question) { + const message = new Message(getRandomInt(), "user", this.brain, question); + this.messageHistory.add_message(message, this.knowledgeBase); + this.messageIdToKnowledgeBaseId[message.id] = this.knowledgeBase; + this.hideExampleQuestions(); + this.redrawChat(); + + let loadingMessage = new Message("loading", "bot", this.brain, LOADING_MESSAGE); this.chatHistory.insertAdjacentHTML( "beforeend", - createHistoryMessage("user", question), - ); - this.chatHistory.insertAdjacentHTML( - "beforeend", - createHistoryMessage( - "bot", - LOADING_MESSAGE, - "chatbot-loading-message", - this.knowledgeBase, - ), + createHistoryMessage(loadingMessage), ); - this.hideExampleQuestions(); this.chatHistory.scrollTop = this.chatHistory.scrollHeight; - - this.gettingAnswer = true; - getAnswer(question, this.brain, this.knowledgeBase) - .then((answer) => { - if (answer.answer) { - this.chatHistory.insertAdjacentHTML( - "beforeend", - createHistoryMessage( - "bot", - DOMPurify.sanitize(marked.parse(answer.answer)), - "", - this.knowledgeBase, - ), - ); - } else { - this.showChatbotAlert("Error", answer.error); - console.log(answer.error); - } - }) - .catch((error) => { - this.showChatbotAlert("Error", "Error getting chatbot answer"); - console.log(error); - }) - .finally(() => { - document.getElementById("chatbot-loading-message").remove(); - this.chatHistory.scrollTop = this.chatHistory.scrollHeight; - this.gettingAnswer = false; - }); + + let id = getRandomInt(); + this.messageIdToKnowledgeBaseId[id] = this.knowledgeBase; + let socketData = { + id, + question, + model: this.brain, + knowledge_base: this.knowledgeBase + }; + this.socket.send(JSON.stringify(socketData)); } handleResize() { @@ -169,12 +290,10 @@ export default class extends Controller { handleEnter(e) { // This prevents adding a return e.preventDefault(); - + // Don't continue if the question is empty const question = this.questionInput.value.trim(); - if (question.length == 0) { + if (question.length == 0) return; - } - // Handle resetting the input // There is probably a better way to do this, but this was the best/easiest I found this.questionInput.value = ""; @@ -185,105 +304,31 @@ export default class extends Controller { } handleBrainChange() { - // Comment this out when we go back to using brains - this.brain = 0; + let selected = document.querySelector('input[name="chatbot-brain-options"]:checked').value; + if (selected == this.brain) + return; + this.brain = selected; this.questionInput.focus(); - - // Uncomment this out when we go back to using brains - // We could just disable the input, but we would then need to listen for click events so this seems easier - // if (this.gettingAnswer) { - // document.querySelector( - // `input[name="chatbot-brain-options"][value="${this.brain}"]`, - // ).checked = true; - // this.showChatbotAlert( - // "Error", - // "Cannot change brain while chatbot is loading answer", - // ); - // return; - // } - // let selected = parseInt( - // document.querySelector('input[name="chatbot-brain-options"]:checked') - // .value, - // ); - // if (selected == this.brain) { - // return; - // } - // brainToContentMap[this.brain] = this.chatHistory.innerHTML; - // this.chatHistory.innerHTML = brainToContentMap[selected] || ""; - // if (this.chatHistory.innerHTML) { - // this.exampleQuestions.style.setProperty("display", "none", "important"); - // } else { - // this.exampleQuestions.style.setProperty("display", "flex", "important"); - // } - // this.brain = selected; - // this.chatHistory.scrollTop = this.chatHistory.scrollHeight; - // this.questionInput.focus(); + this.addBrainAndKnowledgeBaseChangedSystemMessage(); } handleKnowledgeBaseChange() { - // Uncomment this when we go back to using brains - // let selected = parseInt( - // document.querySelector('input[name="chatbot-knowledge-base-options"]:checked') - // .value, - // ); - // this.knowledgeBase = selected; - - // Comment this out when we go back to using brains - // We could just disable the input, but we would then need to listen for click events so this seems easier - if (this.gettingAnswer) { - document.querySelector( - `input[name="chatbot-knowledge-base-options"][value="${this.knowledgeBase}"]`, - ).checked = true; - this.showChatbotAlert( - "Error", - "Cannot change knowledge base while chatbot is loading answer", - ); - return; - } - let selected = parseInt( - document.querySelector( - 'input[name="chatbot-knowledge-base-options"]:checked', - ).value, - ); - if (selected == this.knowledgeBase) { + let selected = document.querySelector('input[name="chatbot-knowledge-base-options"]:checked').value; + if (selected == this.knowledgeBase) return; - } - - // document.getElementById - this.knowledgeBaseToContentMap[this.knowledgeBase] = - this.chatHistory.innerHTML; - this.chatHistory.innerHTML = this.knowledgeBaseToContentMap[selected] || ""; this.knowledgeBase = selected; - - // This should be extended to insert the new knowledge base notice in the correct place - if (this.chatHistory.childElementCount == 0) { - this.chatHistory.insertAdjacentHTML( - "beforeend", - createKnowledgeBaseNotice(this.knowledgeBase), - ); - this.hideExampleQuestions(); - document - .getElementById( - `chatbot-example-questions-${knowledgeBaseIdToName( - this.knowledgeBase, - )}`, - ) - .style.setProperty("display", "flex", "important"); - } else if (this.chatHistory.childElementCount == 1) { - this.hideExampleQuestions(); - document - .getElementById( - `chatbot-example-questions-${knowledgeBaseIdToName( - this.knowledgeBase, - )}`, - ) - .style.setProperty("display", "flex", "important"); - } else { - this.hideExampleQuestions(); - } - - this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + this.redrawChat(); this.questionInput.focus(); + this.addBrainAndKnowledgeBaseChangedSystemMessage(); + } + + addBrainAndKnowledgeBaseChangedSystemMessage() { + let knowledge_base = knowledgeBaseIdToName(this.knowledgeBase); + let brain = brainIdToName(this.brain); + let content = `Chatting with ${brain} about ${knowledge_base}`; + const newMessage = new Message(getRandomInt(), "system", this.brain, content); + this.messageHistory.add_message(newMessage, this.knowledgeBase); + this.redrawChat(); } handleExampleQuestionClick(e) { diff --git a/pgml-dashboard/src/components/chatbot/mod.rs b/pgml-dashboard/src/components/chatbot/mod.rs index 8bcf23fc4..4b149b96e 100644 --- a/pgml-dashboard/src/components/chatbot/mod.rs +++ b/pgml-dashboard/src/components/chatbot/mod.rs @@ -4,7 +4,7 @@ use sailfish::TemplateOnce; type ExampleQuestions = [(&'static str, [(&'static str, &'static str); 4]); 4]; const EXAMPLE_QUESTIONS: ExampleQuestions = [ ( - "PostgresML", + "postgresml", [ ("How do I", "use pgml.transform()?"), ("Show me", "a query to train a model"), @@ -13,7 +13,7 @@ const EXAMPLE_QUESTIONS: ExampleQuestions = [ ], ), ( - "PyTorch", + "pytorch", [ ("What are", "tensors?"), ("How do I", "train a model?"), @@ -22,7 +22,7 @@ const EXAMPLE_QUESTIONS: ExampleQuestions = [ ], ), ( - "Rust", + "rust", [ ("What is", "a lifetime?"), ("How do I", "use a for loop?"), @@ -31,7 +31,7 @@ const EXAMPLE_QUESTIONS: ExampleQuestions = [ ], ), ( - "PostgreSQL", + "postgresql", [ ("How do I", "join two tables?"), ("What is", "a GIN index?"), @@ -41,79 +41,92 @@ const EXAMPLE_QUESTIONS: ExampleQuestions = [ ), ]; -const KNOWLEDGE_BASES: [&str; 0] = [ - // "Knowledge Base 1", - // "Knowledge Base 2", - // "Knowledge Base 3", - // "Knowledge Base 4", -]; - const KNOWLEDGE_BASES_WITH_LOGO: [KnowledgeBaseWithLogo; 4] = [ - KnowledgeBaseWithLogo::new("PostgresML", "/dashboard/static/images/owl_gradient.svg"), - KnowledgeBaseWithLogo::new("PyTorch", "/dashboard/static/images/logos/pytorch.svg"), - KnowledgeBaseWithLogo::new("Rust", "/dashboard/static/images/logos/rust.svg"), KnowledgeBaseWithLogo::new( + "postgresml", + "PostgresML", + "/dashboard/static/images/owl_gradient.svg", + ), + KnowledgeBaseWithLogo::new( + "pytorch", + "PyTorch", + "/dashboard/static/images/logos/pytorch.svg", + ), + KnowledgeBaseWithLogo::new("rust", "Rust", "/dashboard/static/images/logos/rust.svg"), + KnowledgeBaseWithLogo::new( + "postgresql", "PostgreSQL", "/dashboard/static/images/logos/postgresql.svg", ), ]; struct KnowledgeBaseWithLogo { + id: &'static str, name: &'static str, logo: &'static str, } impl KnowledgeBaseWithLogo { - const fn new(name: &'static str, logo: &'static str) -> Self { - Self { name, logo } + const fn new(id: &'static str, name: &'static str, logo: &'static str) -> Self { + Self { id, name, logo } } } -const CHATBOT_BRAINS: [ChatbotBrain; 0] = [ - // ChatbotBrain::new( - // "PostgresML", - // "Falcon 180b", - // "/dashboard/static/images/owl_gradient.svg", - // ), - // ChatbotBrain::new( - // "OpenAI", - // "ChatGPT", - // "/dashboard/static/images/logos/openai.webp", - // ), - // ChatbotBrain::new( - // "Anthropic", - // "Claude", - // "/dashboard/static/images/logos/anthropic.webp", - // ), - // ChatbotBrain::new( - // "Meta", - // "Llama2 70b", - // "/dashboard/static/images/logos/meta.webp", - // ), +const CHATBOT_BRAINS: [ChatbotBrain; 4] = [ + ChatbotBrain::new( + "teknium/OpenHermes-2.5-Mistral-7B", + "OpenHermes", + "teknium/OpenHermes-2.5-Mistral-7B", + "/dashboard/static/images/logos/openhermes.webp", + ), + ChatbotBrain::new( + "Gryphe/MythoMax-L2-13b", + "MythoMax", + "Gryphe/MythoMax-L2-13b", + "/dashboard/static/images/logos/mythomax.webp", + ), + ChatbotBrain::new( + "openai", + "OpenAI", + "ChatGPT", + "/dashboard/static/images/logos/openai.webp", + ), + ChatbotBrain::new( + "berkeley-nest/Starling-LM-7B-alpha", + "Starling", + "berkeley-nest/Starling-LM-7B-alpha", + "/dashboard/static/images/logos/starling.webp", + ), ]; struct ChatbotBrain { + id: &'static str, provider: &'static str, model: &'static str, logo: &'static str, } -// impl ChatbotBrain { -// const fn new(provider: &'static str, model: &'static str, logo: &'static str) -> Self { -// Self { -// provider, -// model, -// logo, -// } -// } -// } +impl ChatbotBrain { + const fn new( + id: &'static str, + provider: &'static str, + model: &'static str, + logo: &'static str, + ) -> Self { + Self { + id, + provider, + model, + logo, + } + } +} #[derive(TemplateOnce)] #[template(path = "chatbot/template.html")] pub struct Chatbot { - brains: &'static [ChatbotBrain; 0], + brains: &'static [ChatbotBrain; 4], example_questions: &'static ExampleQuestions, - knowledge_bases: &'static [&'static str; 0], knowledge_bases_with_logo: &'static [KnowledgeBaseWithLogo; 4], } @@ -122,7 +135,6 @@ impl Default for Chatbot { Chatbot { brains: &CHATBOT_BRAINS, example_questions: &EXAMPLE_QUESTIONS, - knowledge_bases: &KNOWLEDGE_BASES, knowledge_bases_with_logo: &KNOWLEDGE_BASES_WITH_LOGO, } } diff --git a/pgml-dashboard/src/components/chatbot/template.html b/pgml-dashboard/src/components/chatbot/template.html index 1f47cf865..9da069cce 100644 --- a/pgml-dashboard/src/components/chatbot/template.html +++ b/pgml-dashboard/src/components/chatbot/template.html @@ -1,102 +1,72 @@
-
+
- -
Knowledge Base:
+
Change the Brain:
- <% for (index, knowledge_base) in knowledge_bases_with_logo.iter().enumerate() { %> + <% for (index, brain) in brains.iter().enumerate() { %>
checked <% } %> />
<% } %> - - - -
diff --git a/pgml-dashboard/static/images/logos/mythomax.webp b/pgml-dashboard/static/images/logos/mythomax.webp new file mode 100644 index 000000000..6e6c363b2 Binary files /dev/null and b/pgml-dashboard/static/images/logos/mythomax.webp differ diff --git a/pgml-dashboard/static/images/logos/openhermes.webp b/pgml-dashboard/static/images/logos/openhermes.webp new file mode 100644 index 000000000..3c202681e Binary files /dev/null and b/pgml-dashboard/static/images/logos/openhermes.webp differ diff --git a/pgml-dashboard/static/images/logos/starling.webp b/pgml-dashboard/static/images/logos/starling.webp new file mode 100644 index 000000000..988696b14 Binary files /dev/null and b/pgml-dashboard/static/images/logos/starling.webp differ pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy