diff --git a/pgml-extension/requirements.txt b/pgml-extension/requirements.txt index 89b3c2742..09b33732d 100644 --- a/pgml-extension/requirements.txt +++ b/pgml-extension/requirements.txt @@ -26,3 +26,4 @@ xgboost==2.0.0 langchain==0.0.287 einops==0.6.1 pynvml==11.5.0 +vllm==0.2.0 \ No newline at end of file diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index ad952e485..1bee51302 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -4,6 +4,7 @@ use std::str::FromStr; use ndarray::Zip; use pgrx::iter::{SetOfIterator, TableIterator}; use pgrx::*; +use serde_json::Value; #[cfg(feature = "python")] use serde_json::json; @@ -610,7 +611,7 @@ pub fn transform_json( inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), cache: default!(bool, false), ) -> JsonB { - match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { + match transform(task.0, args.0, inputs) { Ok(output) => JsonB(output), Err(e) => error!("{e}"), } @@ -632,6 +633,23 @@ pub fn transform_string( } } +fn transform(mut task: Value, args: Value, inputs: Vec<&str>) -> anyhow::Result { + // use vLLM if model present in task and backend is set to vllm + let use_vllm = task.as_object_mut().is_some_and(|obj| { + obj.contains_key("model") && matches!(obj.get("backend"), Some(Value::String(backend)) if backend.to_string().to_ascii_lowercase() == "vllm") + }); + + if use_vllm { + Ok(crate::bindings::vllm::vllm_inference(&task, &inputs)?) + } else { + if let Some(map) = task.as_object_mut() { + // pop backend keyword, if present + let _ = map.remove("backend"); + } + crate::bindings::transformers::transform(&task, &args, inputs) + } +} + #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String { diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 13702106d..e380c290c 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -46,6 +46,8 @@ pub mod python; pub mod sklearn; #[cfg(feature = "python")] pub mod transformers; +#[cfg(feature = "python")] +pub mod vllm; pub mod xgboost; pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Result>; diff --git a/pgml-extension/src/bindings/vllm/inference.rs b/pgml-extension/src/bindings/vllm/inference.rs new file mode 100644 index 000000000..9c0d2b4e3 --- /dev/null +++ b/pgml-extension/src/bindings/vllm/inference.rs @@ -0,0 +1,86 @@ +use parking_lot::Mutex; +use pyo3::prelude::*; +use serde_json::{json, Value}; + +use super::LLM; + +/// Cache a single model per client process. vLLM does not allow multiple, simultaneous models to be loaded. +/// See GH issue, https://github.com/vllm-project/vllm/issues/565 +static MODEL: Mutex> = Mutex::new(None); + +pub fn vllm_inference(task: &Value, inputs: &[&str]) -> PyResult { + crate::bindings::python::activate().expect("python venv activate"); + let mut model = MODEL.lock(); + + let llm = match get_model_name(&model, task) { + ModelName::Same => model.as_mut().expect("ModelName::Same as_mut"), + ModelName::Different(name) => { + if let Some(llm) = model.take() { + // delete old model, exists + destroy_model_parallel(llm)?; + } + // make new model + let llm = LLM::new(&name)?; + model.insert(llm) + } + }; + + let outputs = llm + .generate(&inputs, None)? + .iter() + .map(|o| { + o.outputs() + .expect("RequestOutput::outputs()") + .iter() + .map(|o| o.text().expect("CompletionOutput::text()")) + .collect::>() + }) + .collect::>>(); + + Ok(json!(outputs)) +} + +/// Determine if the "model" specified in the task is the same model as the one cached. +/// +/// # Panic +/// This function panics if: +/// - `task` is not an object +/// - "model" key is missing from `task` object +/// - "model" value is not a str +fn get_model_name(model: &M, task: &Value) -> ModelName +where + M: std::ops::Deref>, +{ + let name = task.as_object() + .expect("`task` is an object") + .get("model") + .expect("model key is present") + .as_str() + .expect("model value is a str"); + + if matches!(model.as_ref(), Some(llm) if llm.model() == name) { + ModelName::Same + } else { + ModelName::Different(name.to_string()) + } +} + +enum ModelName { + Same, + Different(String), +} + +// See https://github.com/vllm-project/vllm/issues/565#issuecomment-1725174811 +fn destroy_model_parallel(llm: LLM) -> PyResult<()> { + Python::with_gil(|py| { + PyModule::import(py, "vllm")? + .getattr("model_executor")? + .getattr("parallel_utils")? + .getattr("parallel_state")? + .getattr("destroy_model_parallel")? + .call0()?; + drop(llm); + PyModule::import(py, "gc")?.getattr("collect")?.call0()?; + Ok(()) + }) +} diff --git a/pgml-extension/src/bindings/vllm/llm.rs b/pgml-extension/src/bindings/vllm/llm.rs new file mode 100644 index 000000000..20b7f9f9b --- /dev/null +++ b/pgml-extension/src/bindings/vllm/llm.rs @@ -0,0 +1,272 @@ +use pyo3::{prelude::*, types::PyDict}; +use serde_json::Value; + +use super::{RequestOutput, SamplingParams}; + +pub struct LLMBuilder { + model: String, + tokenizer: Option, + tokenizer_mode: TokenizerMode, + trust_remote_code: bool, + tensor_parallel_size: u8, + dtype: String, + quantization: Option, + revision: Option, + seed: u64, + gpu_memory_utilization: f64, + swap_space: u32, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum TokenizerMode { + Auto, + Slow, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum Quantization { + Awq, +} + +pub struct LLM { + model: String, + inner: PyObject, +} + +impl LLMBuilder { + /// Create a builder for a model with the name or path of a HuggingFace + /// Transformers model. + pub fn new(model: &str) -> Self { + Self { + model: model.to_string(), + tokenizer: None, + tokenizer_mode: TokenizerMode::Auto, + trust_remote_code: false, + tensor_parallel_size: 1, + dtype: "auto".to_string(), + quantization: None, + revision: None, + seed: 0, + gpu_memory_utilization: 0.9, + swap_space: 4, + } + } + + /// The name or path of a HuggingFace Transformers tokenizer. + pub fn tokenizer(mut self, tokenizer: &str) -> Self { + self.tokenizer = Some(tokenizer.to_string()); + self + } + + /// The tokenizer mode. "auto" will use the fast tokenizer if available, and + /// "slow" will always use the slow tokenizer. + pub fn tokenizer_mode(mut self, tokenizer_mode: TokenizerMode) -> Self { + self.tokenizer_mode = tokenizer_mode; + self + } + + /// Trust remote code (e.g., from HuggingFace) when downloading the model + /// and tokenizer. + pub fn trust_remote_code(mut self, trust_remote_code: bool) -> Self { + self.trust_remote_code = trust_remote_code; + self + } + + /// The number of GPUs to use for distributed execution with tensor + /// parallelism. + pub fn tensor_parallel_size(mut self, tensor_parallel_size: u8) -> Self { + self.tensor_parallel_size = tensor_parallel_size; + self + } + + /// The data type for the model weights and activations. Currently, + /// we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + /// the `torch_dtype` attribute specified in the model config file. + /// However, if the `torch_dtype` in the config is `float32`, we will + /// use `float16` instead. + pub fn dtype(mut self, dtype: &str) -> Self { + self.dtype = dtype.to_string(); + self + } + + /// The method used to quantize the model weights. Currently, + /// we support "awq". If None, we assume the model weights are not + /// quantized and use `dtype` to determine the data type of the weights. + pub fn quantization(mut self, quantization: Quantization) -> Self { + self.quantization = Some(quantization); + self + } + + /// The specific model version to use. It can be a branch name, + /// a tag name, or a commit id. + pub fn revision(mut self, revision: &str) -> Self { + self.revision = Some(revision.to_string()); + self + } + + /// The seed to initialize the random number generator for sampling. + pub fn seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } + + /// The ratio (between 0 and 1) of GPU memory to + /// reserve for the model weights, activations, and KV cache. Higher + /// values will increase the KV cache size and thus improve the model's + /// throughput. However, if the value is too high, it may cause out-of- + /// memory (OOM) errors. + pub fn gpu_memory_utilization(mut self, gpu_memory_utilization: f64) -> Self { + self.gpu_memory_utilization = gpu_memory_utilization; + self + } + + /// The size (GiB) of CPU memory per GPU to use as swap space. + /// This can be used for temporarily storing the states of the requests + /// when their `best_of` sampling parameters are larger than 1. If all + /// requests will have `best_of=1`, you can safely set this to 0. + /// Otherwise, too small values may cause out-of-memory (OOM) errors. + pub fn swap_space(mut self, swap_space: u32) -> Self { + self.swap_space = swap_space; + self + } + + /// Create a [`LLM`] from the [`LLMBuilder`] + pub fn build(self) -> PyResult { + let inner = Python::with_gil(|py| -> PyResult { + let kwargs = PyDict::new(py); + kwargs.set_item("model", self.model.clone())?; + kwargs.set_item("tokenizer", self.tokenizer)?; + kwargs.set_item("tokenizer_mode", self.tokenizer_mode)?; + kwargs.set_item("trust_remote_code", self.trust_remote_code)?; + kwargs.set_item("tensor_parallel_size", self.tensor_parallel_size)?; + kwargs.set_item("dtype", self.dtype)?; + kwargs.set_item("quantization", self.quantization)?; + kwargs.set_item("revision", self.revision)?; + kwargs.set_item("seed", self.seed)?; + kwargs.set_item("gpu_memory_utilization", self.gpu_memory_utilization)?; + kwargs.set_item("swap_space", self.swap_space)?; + + let vllm = PyModule::import(py, "vllm")?; + vllm.getattr("LLM")?.call((), Some(kwargs))?.extract() + })?; + + Ok(LLM { + inner, + model: self.model, + }) + } +} + +impl LLM { + /// Create an LLM for a model with the name or path of a HuggingFace + /// Transformers model. + pub fn new(model: &str) -> PyResult { + LLMBuilder::new(model).build() + } + + /// Generates the completions for the input prompts. + /// + /// ### NOTE + /// This automatically batches the given prompts, considering the memory + /// constraint. For the best performance, put all of your prompts into a + /// single list and pass it to this method. + pub fn generate( + &self, + prompts: &[&str], + params: Option<&SamplingParams>, + ) -> PyResult> { + let prompts: Vec<_> = prompts.iter().map(|s| s.to_string()).collect(); + + Python::with_gil(|py| { + let kwargs = PyDict::new(py); + kwargs.set_item("prompts", prompts)?; + kwargs.set_item("sampling_params", params)?; + + self.inner + .getattr(py, "generate")? + .call(py, (), Some(kwargs))? + .extract(py) + }) + } + + pub fn model(&self) -> &str { + self.model.as_str() + } +} + +impl ToPyObject for TokenizerMode { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + TokenizerMode::Auto => "auto".to_string(), + TokenizerMode::Slow => "slow".to_string(), + } + .into_py(py) + } +} + +impl ToPyObject for Quantization { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + Quantization::Awq => "awg".to_string(), + } + .into_py(py) + } +} + +impl TryFrom for LLMBuilder { + type Error = &'static str; + + fn try_from(value: Value) -> Result { + match value.as_object() { + Some(map) => { + let model = map + .get("model") + .ok_or("Json object must have `model` key")? + .as_str() + .ok_or("`model` key must be a str")?; + + Ok(LLMBuilder::new(model)) + } + None => match value.as_str() { + Some(model) => Ok(LLMBuilder::new(model)), + None => Err("Json value expected as str or object"), + }, + } + } +} + +#[cfg(test)] +mod tests { + use crate::bindings::vllm::SamplingParamsBuilder; + + use super::*; + + #[test] + #[ignore = "requires model download"] + fn vllm_quickstart() { + // quickstart example from https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html + let prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ]; + let sampling_params = SamplingParamsBuilder::new() + .temperature(0.8) + .top_p(0.95) + .build() + .unwrap(); + + let llm = LLMBuilder::new("facebook/opt-125m").build().unwrap(); + let outputs = llm.generate(&prompts, Some(&sampling_params)).unwrap(); + assert_eq!(prompts.len(), outputs.len()); + } + + #[test] + #[ignore = "requires model download"] + fn model_support() { + if let Err(e) = LLMBuilder::new("intfloat/e5-small").build() { + assert!(e.to_string().contains("not supported")); + } + } +} diff --git a/pgml-extension/src/bindings/vllm/mod.rs b/pgml-extension/src/bindings/vllm/mod.rs new file mode 100644 index 000000000..4292d70bf --- /dev/null +++ b/pgml-extension/src/bindings/vllm/mod.rs @@ -0,0 +1,11 @@ +//! Rust bindings to the Python package [vLLM](https://vllm.readthedocs.io/en/latest/) + +mod inference; +mod llm; +mod outputs; +mod params; + +pub use inference::*; +pub use llm::*; +pub use outputs::*; +pub use params::*; diff --git a/pgml-extension/src/bindings/vllm/outputs.rs b/pgml-extension/src/bindings/vllm/outputs.rs new file mode 100644 index 000000000..b83af367a --- /dev/null +++ b/pgml-extension/src/bindings/vllm/outputs.rs @@ -0,0 +1,47 @@ +use pyo3::prelude::*; + +#[derive(Clone)] +pub struct RequestOutput { + inner: PyObject, +} + +#[derive(Clone)] +pub struct CompletionOutput { + inner: PyObject, +} + +impl RequestOutput { + pub fn prompt(&self) -> PyResult { + Python::with_gil(|py| self.inner.getattr(py, "prompt")?.extract(py)) + } + + pub fn outputs(&self) -> PyResult> { + Python::with_gil(|py| self.inner.getattr(py, "outputs")?.extract(py)) + } +} + +impl CompletionOutput { + pub fn finished(&self) -> PyResult { + Python::with_gil(|py| self.inner.getattr(py, "finished")?.call0(py)?.extract(py)) + } + + pub fn text(&self) -> PyResult { + Python::with_gil(|py| self.inner.getattr(py, "text")?.extract(py)) + } +} + +impl<'source> FromPyObject<'source> for RequestOutput { + fn extract(ob: &'source PyAny) -> PyResult { + Ok(Self { + inner: ob.extract()?, + }) + } +} + +impl<'source> FromPyObject<'source> for CompletionOutput { + fn extract(ob: &'source PyAny) -> PyResult { + Ok(Self { + inner: ob.extract()?, + }) + } +} diff --git a/pgml-extension/src/bindings/vllm/params.rs b/pgml-extension/src/bindings/vllm/params.rs new file mode 100644 index 000000000..31a54006a --- /dev/null +++ b/pgml-extension/src/bindings/vllm/params.rs @@ -0,0 +1,222 @@ +use pyo3::{prelude::*, types::PyDict}; + +#[derive(Debug, Clone)] +pub struct SamplingParamsBuilder { + n: usize, + best_of: Option, + presence_penalty: f64, + frequency_penalty: f64, + temperature: f64, + top_p: f64, + top_k: i32, + use_beam_search: bool, + length_penalty: f64, + early_stopping: EarlyStopping, + stop: Option>, + stop_token_ids: Option>, + ignore_eos: bool, + max_tokens: usize, + logprobs: Option, + skip_special_tokens: bool, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum EarlyStopping { + True, + False, + Never, +} + +pub struct SamplingParams { + inner: PyObject, +} + +impl SamplingParamsBuilder { + pub fn new() -> Self { + Self { + n: 1, + best_of: None, + presence_penalty: 0.0, + frequency_penalty: 0.0, + temperature: 1.0, + top_p: 1.0, + top_k: -1, + use_beam_search: false, + length_penalty: 1.0, + early_stopping: EarlyStopping::False, + stop: None, + stop_token_ids: None, + ignore_eos: false, + max_tokens: 16, + logprobs: None, + skip_special_tokens: true, + } + } + + /// Number of output sequences to return for the given prompt. + pub fn n(mut self, n: usize) -> Self { + self.n = n; + self + } + + /// Number of output sequences that are generated from the prompt. + /// From these `best_of` sequences, the top `n` sequences are returned. + /// `best_of` must be greater than or equal to `n`. This is treated as + /// the beam width when `use_beam_search` is true. By default, `best_of` + /// is set to `n`. + pub fn best_of(mut self, best_of: usize) -> Self { + self.best_of = Some(best_of); + self + } + + /// Float that penalizes new tokens based on whether they + /// appear in the generated text so far. Values > 0 encourage the model + /// to use new tokens, while values < 0 encourage the model to repeat + /// tokens. + pub fn presence_penalty(mut self, presence_penalty: f64) -> Self { + self.presence_penalty = presence_penalty; + self + } + + /// Float that penalizes new tokens based on their + /// frequency in the generated text so far. Values > 0 encourage the + /// model to use new tokens, while values < 0 encourage the model to + /// repeat tokens. + pub fn frequency_penalty(mut self, frequency_penalty: f64) -> Self { + self.frequency_penalty = frequency_penalty; + self + } + + /// Float that controls the randomness of the sampling. Lower + /// values make the model more deterministic, while higher values make + /// the model more random. Zero means greedy sampling. + pub fn temperature(mut self, temperature: f64) -> Self { + self.temperature = temperature; + self + } + + /// Float that controls the cumulative probability of the top tokens + /// to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + pub fn top_p(mut self, top_p: f64) -> Self { + self.top_p = top_p; + self + } + + /// Integer that controls the number of top tokens to consider. Set + /// to -1 to consider all tokens. + pub fn top_k(mut self, top_k: i32) -> Self { + self.top_k = top_k; + self + } + + /// Whether to use beam search instead of sampling. + pub fn use_beam_search(mut self, use_beam_search: bool) -> Self { + self.use_beam_search = use_beam_search; + self + } + + /// Float that penalizes sequences based on their length. + /// Used in beam search. + pub fn length_penalty(mut self, length_penalty: f64) -> Self { + self.length_penalty = length_penalty; + self + } + + /// Controls the stopping condition for beam search. It + /// accepts the following values: `true`, where the generation stops as + /// soon as there are `best_of` complete candidates; `false`, where an + /// heuristic is applied and the generation stops when is it very + /// unlikely to find better candidates; `"never"`, where the beam search + /// procedure only stops when there cannot be better candidates + /// (canonical beam search algorithm). + pub fn early_stopping(mut self, early_stopping: EarlyStopping) -> Self { + self.early_stopping = early_stopping; + self + } + + /// List of [`String`]s that stop the generation when they are generated. + /// The returned output will not contain the stop [`String`]s. + pub fn stop(mut self, stop: &[&str]) -> Self { + self.stop = Some(stop.iter().map(|s| s.to_string()).collect()); + self + } + + /// List of tokens that stop the generation when they are + /// generated. The returned output will contain the stop tokens unless + /// the stop tokens are sepcial tokens. + pub fn stop_token_ids(mut self, stop_token_ids: Vec) -> Self { + self.stop_token_ids = Some(stop_token_ids); + self + } + + /// Whether to ignore the EOS token and continue generating + /// tokens after the EOS token is generated. + pub fn ignore_eos(mut self, ignore_eos: bool) -> Self { + self.ignore_eos = ignore_eos; + self + } + + /// Maximum number of tokens to generate per output sequence. + pub fn max_tokens(mut self, max_tokens: usize) -> Self { + self.max_tokens = max_tokens; + self + } + + /// Number of log probabilities to return per output token. + pub fn logprobs(mut self, logprobs: usize) -> Self { + self.logprobs = Some(logprobs); + self + } + + /// Whether to skip special tokens in the output. + /// Defaults to true. + pub fn skip_special_tokens(mut self, skip_special_tokens: bool) -> Self { + self.skip_special_tokens = skip_special_tokens; + self + } + + pub fn build(self) -> PyResult { + let inner = Python::with_gil(|py| -> PyResult { + let kwargs = PyDict::new(py); + kwargs.set_item("n", self.n)?; + kwargs.set_item("best_of", self.best_of)?; + kwargs.set_item("presence_penalty", self.presence_penalty)?; + kwargs.set_item("frequency_penalty", self.frequency_penalty)?; + kwargs.set_item("temperature", self.temperature)?; + kwargs.set_item("top_p", self.top_p)?; + kwargs.set_item("top_k", self.top_k)?; + kwargs.set_item("use_beam_search", self.use_beam_search)?; + kwargs.set_item("length_penalty", self.length_penalty)?; + kwargs.set_item("early_stopping", self.early_stopping)?; + kwargs.set_item("stop", self.stop)?; + kwargs.set_item("stop_token_ids", self.stop_token_ids)?; + kwargs.set_item("ignore_eos", self.ignore_eos)?; + kwargs.set_item("max_tokens", self.max_tokens)?; + kwargs.set_item("logprobs", self.logprobs)?; + kwargs.set_item("skip_special_tokens", self.skip_special_tokens)?; + + let vllm = PyModule::import(py, "vllm")?; + vllm.getattr("SamplingParams")? + .call((), Some(kwargs))? + .extract() + })?; + + Ok(SamplingParams { inner }) + } +} + +impl ToPyObject for EarlyStopping { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + EarlyStopping::True => true.into_py(py), + EarlyStopping::False => false.into_py(py), + EarlyStopping::Never => "never".into_py(py), + } + } +} + +impl ToPyObject for SamplingParams { + fn to_object(&self, _py: Python<'_>) -> PyObject { + self.inner.clone() + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy