Skip to content

Commit 139898d

Browse files
committed
Removed Instructor
1 parent 63a8f4a commit 139898d

File tree

1 file changed

+10
-19
lines changed

1 file changed

+10
-19
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from datetime import datetime
99

1010
import datasets
11-
from InstructorEmbedding import INSTRUCTOR
1211
import numpy
1312
import orjson
1413
from rouge import Rouge
@@ -502,23 +501,17 @@ def transform(task, args, inputs, stream=False):
502501

503502

504503
def create_embedding(transformer):
505-
instructor = transformer.startswith("hkunlp/instructor")
506-
klass = INSTRUCTOR if instructor else SentenceTransformer
507-
return klass(transformer)
504+
return SentenceTransformer(transformer)
508505

509506

510507
def embed_using(model, transformer, inputs, kwargs):
511508
if isinstance(kwargs, str):
512509
kwargs = orjson.loads(kwargs)
513510

514511
instructor = transformer.startswith("hkunlp/instructor")
515-
if instructor:
516-
texts_with_instructions = []
512+
if instructor and "instruction" in kwargs:
517513
instruction = kwargs.pop("instruction")
518-
for text in inputs:
519-
texts_with_instructions.append([instruction, text])
520-
521-
inputs = texts_with_instructions
514+
kwargs["prompt"] = instruction
522515

523516
return model.encode(inputs, **kwargs)
524517

@@ -1029,7 +1022,6 @@ def __init__(
10291022
path: str,
10301023
hyperparameters: dict,
10311024
) -> None:
1032-
10331025
# initialize class variables
10341026
self.project_id = project_id
10351027
self.model_id = model_id
@@ -1100,8 +1092,9 @@ def print_number_of_trainable_model_parameters(self, model):
11001092
# Calculate and print the number and percentage of trainable parameters
11011093
r_log("info", f"Trainable model parameters: {trainable_model_params}")
11021094
r_log("info", f"All model parameters: {all_model_params}")
1103-
r_log("info",
1104-
f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"
1095+
r_log(
1096+
"info",
1097+
f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%",
11051098
)
11061099

11071100
def tokenize_function(self):
@@ -1396,23 +1389,22 @@ def __init__(
13961389
"bias": "none",
13971390
"task_type": "CAUSAL_LM",
13981391
}
1399-
r_log("info",
1392+
r_log(
1393+
"info",
14001394
"LoRA configuration are not set. Using default parameters"
1401-
+ json.dumps(self.lora_config_params)
1395+
+ json.dumps(self.lora_config_params),
14021396
)
14031397

14041398
self.prompt_template = None
14051399
if "prompt_template" in hyperparameters.keys():
14061400
self.prompt_template = hyperparameters.pop("prompt_template")
14071401

14081402
def train(self):
1409-
14101403
args = TrainingArguments(
14111404
output_dir=self.path, logging_dir=self.path, **self.training_args
14121405
)
14131406

14141407
def formatting_prompts_func(example):
1415-
14161408
system_content = example["system"]
14171409
user_content = example["user"]
14181410
assistant_content = example["assistant"]
@@ -1463,7 +1455,7 @@ def formatting_prompts_func(example):
14631455
peft_config=LoraConfig(**self.lora_config_params),
14641456
callbacks=[PGMLCallback(self.project_id, self.model_id)],
14651457
)
1466-
r_log("info","Creating Supervised Fine Tuning trainer done. Training ... ")
1458+
r_log("info", "Creating Supervised Fine Tuning trainer done. Training ... ")
14671459

14681460
# Train
14691461
self.trainer.train()
@@ -1582,7 +1574,6 @@ def finetune_conversation(
15821574
project_id,
15831575
model_id,
15841576
):
1585-
15861577
train_dataset = datasets.Dataset.from_dict(
15871578
{
15881579
"system": system_train,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy