Skip to content

Commit 4ce6670

Browse files
authored
Merge pull request abetlen#87 from SagsMug/main
Fix TypeError in low_level chat
2 parents eb7f278 + 1b73a15 commit 4ce6670

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

examples/low_level_api/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class GptParams:
5050
# If chat ended prematurely, append this to the conversation to fix it.
5151
# Set to "\nUser:" etc.
5252
# This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
53-
fix_prefix: str = " "
53+
fix_prefix: str = ""
5454
output_postfix: str = ""
5555
input_echo: bool = True,
5656

@@ -75,7 +75,7 @@ def gpt_params_parse(argv = None, params: Optional[GptParams] = None):
7575
parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
7676
parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
7777
parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
78-
parser.add_argument("--n_predict", type=int, default=128, help="number of model parts",dest="n_predict")
78+
parser.add_argument("--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict")
7979
parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
8080
parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
8181
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")

examples/low_level_api/low_level_api_chat_cpp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(self, params: GptParams) -> None:
144144

145145
# determine newline token
146146
self.llama_token_newline = self._tokenize("\n", False)
147+
self.llama_token_eot = self._tokenize(" [end of text]\n", False)
147148

148149
if (self.params.verbose_prompt):
149150
print(f"""
@@ -203,16 +204,16 @@ def _tokenize(self, prompt, bos=True):
203204
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
204205
return _arr[:_n]
205206

206-
def use_antiprompt(self):
207-
return len(self.first_antiprompt) > 0
208-
209207
def set_color(self, c):
210208
if (self.params.use_color):
211209
print(c, end="")
212210

211+
def use_antiprompt(self):
212+
return len(self.first_antiprompt) > 0
213+
213214
# generate tokens
214215
def generate(self):
215-
while self.remaining_tokens > 0 or self.params.interactive:
216+
while self.remaining_tokens > 0 or self.params.interactive or self.params.n_predict == -1:
216217
# predict
217218
if len(self.embd) > 0:
218219
# infinite text generation via context swapping
@@ -313,7 +314,7 @@ def generate(self):
313314
# end of text token
314315
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
315316
if (not self.params.instruct):
316-
for i in " [end of text]\n":
317+
for i in self.llama_token_eot:
317318
yield i
318319
break
319320

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy