Skip to content

Commit cc81afe

Browse files
committed
feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct
1 parent d17c188 commit cc81afe

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,10 @@ def __init__(
426426
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
427427

428428
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
429-
template=template, eos_token=eos_token, bos_token=bos_token
429+
template=template,
430+
eos_token=eos_token,
431+
bos_token=bos_token,
432+
stop_token_ids=[eos_token_id],
430433
).to_chat_handler()
431434

432435
if self.chat_format is None and self.chat_handler is None:

llama_cpp/llama_chat_format.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import jinja2
1212

13+
import numpy as np
14+
import numpy.typing as npt
15+
1316
import llama_cpp.llama as llama
1417
import llama_cpp.llama_types as llama_types
1518
import llama_cpp.llama_grammar as llama_grammar
@@ -150,6 +153,7 @@ class ChatFormatterResponse:
150153

151154
prompt: str
152155
stop: Optional[Union[str, List[str]]] = None
156+
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
153157

154158

155159
class ChatFormatter(Protocol):
@@ -173,12 +177,14 @@ def __init__(
173177
eos_token: str,
174178
bos_token: str,
175179
add_generation_prompt: bool = True,
180+
stop_token_ids: Optional[List[int]] = None,
176181
):
177182
"""A chat formatter that uses jinja2 templates to format the prompt."""
178183
self.template = template
179184
self.eos_token = eos_token
180185
self.bos_token = bos_token
181186
self.add_generation_prompt = add_generation_prompt
187+
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
182188

183189
self._environment = jinja2.Environment(
184190
loader=jinja2.BaseLoader(),
@@ -211,7 +217,16 @@ def raise_exception(message: str):
211217
tool_choice=tool_choice,
212218
)
213219

214-
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
220+
stopping_criteria = None
221+
if self.stop_token_ids is not None:
222+
def stop_on_last_token(
223+
tokens: npt.NDArray[np.intc],
224+
logits: npt.NDArray[np.single]
225+
) -> bool:
226+
return tokens[-1] in self.stop_token_ids
227+
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
228+
229+
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
215230

216231
def to_chat_handler(self) -> LlamaChatCompletionHandler:
217232
return chat_formatter_to_chat_completion_handler(self)
@@ -533,6 +548,10 @@ def chat_completion_handler(
533548
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
534549
stop = stop + rstop
535550

551+
stopping_criteria = None
552+
if result.stopping_criteria is not None:
553+
stopping_criteria = result.stopping_criteria
554+
536555
if response_format is not None and response_format["type"] == "json_object":
537556
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
538557

@@ -598,6 +617,7 @@ def chat_completion_handler(
598617
mirostat_eta=mirostat_eta,
599618
model=model,
600619
logits_processor=logits_processor,
620+
stopping_criteria=stopping_criteria,
601621
grammar=grammar,
602622
logit_bias=logit_bias,
603623
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy