Open
Description
I think that the SyntheticDataKit
ports are hardcoded to 8000. I would like to suggest changing it to take in a port argument and then editing the load_vllm code to ping the port inputs taken. Should be quite a quick fix!
Additional context
class SyntheticDataKit:
def __init__(
self,
model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
max_seq_length = 2048,
gpu_memory_utilization = 0.98,
float8_kv_cache = False,
conservativeness = 1.0,
token = None,
**kwargs,
):
assert(type(model_name) is str)
assert(type(max_seq_length) is int)
assert(type(gpu_memory_utilization) is float)
assert(type(float8_kv_cache) is bool)
assert(type(conservativeness) is float)
assert(token is None or type(token) is str)
self.model_name = model_name
self.max_seq_length = max_seq_length
from transformers import AutoConfig, AutoTokenizer
self.config = AutoConfig.from_pretrained(
model_name,
token = token,
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token = token,
)
patch_vllm()
engine_args = load_vllm(
model_name = model_name,
config = self.config,
gpu_memory_utilization = gpu_memory_utilization,
max_seq_length = max_seq_length,
disable_log_stats = True,
float8_kv_cache = float8_kv_cache,
conservativeness = conservativeness,
return_args = True,
enable_lora = False,
use_bitsandbytes = False,
**kwargs,
)
if "device" in engine_args: del engine_args["device"]
if "model" in engine_args: del engine_args["model"]
subprocess_commands = [
"vllm", "serve", str(model_name),
]
..... #other codes
@staticmethod
def check_vllm_status():
try:
response = requests.get("http://localhost:8000/metrics")
if response.status_code == 200:
return True
except requests.exceptions.ConnectionError:
return False
pass
pass