From be11967857847fa6a6367e4b8bde56394a4f8fd2 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 12 Dec 2023 13:53:59 -0800 Subject: [PATCH] Working simple python thread metrics --- .../src/bindings/transformers/transformers.py | 68 +++++++++++++++++-- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 7d46d0636..eed9cede7 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -3,6 +3,8 @@ import shutil import time import queue +import sys +import json import datasets from InstructorEmbedding import INSTRUCTOR @@ -40,7 +42,7 @@ TrainingArguments, Trainer, ) -from threading import Thread +import threading __cache_transformer_by_model_id = {} __cache_sentence_transformer_by_name = {} @@ -62,6 +64,26 @@ } +class WorkerThreads: + def __init__(self): + self.worker_threads = {} + + def delete_thread(self, id): + del self.worker_threads[id] + + def update_thread(self, id, value): + self.worker_threads[id] = value + + def get_thread(self, id): + if id in self.worker_threads: + return self.worker_threads[id] + else: + return None + + +worker_threads = WorkerThreads() + + class PgMLException(Exception): pass @@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs): self.token_cache = [] self.text_index_cache = [] + def set_worker_thread_id(self, id): + self.worker_thread_id = id + + def get_worker_thread_id(self): + return self.worker_thread_id + def put(self, values): if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False @@ -149,6 +177,22 @@ def __next__(self): return value +def streaming_worker(worker_threads, model, **kwargs): + thread_id = threading.get_native_id() + try: + worker_threads.update_thread( + thread_id, json.dumps({"model": model.name_or_path}) + ) + except: + worker_threads.update_thread(thread_id, "Error setting data") + try: + model.generate(**kwargs) + except BaseException as error: + print(f"Error in streaming_worker: {error}", file=sys.stderr) + finally: + worker_threads.delete_thread(thread_id) + + class GGMLPipeline(object): def __init__(self, model_name, **task): import ctransformers @@ -185,7 +229,7 @@ def do_work(): self.q.put(x) self.done = True - thread = Thread(target=do_work) + thread = threading.Thread(target=do_work) thread.start() def __iter__(self): @@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs): input, add_generation_prompt=True, tokenize=False ) input = self.tokenizer(input, return_tensors="pt").to(self.model.device) - generation_kwargs = dict(input, streamer=streamer, **kwargs) + generation_kwargs = dict( + input, + worker_threads=worker_threads, + model=self.model, + streamer=streamer, + **kwargs, + ) else: streamer = TextIteratorStreamer( self.tokenizer, @@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs): input = self.tokenizer(input, return_tensors="pt", padding=True).to( self.model.device ) - generation_kwargs = dict(input, streamer=streamer, **kwargs) - thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + generation_kwargs = dict( + input, + worker_threads=worker_threads, + model=self.model, + streamer=streamer, + **kwargs, + ) + # thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread = threading.Thread(target=streaming_worker, kwargs=generation_kwargs) thread.start() + streamer.set_worker_thread_id(thread.native_id) return streamer def __call__(self, inputs, **kwargs):
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: