From 635476ca6cf9787b2ae642ec8688cce4e3e3ec77 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Wed, 11 Oct 2023 19:28:40 +0000 Subject: [PATCH 1/9] add vllm binding --- pgml-extension/requirements.txt | 1 + pgml-extension/src/bindings/mod.rs | 2 + pgml-extension/src/bindings/vllm/mod.rs | 250 ++++++++++++++++++++++++ 3 files changed, 253 insertions(+) create mode 100644 pgml-extension/src/bindings/vllm/mod.rs diff --git a/pgml-extension/requirements.txt b/pgml-extension/requirements.txt index 89b3c2742..09b33732d 100644 --- a/pgml-extension/requirements.txt +++ b/pgml-extension/requirements.txt @@ -26,3 +26,4 @@ xgboost==2.0.0 langchain==0.0.287 einops==0.6.1 pynvml==11.5.0 +vllm==0.2.0 \ No newline at end of file diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 13702106d..e380c290c 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -46,6 +46,8 @@ pub mod python; pub mod sklearn; #[cfg(feature = "python")] pub mod transformers; +#[cfg(feature = "python")] +pub mod vllm; pub mod xgboost; pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Result>; diff --git a/pgml-extension/src/bindings/vllm/mod.rs b/pgml-extension/src/bindings/vllm/mod.rs new file mode 100644 index 000000000..39afcf6aa --- /dev/null +++ b/pgml-extension/src/bindings/vllm/mod.rs @@ -0,0 +1,250 @@ +use std::fmt; + +use anyhow::{anyhow, Result}; +use pgrx::prelude::*; +use pyo3::{prelude::*, types::PyDict}; + +use super::TracebackError; + +pub struct LLMBuilder { + model: String, + tokenizer: Option, + tokenizer_mode: TokenizerMode, + trust_remote_code: bool, + tensor_parallel_size: u8, + dtype: String, + quantization: Option, + revision: Option, + seed: u64, + gpu_memory_utilization: f32, + swap_space: u32, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum TokenizerMode { + Auto, + Slow, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum Quantization { + Awq, +} + +pub struct LLM { + inner: Py, +} + +impl LLMBuilder { + /// Create a builder for a model with the name or path of a HuggingFace Transformers model. + pub fn new(model: &str) -> Self { + Self { + model: model.to_string(), + tokenizer: None, + tokenizer_mode: TokenizerMode::Auto, + trust_remote_code: false, + tensor_parallel_size: 1, + dtype: "auto".to_string(), + quantization: None, + revision: None, + seed: 0, + gpu_memory_utilization: 0.9, + swap_space: 4, + } + } + + /// The name or path of a HuggingFace Transformers tokenizer. + pub fn tokenizer(mut self, tokenizer: &str) -> Self { + self.tokenizer = Some(tokenizer.to_string()); + self + } + + /// The tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. + pub fn tokenizer_mode(mut self, tokenizer_mode: TokenizerMode) -> Self { + self.tokenizer_mode = tokenizer_mode; + self + } + + /// Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + pub fn trust_remote_code(mut self, trust_remote_code: bool) -> Self { + self.trust_remote_code = trust_remote_code; + self + } + + /// The number of GPUs to use for distributed execution with tensor parallelism. + pub fn tensor_parallel_size(mut self, tensor_parallel_size: u8) -> Self { + self.tensor_parallel_size = tensor_parallel_size; + self + } + + /// The data type for the model weights and activations. Currently, + /// we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + /// the `torch_dtype` attribute specified in the model config file. + /// However, if the `torch_dtype` in the config is `float32`, we will + /// use `float16` instead. + pub fn dtype(mut self, dtype: &str) -> Self { + self.dtype = dtype.to_string(); + self + } + + /// The method used to quantize the model weights. Currently, + /// we support "awq". If None, we assume the model weights are not + /// quantized and use `dtype` to determine the data type of the weights. + pub fn quantization(mut self, quantization: Quantization) -> Self { + self.quantization = Some(quantization); + self + } + + /// The specific model version to use. It can be a branch name, + /// a tag name, or a commit id. + pub fn revision(mut self, revision: &str) -> Self { + self.revision = Some(revision.to_string()); + self + } + + /// The seed to initialize the random number generator for sampling. + pub fn seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } + + /// The ratio (between 0 and 1) of GPU memory to + /// reserve for the model weights, activations, and KV cache. Higher + /// values will increase the KV cache size and thus improve the model's + /// throughput. However, if the value is too high, it may cause out-of- + /// memory (OOM) errors. + pub fn gpu_memory_utilization(mut self, gpu_memory_utilization: f32) -> Self { + self.gpu_memory_utilization = gpu_memory_utilization; + self + } + + /// The size (GiB) of CPU memory per GPU to use as swap space. + /// This can be used for temporarily storing the states of the requests + /// when their `best_of` sampling parameters are larger than 1. If all + /// requests will have `best_of=1`, you can safely set this to 0. + /// Otherwise, too small values may cause out-of-memory (OOM) errors. + pub fn swap_space(mut self, swap_space: u32) -> Self { + self.swap_space = swap_space; + self + } + + /// Create a [`LLM`] from the [`LLMBuilder`] + pub fn build(self) -> Result { + let inner = Python::with_gil(|py| -> Result> { + let kwargs = PyDict::new(py); + kwargs.set_item("model", self.model)?; + kwargs.set_item("tokenizer", self.tokenizer)?; + kwargs.set_item("tokenizer_mode", self.tokenizer_mode.to_string())?; + kwargs.set_item("trust_remote_code", self.trust_remote_code)?; + kwargs.set_item("tensor_parallel_size", self.tensor_parallel_size)?; + kwargs.set_item("dtype", self.dtype)?; + kwargs.set_item("quantization", self.quantization.map(|q| q.to_string()))?; + kwargs.set_item("revision", self.revision)?; + kwargs.set_item("seed", self.seed)?; + kwargs.set_item("gpu_memory_utilization", self.gpu_memory_utilization)?; + kwargs.set_item("swap_space", self.swap_space)?; + + let vllm = PyModule::import(py, "vllm").format_traceback(py)?; + vllm.getattr("LLM") + .format_traceback(py)? + .call((), Some(kwargs)) + .format_traceback(py)? + .extract() + .format_traceback(py) + })?; + + Ok(LLM { inner }) + } +} + +impl LLM { + /// Create an LLM for a model with the name or path of a HuggingFace Transformers model. + pub fn new(model: &str) -> Result { + LLMBuilder::new(model).build() + } + + /// Generates the completions for the input prompts. + /// + /// ### NOTE + /// This automatically batches the given prompts, considering + /// the memory constraint. For the best performance, put all of your prompts + /// into a single list and pass it to this method. + pub fn generate(&self, prompts: &[&str]) -> Result> { + let prompts: Vec<_> = prompts.iter().map(|s| s.to_string()).collect(); + Python::with_gil(|py| { + let outputs: Vec> = self + .inner + .getattr(py, "generate") + .format_traceback(py)? + .call1(py, (prompts.into_py(py),)) + .format_traceback(py)? + .extract(py) + .format_traceback(py)?; + + outputs + .iter() + .map(|output| -> Result { + let outputs: Vec> = output + .getattr(py, "outputs") + .format_traceback(py)? + .extract(py) + .format_traceback(py)?; + outputs + .first() + .ok_or_else(|| anyhow!("vllm output.outputs[] empty"))? + .getattr(py, "text") + .format_traceback(py)? + .extract(py) + .format_traceback(py) + }) + .collect::>>() + }) + } +} + +impl fmt::Display for TokenizerMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + TokenizerMode::Auto => "auto", + TokenizerMode::Slow => "slow", + } + ) + } +} + +impl fmt::Display for Quantization { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + Quantization::Awq => "awq", + } + ) + } +} + +#[cfg(any(test, feature = "pg_test"))] +#[pg_schema] +mod tests { + use super::*; + + #[pg_test] + fn vllm_quickstart() { + crate::bindings::python::activate().unwrap(); + + // quickstart example from https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html + let prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ]; + let llm = LLMBuilder::new("facebook/opt-125m").build().unwrap(); + let outputs = llm.generate(&prompts).unwrap(); + assert_eq!(prompts.len(), outputs.len()); + } +} From 9360ef7a80344b9082a46d4530f4f3cfd02fb588 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Wed, 11 Oct 2023 20:42:26 +0000 Subject: [PATCH 2/9] add vllm SamplingParams --- pgml-extension/src/bindings/vllm/mod.rs | 76 +++++-- pgml-extension/src/bindings/vllm/params.rs | 252 +++++++++++++++++++++ 2 files changed, 308 insertions(+), 20 deletions(-) create mode 100644 pgml-extension/src/bindings/vllm/params.rs diff --git a/pgml-extension/src/bindings/vllm/mod.rs b/pgml-extension/src/bindings/vllm/mod.rs index 39afcf6aa..42e323f27 100644 --- a/pgml-extension/src/bindings/vllm/mod.rs +++ b/pgml-extension/src/bindings/vllm/mod.rs @@ -6,6 +6,9 @@ use pyo3::{prelude::*, types::PyDict}; use super::TracebackError; +mod params; +pub use params::*; + pub struct LLMBuilder { model: String, tokenizer: Option, @@ -16,7 +19,7 @@ pub struct LLMBuilder { quantization: Option, revision: Option, seed: u64, - gpu_memory_utilization: f32, + gpu_memory_utilization: f64, swap_space: u32, } @@ -32,7 +35,7 @@ pub enum Quantization { } pub struct LLM { - inner: Py, + inner: PyObject, } impl LLMBuilder { @@ -113,7 +116,7 @@ impl LLMBuilder { /// values will increase the KV cache size and thus improve the model's /// throughput. However, if the value is too high, it may cause out-of- /// memory (OOM) errors. - pub fn gpu_memory_utilization(mut self, gpu_memory_utilization: f32) -> Self { + pub fn gpu_memory_utilization(mut self, gpu_memory_utilization: f64) -> Self { self.gpu_memory_utilization = gpu_memory_utilization; self } @@ -130,19 +133,35 @@ impl LLMBuilder { /// Create a [`LLM`] from the [`LLMBuilder`] pub fn build(self) -> Result { - let inner = Python::with_gil(|py| -> Result> { + let inner = Python::with_gil(|py| -> Result { let kwargs = PyDict::new(py); - kwargs.set_item("model", self.model)?; - kwargs.set_item("tokenizer", self.tokenizer)?; - kwargs.set_item("tokenizer_mode", self.tokenizer_mode.to_string())?; - kwargs.set_item("trust_remote_code", self.trust_remote_code)?; - kwargs.set_item("tensor_parallel_size", self.tensor_parallel_size)?; - kwargs.set_item("dtype", self.dtype)?; - kwargs.set_item("quantization", self.quantization.map(|q| q.to_string()))?; - kwargs.set_item("revision", self.revision)?; - kwargs.set_item("seed", self.seed)?; - kwargs.set_item("gpu_memory_utilization", self.gpu_memory_utilization)?; - kwargs.set_item("swap_space", self.swap_space)?; + kwargs.set_item("model", self.model).format_traceback(py)?; + kwargs + .set_item("tokenizer", self.tokenizer) + .format_traceback(py)?; + kwargs + .set_item("tokenizer_mode", self.tokenizer_mode.to_string()) + .format_traceback(py)?; + kwargs + .set_item("trust_remote_code", self.trust_remote_code) + .format_traceback(py)?; + kwargs + .set_item("tensor_parallel_size", self.tensor_parallel_size) + .format_traceback(py)?; + kwargs.set_item("dtype", self.dtype).format_traceback(py)?; + kwargs + .set_item("quantization", self.quantization.map(|q| q.to_string())) + .format_traceback(py)?; + kwargs + .set_item("revision", self.revision) + .format_traceback(py)?; + kwargs.set_item("seed", self.seed).format_traceback(py)?; + kwargs + .set_item("gpu_memory_utilization", self.gpu_memory_utilization) + .format_traceback(py)?; + kwargs + .set_item("swap_space", self.swap_space) + .format_traceback(py)?; let vllm = PyModule::import(py, "vllm").format_traceback(py)?; vllm.getattr("LLM") @@ -169,14 +188,25 @@ impl LLM { /// This automatically batches the given prompts, considering /// the memory constraint. For the best performance, put all of your prompts /// into a single list and pass it to this method. - pub fn generate(&self, prompts: &[&str]) -> Result> { + pub fn generate( + &self, + prompts: &[&str], + params: Option<&SamplingParams>, + ) -> Result> { let prompts: Vec<_> = prompts.iter().map(|s| s.to_string()).collect(); + Python::with_gil(|py| { - let outputs: Vec> = self + let kwargs = PyDict::new(py); + kwargs.set_item("prompts", prompts).format_traceback(py)?; + kwargs + .set_item("sampling_params", params) + .format_traceback(py)?; + + let outputs: Vec = self .inner .getattr(py, "generate") .format_traceback(py)? - .call1(py, (prompts.into_py(py),)) + .call(py, (), Some(kwargs)) .format_traceback(py)? .extract(py) .format_traceback(py)?; @@ -184,7 +214,7 @@ impl LLM { outputs .iter() .map(|output| -> Result { - let outputs: Vec> = output + let outputs: Vec = output .getattr(py, "outputs") .format_traceback(py)? .extract(py) @@ -243,8 +273,14 @@ mod tests { "The capital of France is", "The future of AI is", ]; + let sampling_params = SamplingParamsBuilder::new() + .temperature(0.8) + .top_p(0.95) + .build() + .unwrap(); + let llm = LLMBuilder::new("facebook/opt-125m").build().unwrap(); - let outputs = llm.generate(&prompts).unwrap(); + let outputs = llm.generate(&prompts, Some(&sampling_params)).unwrap(); assert_eq!(prompts.len(), outputs.len()); } } diff --git a/pgml-extension/src/bindings/vllm/params.rs b/pgml-extension/src/bindings/vllm/params.rs new file mode 100644 index 000000000..873d3157e --- /dev/null +++ b/pgml-extension/src/bindings/vllm/params.rs @@ -0,0 +1,252 @@ +use anyhow::Result; +use pyo3::{prelude::*, types::PyDict}; + +use crate::bindings::TracebackError; + +#[derive(Debug, Clone)] +pub struct SamplingParamsBuilder { + n: usize, + best_of: Option, + presence_penalty: f64, + frequency_penalty: f64, + temperature: f64, + top_p: f64, + top_k: i32, + use_beam_search: bool, + length_penalty: f64, + early_stopping: EarlyStopping, + stop: Option>, + stop_token_ids: Option>, + ignore_eos: bool, + max_tokens: usize, + logprobs: Option, + skip_special_tokens: bool, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum EarlyStopping { + True, + False, + Never, +} + +pub struct SamplingParams { + inner: PyObject, +} + +impl SamplingParamsBuilder { + pub fn new() -> Self { + Self { + n: 1, + best_of: None, + presence_penalty: 0.0, + frequency_penalty: 0.0, + temperature: 1.0, + top_p: 1.0, + top_k: -1, + use_beam_search: false, + length_penalty: 1.0, + early_stopping: EarlyStopping::False, + stop: None, + stop_token_ids: None, + ignore_eos: false, + max_tokens: 16, + logprobs: None, + skip_special_tokens: true, + } + } + + /// Number of output sequences to return for the given prompt. + pub fn n(mut self, n: usize) -> Self { + self.n = n; + self + } + + /// Number of output sequences that are generated from the prompt. + /// From these `best_of` sequences, the top `n` sequences are returned. + /// `best_of` must be greater than or equal to `n`. This is treated as + /// the beam width when `use_beam_search` is true. By default, `best_of` + /// is set to `n`. + pub fn best_of(mut self, best_of: usize) -> Self { + self.best_of = Some(best_of); + self + } + + /// Float that penalizes new tokens based on whether they + /// appear in the generated text so far. Values > 0 encourage the model + /// to use new tokens, while values < 0 encourage the model to repeat + /// tokens. + pub fn presence_penalty(mut self, presence_penalty: f64) -> Self { + self.presence_penalty = presence_penalty; + self + } + + /// Float that penalizes new tokens based on their + /// frequency in the generated text so far. Values > 0 encourage the + /// model to use new tokens, while values < 0 encourage the model to + /// repeat tokens. + pub fn frequency_penalty(mut self, frequency_penalty: f64) -> Self { + self.frequency_penalty = frequency_penalty; + self + } + + /// Float that controls the randomness of the sampling. Lower + /// values make the model more deterministic, while higher values make + /// the model more random. Zero means greedy sampling. + pub fn temperature(mut self, temperature: f64) -> Self { + self.temperature = temperature; + self + } + + /// Float that controls the cumulative probability of the top tokens + /// to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + pub fn top_p(mut self, top_p: f64) -> Self { + self.top_p = top_p; + self + } + + /// Integer that controls the number of top tokens to consider. Set + /// to -1 to consider all tokens. + pub fn top_k(mut self, top_k: i32) -> Self { + self.top_k = top_k; + self + } + + /// Whether to use beam search instead of sampling. + pub fn use_beam_search(mut self, use_beam_search: bool) -> Self { + self.use_beam_search = use_beam_search; + self + } + + /// Float that penalizes sequences based on their length. + /// Used in beam search. + pub fn length_penalty(mut self, length_penalty: f64) -> Self { + self.length_penalty = length_penalty; + self + } + + /// Controls the stopping condition for beam search. It + /// accepts the following values: `true`, where the generation stops as + /// soon as there are `best_of` complete candidates; `false`, where an + /// heuristic is applied and the generation stops when is it very + /// unlikely to find better candidates; `"never"`, where the beam search + /// procedure only stops when there cannot be better candidates + /// (canonical beam search algorithm). + pub fn early_stopping(mut self, early_stopping: EarlyStopping) -> Self { + self.early_stopping = early_stopping; + self + } + + /// List of [`String`]s that stop the generation when they are generated. + /// The returned output will not contain the stop [`String`]s. + pub fn stop(mut self, stop: &[&str]) -> Self { + self.stop = Some(stop.iter().map(|s| s.to_string()).collect()); + self + } + + /// List of tokens that stop the generation when they are + /// generated. The returned output will contain the stop tokens unless + /// the stop tokens are sepcial tokens. + pub fn stop_token_ids(mut self, stop_token_ids: Vec) -> Self { + self.stop_token_ids = Some(stop_token_ids); + self + } + + /// Whether to ignore the EOS token and continue generating + /// tokens after the EOS token is generated. + pub fn ignore_eos(mut self, ignore_eos: bool) -> Self { + self.ignore_eos = ignore_eos; + self + } + + /// Maximum number of tokens to generate per output sequence. + pub fn max_tokens(mut self, max_tokens: usize) -> Self { + self.max_tokens = max_tokens; + self + } + + /// Number of log probabilities to return per output token. + pub fn logprobs(mut self, logprobs: usize) -> Self { + self.logprobs = Some(logprobs); + self + } + + /// Whether to skip special tokens in the output. + /// Defaults to true. + pub fn skip_special_tokens(mut self, skip_special_tokens: bool) -> Self { + self.skip_special_tokens = skip_special_tokens; + self + } + + pub fn build(self) -> Result { + let inner = Python::with_gil(|py| -> Result { + let kwargs = PyDict::new(py); + kwargs.set_item("n", self.n).format_traceback(py)?; + kwargs + .set_item("best_of", self.best_of) + .format_traceback(py)?; + kwargs + .set_item("presence_penalty", self.presence_penalty) + .format_traceback(py)?; + kwargs + .set_item("frequency_penalty", self.frequency_penalty) + .format_traceback(py)?; + kwargs + .set_item("temperature", self.temperature) + .format_traceback(py)?; + kwargs.set_item("top_p", self.top_p).format_traceback(py)?; + kwargs.set_item("top_k", self.top_k).format_traceback(py)?; + kwargs + .set_item("use_beam_search", self.use_beam_search) + .format_traceback(py)?; + kwargs + .set_item("length_penalty", self.length_penalty) + .format_traceback(py)?; + kwargs + .set_item("early_stopping", self.early_stopping) + .format_traceback(py)?; + kwargs.set_item("stop", self.stop).format_traceback(py)?; + kwargs + .set_item("stop_token_ids", self.stop_token_ids) + .format_traceback(py)?; + kwargs + .set_item("ignore_eos", self.ignore_eos) + .format_traceback(py)?; + kwargs + .set_item("max_tokens", self.max_tokens) + .format_traceback(py)?; + kwargs + .set_item("logprobs", self.logprobs) + .format_traceback(py)?; + kwargs + .set_item("skip_special_tokens", self.skip_special_tokens) + .format_traceback(py)?; + + let vllm = PyModule::import(py, "vllm").format_traceback(py)?; + vllm.getattr("SamplingParams") + .format_traceback(py)? + .call((), Some(kwargs)) + .format_traceback(py)? + .extract() + .format_traceback(py) + })?; + + Ok(SamplingParams { inner }) + } +} + +impl ToPyObject for EarlyStopping { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + EarlyStopping::True => true.into_py(py), + EarlyStopping::False => false.into_py(py), + EarlyStopping::Never => "never".into_py(py), + } + } +} + +impl ToPyObject for SamplingParams { + fn to_object(&self, _py: Python<'_>) -> PyObject { + self.inner.clone() + } +} From 8be071029b1eaec2e0397728df7b91865fcfcc1d Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Wed, 11 Oct 2023 20:56:40 +0000 Subject: [PATCH 3/9] add test showing vllm model support check --- pgml-extension/src/bindings/vllm/mod.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pgml-extension/src/bindings/vllm/mod.rs b/pgml-extension/src/bindings/vllm/mod.rs index 42e323f27..bfb45cd43 100644 --- a/pgml-extension/src/bindings/vllm/mod.rs +++ b/pgml-extension/src/bindings/vllm/mod.rs @@ -263,6 +263,7 @@ mod tests { use super::*; #[pg_test] + #[ignore = "requires model download"] fn vllm_quickstart() { crate::bindings::python::activate().unwrap(); @@ -283,4 +284,11 @@ mod tests { let outputs = llm.generate(&prompts, Some(&sampling_params)).unwrap(); assert_eq!(prompts.len(), outputs.len()); } + + #[pg_test] + fn model_support() { + if let Err(e) = LLMBuilder::new("intfloat/e5-small").build() { + assert!(e.to_string().contains("not supported")); + } + } } From b212ee0698c954e77dd7b9470e092eb8165acb21 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Fri, 13 Oct 2023 20:08:31 +0000 Subject: [PATCH 4/9] refactor into llm module, use PyResult --- pgml-extension/src/bindings/vllm/llm.rs | 250 +++++++++++++++++ pgml-extension/src/bindings/vllm/mod.rs | 295 +-------------------- pgml-extension/src/bindings/vllm/params.rs | 74 ++---- 3 files changed, 276 insertions(+), 343 deletions(-) create mode 100644 pgml-extension/src/bindings/vllm/llm.rs diff --git a/pgml-extension/src/bindings/vllm/llm.rs b/pgml-extension/src/bindings/vllm/llm.rs new file mode 100644 index 000000000..0a32b0329 --- /dev/null +++ b/pgml-extension/src/bindings/vllm/llm.rs @@ -0,0 +1,250 @@ +use pyo3::{prelude::*, types::PyDict}; + +use super::SamplingParams; + +pub struct LLMBuilder { + model: String, + tokenizer: Option, + tokenizer_mode: TokenizerMode, + trust_remote_code: bool, + tensor_parallel_size: u8, + dtype: String, + quantization: Option, + revision: Option, + seed: u64, + gpu_memory_utilization: f64, + swap_space: u32, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum TokenizerMode { + Auto, + Slow, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum Quantization { + Awq, +} + +pub struct LLM { + inner: PyObject, +} + +impl LLMBuilder { + /// Create a builder for a model with the name or path of a HuggingFace + /// Transformers model. + pub fn new(model: &str) -> Self { + Self { + model: model.to_string(), + tokenizer: None, + tokenizer_mode: TokenizerMode::Auto, + trust_remote_code: false, + tensor_parallel_size: 1, + dtype: "auto".to_string(), + quantization: None, + revision: None, + seed: 0, + gpu_memory_utilization: 0.9, + swap_space: 4, + } + } + + /// The name or path of a HuggingFace Transformers tokenizer. + pub fn tokenizer(mut self, tokenizer: &str) -> Self { + self.tokenizer = Some(tokenizer.to_string()); + self + } + + /// The tokenizer mode. "auto" will use the fast tokenizer if available, and + /// "slow" will always use the slow tokenizer. + pub fn tokenizer_mode(mut self, tokenizer_mode: TokenizerMode) -> Self { + self.tokenizer_mode = tokenizer_mode; + self + } + + /// Trust remote code (e.g., from HuggingFace) when downloading the model + /// and tokenizer. + pub fn trust_remote_code(mut self, trust_remote_code: bool) -> Self { + self.trust_remote_code = trust_remote_code; + self + } + + /// The number of GPUs to use for distributed execution with tensor + /// parallelism. + pub fn tensor_parallel_size(mut self, tensor_parallel_size: u8) -> Self { + self.tensor_parallel_size = tensor_parallel_size; + self + } + + /// The data type for the model weights and activations. Currently, + /// we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + /// the `torch_dtype` attribute specified in the model config file. + /// However, if the `torch_dtype` in the config is `float32`, we will + /// use `float16` instead. + pub fn dtype(mut self, dtype: &str) -> Self { + self.dtype = dtype.to_string(); + self + } + + /// The method used to quantize the model weights. Currently, + /// we support "awq". If None, we assume the model weights are not + /// quantized and use `dtype` to determine the data type of the weights. + pub fn quantization(mut self, quantization: Quantization) -> Self { + self.quantization = Some(quantization); + self + } + + /// The specific model version to use. It can be a branch name, + /// a tag name, or a commit id. + pub fn revision(mut self, revision: &str) -> Self { + self.revision = Some(revision.to_string()); + self + } + + /// The seed to initialize the random number generator for sampling. + pub fn seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } + + /// The ratio (between 0 and 1) of GPU memory to + /// reserve for the model weights, activations, and KV cache. Higher + /// values will increase the KV cache size and thus improve the model's + /// throughput. However, if the value is too high, it may cause out-of- + /// memory (OOM) errors. + pub fn gpu_memory_utilization(mut self, gpu_memory_utilization: f64) -> Self { + self.gpu_memory_utilization = gpu_memory_utilization; + self + } + + /// The size (GiB) of CPU memory per GPU to use as swap space. + /// This can be used for temporarily storing the states of the requests + /// when their `best_of` sampling parameters are larger than 1. If all + /// requests will have `best_of=1`, you can safely set this to 0. + /// Otherwise, too small values may cause out-of-memory (OOM) errors. + pub fn swap_space(mut self, swap_space: u32) -> Self { + self.swap_space = swap_space; + self + } + + /// Create a [`LLM`] from the [`LLMBuilder`] + pub fn build(self) -> PyResult { + let inner = Python::with_gil(|py| -> PyResult { + let kwargs = PyDict::new(py); + kwargs.set_item("model", self.model)?; + kwargs.set_item("tokenizer", self.tokenizer)?; + kwargs.set_item("tokenizer_mode", self.tokenizer_mode)?; + kwargs.set_item("trust_remote_code", self.trust_remote_code)?; + kwargs.set_item("tensor_parallel_size", self.tensor_parallel_size)?; + kwargs.set_item("dtype", self.dtype)?; + kwargs.set_item("quantization", self.quantization)?; + kwargs.set_item("revision", self.revision)?; + kwargs.set_item("seed", self.seed)?; + kwargs.set_item("gpu_memory_utilization", self.gpu_memory_utilization)?; + kwargs.set_item("swap_space", self.swap_space)?; + + let vllm = PyModule::import(py, "vllm")?; + vllm.getattr("LLM")?.call((), Some(kwargs))?.extract() + })?; + + Ok(LLM { inner }) + } +} + +impl LLM { + /// Create an LLM for a model with the name or path of a HuggingFace + /// Transformers model. + pub fn new(model: &str) -> PyResult { + LLMBuilder::new(model).build() + } + + /// Generates the completions for the input prompts. + /// + /// ### NOTE + /// This automatically batches the given prompts, considering the memory + /// constraint. For the best performance, put all of your prompts into a + /// single list and pass it to this method. + pub fn generate( + &self, + prompts: &[&str], + params: Option<&SamplingParams>, + ) -> PyResult> { + let prompts: Vec<_> = prompts.iter().map(|s| s.to_string()).collect(); + + Python::with_gil(|py| { + let kwargs = PyDict::new(py); + kwargs.set_item("prompts", prompts)?; + kwargs.set_item("sampling_params", params)?; + + let outputs: Vec = self + .inner + .getattr(py, "generate")? + .call(py, (), Some(kwargs))? + .extract(py)?; + + outputs + .iter() + .map(|output| -> PyResult { + let outputs: Vec = output.getattr(py, "outputs")?.extract(py)?; + outputs.first().unwrap().getattr(py, "text")?.extract(py) + }) + .collect::>>() + }) + } +} + +impl ToPyObject for TokenizerMode { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + TokenizerMode::Auto => "auto".to_string(), + TokenizerMode::Slow => "slow".to_string(), + } + .into_py(py) + } +} + +impl ToPyObject for Quantization { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + Quantization::Awq => "awg".to_string(), + } + .into_py(py) + } +} + +#[cfg(test)] +mod tests { + use crate::SamplingParamsBuilder; + + use super::*; + + #[test] + #[ignore = "requires model download"] + fn vllm_quickstart() { + // quickstart example from https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html + let prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ]; + let sampling_params = SamplingParamsBuilder::new() + .temperature(0.8) + .top_p(0.95) + .build() + .unwrap(); + + let llm = LLMBuilder::new("facebook/opt-125m").build().unwrap(); + let outputs = llm.generate(&prompts, Some(&sampling_params)).unwrap(); + assert_eq!(prompts.len(), outputs.len()); + } + + #[test] + #[ignore = "requires model download"] + fn model_support() { + if let Err(e) = LLMBuilder::new("intfloat/e5-small").build() { + assert!(e.to_string().contains("not supported")); + } + } +} diff --git a/pgml-extension/src/bindings/vllm/mod.rs b/pgml-extension/src/bindings/vllm/mod.rs index bfb45cd43..3544f85b0 100644 --- a/pgml-extension/src/bindings/vllm/mod.rs +++ b/pgml-extension/src/bindings/vllm/mod.rs @@ -1,294 +1,7 @@ -use std::fmt; - -use anyhow::{anyhow, Result}; -use pgrx::prelude::*; -use pyo3::{prelude::*, types::PyDict}; - -use super::TracebackError; +//! Rust bindings to the Python package `vllm`. +mod llm; mod params; -pub use params::*; - -pub struct LLMBuilder { - model: String, - tokenizer: Option, - tokenizer_mode: TokenizerMode, - trust_remote_code: bool, - tensor_parallel_size: u8, - dtype: String, - quantization: Option, - revision: Option, - seed: u64, - gpu_memory_utilization: f64, - swap_space: u32, -} - -#[derive(Debug, PartialEq, Eq, Copy, Clone)] -pub enum TokenizerMode { - Auto, - Slow, -} - -#[derive(Debug, PartialEq, Eq, Copy, Clone)] -pub enum Quantization { - Awq, -} - -pub struct LLM { - inner: PyObject, -} - -impl LLMBuilder { - /// Create a builder for a model with the name or path of a HuggingFace Transformers model. - pub fn new(model: &str) -> Self { - Self { - model: model.to_string(), - tokenizer: None, - tokenizer_mode: TokenizerMode::Auto, - trust_remote_code: false, - tensor_parallel_size: 1, - dtype: "auto".to_string(), - quantization: None, - revision: None, - seed: 0, - gpu_memory_utilization: 0.9, - swap_space: 4, - } - } - - /// The name or path of a HuggingFace Transformers tokenizer. - pub fn tokenizer(mut self, tokenizer: &str) -> Self { - self.tokenizer = Some(tokenizer.to_string()); - self - } - - /// The tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. - pub fn tokenizer_mode(mut self, tokenizer_mode: TokenizerMode) -> Self { - self.tokenizer_mode = tokenizer_mode; - self - } - - /// Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. - pub fn trust_remote_code(mut self, trust_remote_code: bool) -> Self { - self.trust_remote_code = trust_remote_code; - self - } - - /// The number of GPUs to use for distributed execution with tensor parallelism. - pub fn tensor_parallel_size(mut self, tensor_parallel_size: u8) -> Self { - self.tensor_parallel_size = tensor_parallel_size; - self - } - - /// The data type for the model weights and activations. Currently, - /// we support `float32`, `float16`, and `bfloat16`. If `auto`, we use - /// the `torch_dtype` attribute specified in the model config file. - /// However, if the `torch_dtype` in the config is `float32`, we will - /// use `float16` instead. - pub fn dtype(mut self, dtype: &str) -> Self { - self.dtype = dtype.to_string(); - self - } - - /// The method used to quantize the model weights. Currently, - /// we support "awq". If None, we assume the model weights are not - /// quantized and use `dtype` to determine the data type of the weights. - pub fn quantization(mut self, quantization: Quantization) -> Self { - self.quantization = Some(quantization); - self - } - - /// The specific model version to use. It can be a branch name, - /// a tag name, or a commit id. - pub fn revision(mut self, revision: &str) -> Self { - self.revision = Some(revision.to_string()); - self - } - - /// The seed to initialize the random number generator for sampling. - pub fn seed(mut self, seed: u64) -> Self { - self.seed = seed; - self - } - - /// The ratio (between 0 and 1) of GPU memory to - /// reserve for the model weights, activations, and KV cache. Higher - /// values will increase the KV cache size and thus improve the model's - /// throughput. However, if the value is too high, it may cause out-of- - /// memory (OOM) errors. - pub fn gpu_memory_utilization(mut self, gpu_memory_utilization: f64) -> Self { - self.gpu_memory_utilization = gpu_memory_utilization; - self - } - /// The size (GiB) of CPU memory per GPU to use as swap space. - /// This can be used for temporarily storing the states of the requests - /// when their `best_of` sampling parameters are larger than 1. If all - /// requests will have `best_of=1`, you can safely set this to 0. - /// Otherwise, too small values may cause out-of-memory (OOM) errors. - pub fn swap_space(mut self, swap_space: u32) -> Self { - self.swap_space = swap_space; - self - } - - /// Create a [`LLM`] from the [`LLMBuilder`] - pub fn build(self) -> Result { - let inner = Python::with_gil(|py| -> Result { - let kwargs = PyDict::new(py); - kwargs.set_item("model", self.model).format_traceback(py)?; - kwargs - .set_item("tokenizer", self.tokenizer) - .format_traceback(py)?; - kwargs - .set_item("tokenizer_mode", self.tokenizer_mode.to_string()) - .format_traceback(py)?; - kwargs - .set_item("trust_remote_code", self.trust_remote_code) - .format_traceback(py)?; - kwargs - .set_item("tensor_parallel_size", self.tensor_parallel_size) - .format_traceback(py)?; - kwargs.set_item("dtype", self.dtype).format_traceback(py)?; - kwargs - .set_item("quantization", self.quantization.map(|q| q.to_string())) - .format_traceback(py)?; - kwargs - .set_item("revision", self.revision) - .format_traceback(py)?; - kwargs.set_item("seed", self.seed).format_traceback(py)?; - kwargs - .set_item("gpu_memory_utilization", self.gpu_memory_utilization) - .format_traceback(py)?; - kwargs - .set_item("swap_space", self.swap_space) - .format_traceback(py)?; - - let vllm = PyModule::import(py, "vllm").format_traceback(py)?; - vllm.getattr("LLM") - .format_traceback(py)? - .call((), Some(kwargs)) - .format_traceback(py)? - .extract() - .format_traceback(py) - })?; - - Ok(LLM { inner }) - } -} - -impl LLM { - /// Create an LLM for a model with the name or path of a HuggingFace Transformers model. - pub fn new(model: &str) -> Result { - LLMBuilder::new(model).build() - } - - /// Generates the completions for the input prompts. - /// - /// ### NOTE - /// This automatically batches the given prompts, considering - /// the memory constraint. For the best performance, put all of your prompts - /// into a single list and pass it to this method. - pub fn generate( - &self, - prompts: &[&str], - params: Option<&SamplingParams>, - ) -> Result> { - let prompts: Vec<_> = prompts.iter().map(|s| s.to_string()).collect(); - - Python::with_gil(|py| { - let kwargs = PyDict::new(py); - kwargs.set_item("prompts", prompts).format_traceback(py)?; - kwargs - .set_item("sampling_params", params) - .format_traceback(py)?; - - let outputs: Vec = self - .inner - .getattr(py, "generate") - .format_traceback(py)? - .call(py, (), Some(kwargs)) - .format_traceback(py)? - .extract(py) - .format_traceback(py)?; - - outputs - .iter() - .map(|output| -> Result { - let outputs: Vec = output - .getattr(py, "outputs") - .format_traceback(py)? - .extract(py) - .format_traceback(py)?; - outputs - .first() - .ok_or_else(|| anyhow!("vllm output.outputs[] empty"))? - .getattr(py, "text") - .format_traceback(py)? - .extract(py) - .format_traceback(py) - }) - .collect::>>() - }) - } -} - -impl fmt::Display for TokenizerMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}", - match self { - TokenizerMode::Auto => "auto", - TokenizerMode::Slow => "slow", - } - ) - } -} - -impl fmt::Display for Quantization { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}", - match self { - Quantization::Awq => "awq", - } - ) - } -} - -#[cfg(any(test, feature = "pg_test"))] -#[pg_schema] -mod tests { - use super::*; - - #[pg_test] - #[ignore = "requires model download"] - fn vllm_quickstart() { - crate::bindings::python::activate().unwrap(); - - // quickstart example from https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html - let prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ]; - let sampling_params = SamplingParamsBuilder::new() - .temperature(0.8) - .top_p(0.95) - .build() - .unwrap(); - - let llm = LLMBuilder::new("facebook/opt-125m").build().unwrap(); - let outputs = llm.generate(&prompts, Some(&sampling_params)).unwrap(); - assert_eq!(prompts.len(), outputs.len()); - } - - #[pg_test] - fn model_support() { - if let Err(e) = LLMBuilder::new("intfloat/e5-small").build() { - assert!(e.to_string().contains("not supported")); - } - } -} +pub use llm::*; +pub use params::*; diff --git a/pgml-extension/src/bindings/vllm/params.rs b/pgml-extension/src/bindings/vllm/params.rs index 873d3157e..31a54006a 100644 --- a/pgml-extension/src/bindings/vllm/params.rs +++ b/pgml-extension/src/bindings/vllm/params.rs @@ -1,8 +1,5 @@ -use anyhow::Result; use pyo3::{prelude::*, types::PyDict}; -use crate::bindings::TracebackError; - #[derive(Debug, Clone)] pub struct SamplingParamsBuilder { n: usize, @@ -178,57 +175,30 @@ impl SamplingParamsBuilder { self } - pub fn build(self) -> Result { - let inner = Python::with_gil(|py| -> Result { + pub fn build(self) -> PyResult { + let inner = Python::with_gil(|py| -> PyResult { let kwargs = PyDict::new(py); - kwargs.set_item("n", self.n).format_traceback(py)?; - kwargs - .set_item("best_of", self.best_of) - .format_traceback(py)?; - kwargs - .set_item("presence_penalty", self.presence_penalty) - .format_traceback(py)?; - kwargs - .set_item("frequency_penalty", self.frequency_penalty) - .format_traceback(py)?; - kwargs - .set_item("temperature", self.temperature) - .format_traceback(py)?; - kwargs.set_item("top_p", self.top_p).format_traceback(py)?; - kwargs.set_item("top_k", self.top_k).format_traceback(py)?; - kwargs - .set_item("use_beam_search", self.use_beam_search) - .format_traceback(py)?; - kwargs - .set_item("length_penalty", self.length_penalty) - .format_traceback(py)?; - kwargs - .set_item("early_stopping", self.early_stopping) - .format_traceback(py)?; - kwargs.set_item("stop", self.stop).format_traceback(py)?; - kwargs - .set_item("stop_token_ids", self.stop_token_ids) - .format_traceback(py)?; - kwargs - .set_item("ignore_eos", self.ignore_eos) - .format_traceback(py)?; - kwargs - .set_item("max_tokens", self.max_tokens) - .format_traceback(py)?; - kwargs - .set_item("logprobs", self.logprobs) - .format_traceback(py)?; - kwargs - .set_item("skip_special_tokens", self.skip_special_tokens) - .format_traceback(py)?; - - let vllm = PyModule::import(py, "vllm").format_traceback(py)?; - vllm.getattr("SamplingParams") - .format_traceback(py)? - .call((), Some(kwargs)) - .format_traceback(py)? + kwargs.set_item("n", self.n)?; + kwargs.set_item("best_of", self.best_of)?; + kwargs.set_item("presence_penalty", self.presence_penalty)?; + kwargs.set_item("frequency_penalty", self.frequency_penalty)?; + kwargs.set_item("temperature", self.temperature)?; + kwargs.set_item("top_p", self.top_p)?; + kwargs.set_item("top_k", self.top_k)?; + kwargs.set_item("use_beam_search", self.use_beam_search)?; + kwargs.set_item("length_penalty", self.length_penalty)?; + kwargs.set_item("early_stopping", self.early_stopping)?; + kwargs.set_item("stop", self.stop)?; + kwargs.set_item("stop_token_ids", self.stop_token_ids)?; + kwargs.set_item("ignore_eos", self.ignore_eos)?; + kwargs.set_item("max_tokens", self.max_tokens)?; + kwargs.set_item("logprobs", self.logprobs)?; + kwargs.set_item("skip_special_tokens", self.skip_special_tokens)?; + + let vllm = PyModule::import(py, "vllm")?; + vllm.getattr("SamplingParams")? + .call((), Some(kwargs))? .extract() - .format_traceback(py) })?; Ok(SamplingParams { inner }) From 746953e49786578c828bd59e66d880beba5c6030 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Fri, 13 Oct 2023 21:17:03 +0000 Subject: [PATCH 5/9] add vLLM to the transform API --- pgml-extension/src/api.rs | 33 ++++++++++++++++++++++++- pgml-extension/src/bindings/vllm/llm.rs | 23 +++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index ad952e485..989f4b73c 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -2,12 +2,15 @@ use std::fmt::Write; use std::str::FromStr; use ndarray::Zip; +use once_cell::sync::OnceCell; use pgrx::iter::{SetOfIterator, TableIterator}; use pgrx::*; +use serde_json::Value; #[cfg(feature = "python")] use serde_json::json; +use crate::bindings::vllm::{LLMBuilder, LLM}; #[cfg(feature = "python")] use crate::orm::*; @@ -610,7 +613,7 @@ pub fn transform_json( inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), cache: default!(bool, false), ) -> JsonB { - match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { + match transform(task.0, args.0, inputs) { Ok(output) => JsonB(output), Err(e) => error!("{e}"), } @@ -632,6 +635,34 @@ pub fn transform_string( } } +fn transform(mut task: Value, args: Value, inputs: Vec<&str>) -> anyhow::Result { + // use vLLM if model present in task and backend is set to vllm + let use_vllm = task.as_object_mut().is_some_and(|obj| { + obj.contains_key("model") && matches!(obj.get("backend"), Some(Value::String(backend)) if backend.to_string().to_ascii_lowercase() == "vllm") + }); + + if use_vllm { + crate::bindings::python::activate().unwrap(); + + static LAZY_LLM: OnceCell = OnceCell::new(); + let llm = LAZY_LLM.get_or_init(move || { + let builder = match LLMBuilder::try_from(task) { + Ok(b) => b, + Err(e) => error!("{e}"), + }; + builder.build().unwrap() + }); + + Ok(json!(llm.generate(&inputs, None)?)) + } else { + if let Some(map) = task.as_object_mut() { + // pop backend keyword, if present + let _ = map.remove("backend"); + } + crate::bindings::transformers::transform(&task, &args, inputs) + } +} + #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String { diff --git a/pgml-extension/src/bindings/vllm/llm.rs b/pgml-extension/src/bindings/vllm/llm.rs index 0a32b0329..0af17a615 100644 --- a/pgml-extension/src/bindings/vllm/llm.rs +++ b/pgml-extension/src/bindings/vllm/llm.rs @@ -1,4 +1,5 @@ use pyo3::{prelude::*, types::PyDict}; +use serde_json::Value; use super::SamplingParams; @@ -213,6 +214,28 @@ impl ToPyObject for Quantization { } } +impl TryFrom for LLMBuilder { + type Error = &'static str; + + fn try_from(value: Value) -> Result { + match value.as_object() { + Some(map) => { + let model = map + .get("model") + .ok_or("Json object must have `model` key")? + .as_str() + .ok_or("`model` key must be a str")?; + + Ok(LLMBuilder::new(model)) + } + None => match value.as_str() { + Some(model) => Ok(LLMBuilder::new(model)), + None => Err("Json value expected as str or object"), + }, + } + } +} + #[cfg(test)] mod tests { use crate::SamplingParamsBuilder; From ca7e4ade99534f08b0c1f534088977ffc778023d Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:37:07 +0000 Subject: [PATCH 6/9] make bindings vllm::outputs --- pgml-extension/src/api.rs | 14 +++++- pgml-extension/src/bindings/vllm/llm.rs | 19 +++------ pgml-extension/src/bindings/vllm/mod.rs | 2 + pgml-extension/src/bindings/vllm/outputs.rs | 47 +++++++++++++++++++++ 4 files changed, 67 insertions(+), 15 deletions(-) create mode 100644 pgml-extension/src/bindings/vllm/outputs.rs diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 989f4b73c..4a91352ea 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -653,7 +653,19 @@ fn transform(mut task: Value, args: Value, inputs: Vec<&str>) -> anyhow::Result< builder.build().unwrap() }); - Ok(json!(llm.generate(&inputs, None)?)) + let outputs = llm + .generate(&inputs, None)? + .iter() + .map(|o| { + o.outputs() + .unwrap() + .iter() + .map(|o| o.text().unwrap()) + .collect::>() + }) + .collect::>>(); + + Ok(json!(outputs)) } else { if let Some(map) = task.as_object_mut() { // pop backend keyword, if present diff --git a/pgml-extension/src/bindings/vllm/llm.rs b/pgml-extension/src/bindings/vllm/llm.rs index 0af17a615..108ca4496 100644 --- a/pgml-extension/src/bindings/vllm/llm.rs +++ b/pgml-extension/src/bindings/vllm/llm.rs @@ -1,7 +1,7 @@ use pyo3::{prelude::*, types::PyDict}; use serde_json::Value; -use super::SamplingParams; +use super::{RequestOutput, SamplingParams}; pub struct LLMBuilder { model: String, @@ -170,7 +170,7 @@ impl LLM { &self, prompts: &[&str], params: Option<&SamplingParams>, - ) -> PyResult> { + ) -> PyResult> { let prompts: Vec<_> = prompts.iter().map(|s| s.to_string()).collect(); Python::with_gil(|py| { @@ -178,19 +178,10 @@ impl LLM { kwargs.set_item("prompts", prompts)?; kwargs.set_item("sampling_params", params)?; - let outputs: Vec = self - .inner + self.inner .getattr(py, "generate")? .call(py, (), Some(kwargs))? - .extract(py)?; - - outputs - .iter() - .map(|output| -> PyResult { - let outputs: Vec = output.getattr(py, "outputs")?.extract(py)?; - outputs.first().unwrap().getattr(py, "text")?.extract(py) - }) - .collect::>>() + .extract(py) }) } } @@ -238,7 +229,7 @@ impl TryFrom for LLMBuilder { #[cfg(test)] mod tests { - use crate::SamplingParamsBuilder; + use crate::bindings::vllm::SamplingParamsBuilder; use super::*; diff --git a/pgml-extension/src/bindings/vllm/mod.rs b/pgml-extension/src/bindings/vllm/mod.rs index 3544f85b0..a17c312d4 100644 --- a/pgml-extension/src/bindings/vllm/mod.rs +++ b/pgml-extension/src/bindings/vllm/mod.rs @@ -1,7 +1,9 @@ //! Rust bindings to the Python package `vllm`. mod llm; +mod outputs; mod params; pub use llm::*; +pub use outputs::*; pub use params::*; diff --git a/pgml-extension/src/bindings/vllm/outputs.rs b/pgml-extension/src/bindings/vllm/outputs.rs new file mode 100644 index 000000000..b83af367a --- /dev/null +++ b/pgml-extension/src/bindings/vllm/outputs.rs @@ -0,0 +1,47 @@ +use pyo3::prelude::*; + +#[derive(Clone)] +pub struct RequestOutput { + inner: PyObject, +} + +#[derive(Clone)] +pub struct CompletionOutput { + inner: PyObject, +} + +impl RequestOutput { + pub fn prompt(&self) -> PyResult { + Python::with_gil(|py| self.inner.getattr(py, "prompt")?.extract(py)) + } + + pub fn outputs(&self) -> PyResult> { + Python::with_gil(|py| self.inner.getattr(py, "outputs")?.extract(py)) + } +} + +impl CompletionOutput { + pub fn finished(&self) -> PyResult { + Python::with_gil(|py| self.inner.getattr(py, "finished")?.call0(py)?.extract(py)) + } + + pub fn text(&self) -> PyResult { + Python::with_gil(|py| self.inner.getattr(py, "text")?.extract(py)) + } +} + +impl<'source> FromPyObject<'source> for RequestOutput { + fn extract(ob: &'source PyAny) -> PyResult { + Ok(Self { + inner: ob.extract()?, + }) + } +} + +impl<'source> FromPyObject<'source> for CompletionOutput { + fn extract(ob: &'source PyAny) -> PyResult { + Ok(Self { + inner: ob.extract()?, + }) + } +} From d017cd6bd9544d14e27eb8ef00b900e5b64a6c89 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Thu, 19 Oct 2023 18:42:20 +0000 Subject: [PATCH 7/9] swap out vLLM model if new --- pgml-extension/src/api.rs | 27 +------ pgml-extension/src/bindings/vllm/inference.rs | 75 +++++++++++++++++++ pgml-extension/src/bindings/vllm/llm.rs | 12 ++- pgml-extension/src/bindings/vllm/mod.rs | 2 + 4 files changed, 88 insertions(+), 28 deletions(-) create mode 100644 pgml-extension/src/bindings/vllm/inference.rs diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 4a91352ea..1bee51302 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -2,7 +2,6 @@ use std::fmt::Write; use std::str::FromStr; use ndarray::Zip; -use once_cell::sync::OnceCell; use pgrx::iter::{SetOfIterator, TableIterator}; use pgrx::*; use serde_json::Value; @@ -10,7 +9,6 @@ use serde_json::Value; #[cfg(feature = "python")] use serde_json::json; -use crate::bindings::vllm::{LLMBuilder, LLM}; #[cfg(feature = "python")] use crate::orm::*; @@ -642,30 +640,7 @@ fn transform(mut task: Value, args: Value, inputs: Vec<&str>) -> anyhow::Result< }); if use_vllm { - crate::bindings::python::activate().unwrap(); - - static LAZY_LLM: OnceCell = OnceCell::new(); - let llm = LAZY_LLM.get_or_init(move || { - let builder = match LLMBuilder::try_from(task) { - Ok(b) => b, - Err(e) => error!("{e}"), - }; - builder.build().unwrap() - }); - - let outputs = llm - .generate(&inputs, None)? - .iter() - .map(|o| { - o.outputs() - .unwrap() - .iter() - .map(|o| o.text().unwrap()) - .collect::>() - }) - .collect::>>(); - - Ok(json!(outputs)) + Ok(crate::bindings::vllm::vllm_inference(&task, &inputs)?) } else { if let Some(map) = task.as_object_mut() { // pop backend keyword, if present diff --git a/pgml-extension/src/bindings/vllm/inference.rs b/pgml-extension/src/bindings/vllm/inference.rs new file mode 100644 index 000000000..336286c03 --- /dev/null +++ b/pgml-extension/src/bindings/vllm/inference.rs @@ -0,0 +1,75 @@ +use parking_lot::Mutex; +use pyo3::prelude::*; +use serde_json::{json, Value}; + +use super::LLM; + +static MODEL: Mutex> = Mutex::new(None); + +pub fn vllm_inference(task: &Value, inputs: &[&str]) -> PyResult { + crate::bindings::python::activate().expect("python venv activate"); + let mut model = MODEL.lock(); + + let llm = match get_model_name(&model, task) { + ModelName::Same => model.as_mut().expect("ModelName::Same as_mut"), + ModelName::Different(name) => { + if let Some(llm) = model.take() { + // delete old model, exists + destroy_model_parallel(llm)?; + } + // make new model + let llm = LLM::new(&name)?; + model.insert(llm) + } + }; + + let outputs = llm + .generate(&inputs, None)? + .iter() + .map(|o| { + o.outputs() + .expect("RequestOutput::outputs()") + .iter() + .map(|o| o.text().expect("CompletionOutput::text()")) + .collect::>() + }) + .collect::>>(); + + Ok(json!(outputs)) +} + +fn get_model_name(model: &M, task: &Value) -> ModelName +where + M: std::ops::Deref>, +{ + match task + .as_object() + .and_then(|obj| obj.get("model").and_then(|m| m.as_str())) + { + Some(name) => match model.as_ref() { + Some(llm) if llm.model() == name => ModelName::Same, + _ => ModelName::Different(name.to_string()), + }, + None => ModelName::Same, + } +} + +enum ModelName { + Same, + Different(String), +} + +// See https://github.com/vllm-project/vllm/issues/565#issuecomment-1725174811 +fn destroy_model_parallel(llm: LLM) -> PyResult<()> { + Python::with_gil(|py| { + PyModule::import(py, "vllm")? + .getattr("model_executor")? + .getattr("parallel_utils")? + .getattr("parallel_state")? + .getattr("destroy_model_parallel")? + .call0()?; + drop(llm); + PyModule::import(py, "gc")?.getattr("collect")?.call0()?; + Ok(()) + }) +} diff --git a/pgml-extension/src/bindings/vllm/llm.rs b/pgml-extension/src/bindings/vllm/llm.rs index 108ca4496..20b7f9f9b 100644 --- a/pgml-extension/src/bindings/vllm/llm.rs +++ b/pgml-extension/src/bindings/vllm/llm.rs @@ -29,6 +29,7 @@ pub enum Quantization { } pub struct LLM { + model: String, inner: PyObject, } @@ -133,7 +134,7 @@ impl LLMBuilder { pub fn build(self) -> PyResult { let inner = Python::with_gil(|py| -> PyResult { let kwargs = PyDict::new(py); - kwargs.set_item("model", self.model)?; + kwargs.set_item("model", self.model.clone())?; kwargs.set_item("tokenizer", self.tokenizer)?; kwargs.set_item("tokenizer_mode", self.tokenizer_mode)?; kwargs.set_item("trust_remote_code", self.trust_remote_code)?; @@ -149,7 +150,10 @@ impl LLMBuilder { vllm.getattr("LLM")?.call((), Some(kwargs))?.extract() })?; - Ok(LLM { inner }) + Ok(LLM { + inner, + model: self.model, + }) } } @@ -184,6 +188,10 @@ impl LLM { .extract(py) }) } + + pub fn model(&self) -> &str { + self.model.as_str() + } } impl ToPyObject for TokenizerMode { diff --git a/pgml-extension/src/bindings/vllm/mod.rs b/pgml-extension/src/bindings/vllm/mod.rs index a17c312d4..b8ab1a978 100644 --- a/pgml-extension/src/bindings/vllm/mod.rs +++ b/pgml-extension/src/bindings/vllm/mod.rs @@ -1,9 +1,11 @@ //! Rust bindings to the Python package `vllm`. +mod inference; mod llm; mod outputs; mod params; +pub use inference::*; pub use llm::*; pub use outputs::*; pub use params::*; From 74ce6ae706dcaa4b793d3686dd915b7564f72b22 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Thu, 19 Oct 2023 20:32:56 +0000 Subject: [PATCH 8/9] add vllm docs --- pgml-extension/src/bindings/vllm/inference.rs | 2 ++ pgml-extension/src/bindings/vllm/mod.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/vllm/inference.rs b/pgml-extension/src/bindings/vllm/inference.rs index 336286c03..010d79020 100644 --- a/pgml-extension/src/bindings/vllm/inference.rs +++ b/pgml-extension/src/bindings/vllm/inference.rs @@ -4,6 +4,8 @@ use serde_json::{json, Value}; use super::LLM; +/// Cache a single model per client process. vLLM does not allow multiple, simultaneous models to be loaded. +/// See GH issue, https://github.com/vllm-project/vllm/issues/565 static MODEL: Mutex> = Mutex::new(None); pub fn vllm_inference(task: &Value, inputs: &[&str]) -> PyResult { diff --git a/pgml-extension/src/bindings/vllm/mod.rs b/pgml-extension/src/bindings/vllm/mod.rs index b8ab1a978..4292d70bf 100644 --- a/pgml-extension/src/bindings/vllm/mod.rs +++ b/pgml-extension/src/bindings/vllm/mod.rs @@ -1,4 +1,4 @@ -//! Rust bindings to the Python package `vllm`. +//! Rust bindings to the Python package [vLLM](https://vllm.readthedocs.io/en/latest/) mod inference; mod llm; From aca505ca8fd8c5b59df943b707cebf46090ef455 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Thu, 19 Oct 2023 21:22:01 +0000 Subject: [PATCH 9/9] add vllm inference docs; fix logic --- pgml-extension/src/bindings/vllm/inference.rs | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/pgml-extension/src/bindings/vllm/inference.rs b/pgml-extension/src/bindings/vllm/inference.rs index 010d79020..9c0d2b4e3 100644 --- a/pgml-extension/src/bindings/vllm/inference.rs +++ b/pgml-extension/src/bindings/vllm/inference.rs @@ -40,19 +40,28 @@ pub fn vllm_inference(task: &Value, inputs: &[&str]) -> PyResult { Ok(json!(outputs)) } +/// Determine if the "model" specified in the task is the same model as the one cached. +/// +/// # Panic +/// This function panics if: +/// - `task` is not an object +/// - "model" key is missing from `task` object +/// - "model" value is not a str fn get_model_name(model: &M, task: &Value) -> ModelName where M: std::ops::Deref>, { - match task - .as_object() - .and_then(|obj| obj.get("model").and_then(|m| m.as_str())) - { - Some(name) => match model.as_ref() { - Some(llm) if llm.model() == name => ModelName::Same, - _ => ModelName::Different(name.to_string()), - }, - None => ModelName::Same, + let name = task.as_object() + .expect("`task` is an object") + .get("model") + .expect("model key is present") + .as_str() + .expect("model value is a str"); + + if matches!(model.as_ref(), Some(llm) if llm.model() == name) { + ModelName::Same + } else { + ModelName::Different(name.to_string()) } } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy