diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 5c6078785..c738fe1f5 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -104,19 +104,42 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs): self.next_tokens_are_prompt = True self.stop_signal = None self.text_queue = queue.Queue() + self.token_cache = [] + self.text_index_cache = [] - def put(self, value): + def put(self, values): if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return - # Can't batch this decode - decoded_values = [] - for v in value: - decoded_values.append(self.tokenizer.decode(v, **self.decode_kwargs)) - self.text_queue.put(decoded_values, self.timeout) + output = [] + for i, v in enumerate(values): + if len(self.token_cache) <= i: + self.token_cache.append([]) + self.text_index_cache.append(0) + token = v.tolist() # Returns a list or number + if type(token) == list: + self.token_cache[i].extend(token) + else: + self.token_cache[i].append(token) + text = self.tokenizer.decode(self.token_cache[i], **self.decode_kwargs) + if text.endswith("\n"): + output.append(text[self.text_index_cache[i] :]) + self.token_cache[i] = [] + self.text_index_cache[i] = 0 + else: + printable_text = text[self.text_index_cache[i] : text.rfind(" ") + 1] + self.text_index_cache[i] += len(printable_text) + output.append(printable_text) + if any(output): + self.text_queue.put(output, self.timeout) def end(self): self.next_tokens_are_prompt = True + output = [] + for i, tokens in enumerate(self.token_cache): + text = self.tokenizer.decode(tokens, **self.decode_kwargs) + output.append(text[self.text_index_cache[i] :]) + self.text_queue.put(output, self.timeout) self.text_queue.put(self.stop_signal, self.timeout) def __iter__(self): @@ -127,6 +150,7 @@ def __next__(self): if value != self.stop_signal: return value + class GGMLPipeline(object): def __init__(self, model_name, **task): import ctransformers @@ -245,7 +269,8 @@ def stream(self, input, **kwargs): generation_kwargs = None if self.task == "conversational": streamer = TextIteratorStreamer( - self.tokenizer, skip_prompt=True, skip_special_tokens=True + self.tokenizer, + skip_prompt=True, ) if "chat_template" in kwargs: input = self.tokenizer.apply_chat_template( @@ -261,7 +286,7 @@ def stream(self, input, **kwargs): input = self.tokenizer(input, return_tensors="pt").to(self.model.device) generation_kwargs = dict(input, streamer=streamer, **kwargs) else: - streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) + streamer = TextIteratorStreamer(self.tokenizer) input = self.tokenizer(input, return_tensors="pt", padding=True).to( self.model.device ) diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 82b51670c..f017a04db 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -8,6 +8,8 @@ async def migrate() -> None Json = Any DateTime = int +GeneralJsonIterator = Any +GeneralJsonAsyncIterator = Any "#; const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" @@ -16,6 +18,8 @@ export function migrate(): Promise; export type Json = any; export type DateTime = Date; +export type GeneralJsonIterator = any; +export type GeneralJsonAsyncIterator = any; export function newCollection(name: string, database_url?: string): Collection; export function newModel(name?: string, source?: string, parameters?: Json): Model; diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index f3b1fbec9..5c3a4df33 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -361,7 +361,6 @@ async def test_open_source_ai_create_async(): ], temperature=0.85, ) - import json assert len(results["choices"]) > 0 diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index c9a09326d..c49b5c493 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -101,13 +101,13 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult .expect("Error converting rust Json to JavaScript Object"); let d = cx.boolean(false); o.set(&mut cx, "value", v) - .expect("Error setting object value in transform_sream_iterate_next"); + .expect("Error setting object value in transform_stream_iterate_next"); o.set(&mut cx, "done", d) - .expect("Error setting object value in transform_sream_iterate_next"); + .expect("Error setting object value in transform_stream_iterate_next"); } else { let d = cx.boolean(true); o.set(&mut cx, "done", d) - .expect("Error setting object value in transform_sream_iterate_next"); + .expect("Error setting object value in transform_stream_iterate_next"); } Ok(o) }) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy