Skip to content

Commit 8c83804

Browse files
author
Anil Pathak
committed
Fix low_level_api_chat_cpp to match current API
1 parent 7c898d5 commit 8c83804

File tree

1 file changed

+37
-14
lines changed

1 file changed

+37
-14
lines changed

examples/low_level_api/low_level_api_chat_cpp.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import llama_cpp
1919
from common import GptParams, gpt_params_parse, gpt_random_prompt
2020
import util
21+
import os
2122

2223
# A LLaMA interactive session
2324
class LLaMAInteract:
@@ -62,7 +63,7 @@ def __init__(self, params: GptParams) -> None:
6263
self.multibyte_fix = []
6364

6465
# model load
65-
self.lparams = llama_cpp.llama_context_default_params()
66+
self.lparams = llama_cpp.llama_model_default_params()
6667
self.lparams.n_ctx = self.params.n_ctx
6768
self.lparams.n_parts = self.params.n_parts
6869
self.lparams.seed = self.params.seed
@@ -72,7 +73,11 @@ def __init__(self, params: GptParams) -> None:
7273

7374
self.model = llama_cpp.llama_load_model_from_file(
7475
self.params.model.encode("utf8"), self.lparams)
75-
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
76+
77+
# Context Params.
78+
self.cparams = llama_cpp.llama_context_default_params()
79+
80+
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams)
7681
if (not self.ctx):
7782
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
7883

@@ -244,7 +249,7 @@ def __init__(self, params: GptParams) -> None:
244249
# tokenize a prompt
245250
def _tokenize(self, prompt, bos=True):
246251
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
247-
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
252+
_n = llama_cpp.llama_tokenize(self.model, prompt.encode("utf8", errors="ignore"), len(prompt), _arr, len(_arr), bos, False)
248253
return _arr[:_n]
249254

250255
def set_color(self, c):
@@ -304,7 +309,7 @@ def generate(self):
304309
self.n_past += n_eval"""
305310

306311
if (llama_cpp.llama_eval(
307-
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
312+
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past
308313
) != 0):
309314
raise Exception("Failed to llama_eval!")
310315

@@ -332,7 +337,7 @@ def generate(self):
332337
id = 0
333338

334339
logits = llama_cpp.llama_get_logits(self.ctx)
335-
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
340+
n_vocab = llama_cpp.llama_n_vocab(self.model)
336341

337342
# Apply params.logit_bias map
338343
for key, value in self.params.logit_bias.items():
@@ -349,12 +354,20 @@ def generate(self):
349354
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
350355

351356
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
352-
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
353-
_arr,
354-
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty))
355-
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
356-
_arr,
357-
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
357+
llama_cpp.llama_sample_repetition_penalties(
358+
ctx=self.ctx,
359+
candidates=candidates_p,
360+
last_tokens_data = _arr,
361+
penalty_last_n = last_n_repeat,
362+
penalty_repeat = llama_cpp.c_float(self.params.repeat_penalty),
363+
penalty_freq = llama_cpp.c_float(self.params.frequency_penalty),
364+
penalty_present = llama_cpp.c_float(self.params.presence_penalty),
365+
)
366+
367+
# NOT PRESENT IN CURRENT VERSION ?
368+
# llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p,
369+
# _arr,
370+
# last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
358371

359372
if not self.params.penalize_nl:
360373
logits[llama_cpp.llama_token_nl()] = nl_logit
@@ -473,7 +486,7 @@ def exit(self):
473486
def token_to_str(self, token_id: int) -> bytes:
474487
size = 32
475488
buffer = (ctypes.c_char * size)()
476-
n = llama_cpp.llama_token_to_piece_with_model(
489+
n = llama_cpp.llama_token_to_piece(
477490
self.model, llama_cpp.llama_token(token_id), buffer, size)
478491
assert n <= size
479492
return bytes(buffer[:n])
@@ -532,6 +545,9 @@ def interact(self):
532545
print(i,end="",flush=True)
533546
self.params.input_echo = False
534547

548+
# Using string instead of tokens to check for antiprompt,
549+
# It is more reliable than tokens for interactive mode.
550+
generated_str = ""
535551
while self.params.interactive:
536552
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
537553
if (self.params.instruct):
@@ -546,6 +562,10 @@ def interact(self):
546562
try:
547563
for i in self.output():
548564
print(i,end="",flush=True)
565+
generated_str += i
566+
for ap in self.params.antiprompt:
567+
if generated_str.endswith(ap):
568+
raise KeyboardInterrupt
549569
except KeyboardInterrupt:
550570
self.set_color(util.CONSOLE_COLOR_DEFAULT)
551571
if not self.params.instruct:
@@ -561,7 +581,7 @@ def interact(self):
561581
time_now = datetime.now()
562582
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
563583
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision.
564-
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
584+
Transcript below contains only the recorded dialog between two, without any annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
565585
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
566586
The transcript only includes text, it does not include markup like HTML and Markdown.
567587
@@ -575,8 +595,11 @@ def interact(self):
575595
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
576596
{USER_NAME}: Name a color.
577597
{AI_NAME}: Blue
578-
{USER_NAME}:"""
598+
{USER_NAME}: """
599+
579600
params = gpt_params_parse()
601+
if params.prompt is None and params.file is None:
602+
params.prompt = prompt
580603

581604
with LLaMAInteract(params) as m:
582605
m.interact()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy