Skip to content

Commit 804b7af

Browse files
authored
Working simple python thread metrics (#1239)
1 parent 4ceda37 commit 804b7af

File tree

1 file changed

+63
-5
lines changed

1 file changed

+63
-5
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import shutil
44
import time
55
import queue
6+
import sys
7+
import json
68

79
import datasets
810
from InstructorEmbedding import INSTRUCTOR
@@ -40,7 +42,7 @@
4042
TrainingArguments,
4143
Trainer,
4244
)
43-
from threading import Thread
45+
import threading
4446

4547
__cache_transformer_by_model_id = {}
4648
__cache_sentence_transformer_by_name = {}
@@ -62,6 +64,26 @@
6264
}
6365

6466

67+
class WorkerThreads:
68+
def __init__(self):
69+
self.worker_threads = {}
70+
71+
def delete_thread(self, id):
72+
del self.worker_threads[id]
73+
74+
def update_thread(self, id, value):
75+
self.worker_threads[id] = value
76+
77+
def get_thread(self, id):
78+
if id in self.worker_threads:
79+
return self.worker_threads[id]
80+
else:
81+
return None
82+
83+
84+
worker_threads = WorkerThreads()
85+
86+
6587
class PgMLException(Exception):
6688
pass
6789

@@ -105,6 +127,12 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
105127
self.token_cache = []
106128
self.text_index_cache = []
107129

130+
def set_worker_thread_id(self, id):
131+
self.worker_thread_id = id
132+
133+
def get_worker_thread_id(self):
134+
return self.worker_thread_id
135+
108136
def put(self, values):
109137
if self.skip_prompt and self.next_tokens_are_prompt:
110138
self.next_tokens_are_prompt = False
@@ -149,6 +177,22 @@ def __next__(self):
149177
return value
150178

151179

180+
def streaming_worker(worker_threads, model, **kwargs):
181+
thread_id = threading.get_native_id()
182+
try:
183+
worker_threads.update_thread(
184+
thread_id, json.dumps({"model": model.name_or_path})
185+
)
186+
except:
187+
worker_threads.update_thread(thread_id, "Error setting data")
188+
try:
189+
model.generate(**kwargs)
190+
except BaseException as error:
191+
print(f"Error in streaming_worker: {error}", file=sys.stderr)
192+
finally:
193+
worker_threads.delete_thread(thread_id)
194+
195+
152196
class GGMLPipeline(object):
153197
def __init__(self, model_name, **task):
154198
import ctransformers
@@ -185,7 +229,7 @@ def do_work():
185229
self.q.put(x)
186230
self.done = True
187231

188-
thread = Thread(target=do_work)
232+
thread = threading.Thread(target=do_work)
189233
thread.start()
190234

191235
def __iter__(self):
@@ -283,7 +327,13 @@ def stream(self, input, timeout=None, **kwargs):
283327
input, add_generation_prompt=True, tokenize=False
284328
)
285329
input = self.tokenizer(input, return_tensors="pt").to(self.model.device)
286-
generation_kwargs = dict(input, streamer=streamer, **kwargs)
330+
generation_kwargs = dict(
331+
input,
332+
worker_threads=worker_threads,
333+
model=self.model,
334+
streamer=streamer,
335+
**kwargs,
336+
)
287337
else:
288338
streamer = TextIteratorStreamer(
289339
self.tokenizer,
@@ -292,9 +342,17 @@ def stream(self, input, timeout=None, **kwargs):
292342
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
293343
self.model.device
294344
)
295-
generation_kwargs = dict(input, streamer=streamer, **kwargs)
296-
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
345+
generation_kwargs = dict(
346+
input,
347+
worker_threads=worker_threads,
348+
model=self.model,
349+
streamer=streamer,
350+
**kwargs,
351+
)
352+
# thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
353+
thread = threading.Thread(target=streaming_worker, kwargs=generation_kwargs)
297354
thread.start()
355+
streamer.set_worker_thread_id(thread.native_id)
298356
return streamer
299357

300358
def __call__(self, inputs, **kwargs):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy