diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index c738fe1f5..7ff3bdc18 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -131,7 +131,7 @@ def put(self, values): self.text_index_cache[i] += len(printable_text) output.append(printable_text) if any(output): - self.text_queue.put(output, self.timeout) + self.text_queue.put(output) def end(self): self.next_tokens_are_prompt = True @@ -139,8 +139,8 @@ def end(self): for i, tokens in enumerate(self.token_cache): text = self.tokenizer.decode(tokens, **self.decode_kwargs) output.append(text[self.text_index_cache[i] :]) - self.text_queue.put(output, self.timeout) - self.text_queue.put(self.stop_signal, self.timeout) + self.text_queue.put(output) + self.text_queue.put(self.stop_signal) def __iter__(self): return self @@ -264,12 +264,13 @@ def __init__(self, model_name, **kwargs): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token - def stream(self, input, **kwargs): + def stream(self, input, timeout=None, **kwargs): streamer = None generation_kwargs = None if self.task == "conversational": streamer = TextIteratorStreamer( self.tokenizer, + timeout=timeout, skip_prompt=True, ) if "chat_template" in kwargs: @@ -286,7 +287,10 @@ def stream(self, input, **kwargs): input = self.tokenizer(input, return_tensors="pt").to(self.model.device) generation_kwargs = dict(input, streamer=streamer, **kwargs) else: - streamer = TextIteratorStreamer(self.tokenizer) + streamer = TextIteratorStreamer( + self.tokenizer, + timeout=timeout, + ) input = self.tokenizer(input, return_tensors="pt", padding=True).to( self.model.device ) @@ -355,7 +359,7 @@ def create_pipeline(task): return pipe -def transform_using(pipeline, args, inputs, stream=False): +def transform_using(pipeline, args, inputs, stream=False, timeout=None): args = orjson.loads(args) inputs = orjson.loads(inputs) @@ -364,7 +368,7 @@ def transform_using(pipeline, args, inputs, stream=False): convert_eos_token(pipeline.tokenizer, args) if stream: - return pipeline.stream(inputs, **args) + return pipeline.stream(inputs, timeout=timeout, **args) return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: