Skip to content

Commit b21ba0e

Browse files
committed
Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main
2 parents 159cc4e + 0281214 commit b21ba0e

File tree

4 files changed

+31
-3
lines changed

4 files changed

+31
-3
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.2.63]
11+
12+
- feat: Update llama.cpp to ggerganov/llama.cpp@0e4802b2ecbaab04b4f829fde4a3096ca19c84b5
13+
- feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct by @abetlen in cc81afebf04d26ca1ac3cf72f23f18da6ab58588
14+
1015
## [0.2.62]
1116

1217
- feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c

llama_cpp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .llama_cpp import *
22
from .llama import *
33

4-
__version__ = "0.2.62"
4+
__version__ = "0.2.63"

llama_cpp/llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,10 @@ def __init__(
426426
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
427427

428428
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
429-
template=template, eos_token=eos_token, bos_token=bos_token
429+
template=template,
430+
eos_token=eos_token,
431+
bos_token=bos_token,
432+
stop_token_ids=[eos_token_id],
430433
).to_chat_handler()
431434

432435
if self.chat_format is None and self.chat_handler is None:

llama_cpp/llama_chat_format.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import jinja2
1212

13+
import numpy as np
14+
import numpy.typing as npt
15+
1316
import llama_cpp.llama as llama
1417
import llama_cpp.llama_types as llama_types
1518
import llama_cpp.llama_grammar as llama_grammar
@@ -150,6 +153,7 @@ class ChatFormatterResponse:
150153

151154
prompt: str
152155
stop: Optional[Union[str, List[str]]] = None
156+
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
153157

154158

155159
class ChatFormatter(Protocol):
@@ -173,12 +177,14 @@ def __init__(
173177
eos_token: str,
174178
bos_token: str,
175179
add_generation_prompt: bool = True,
180+
stop_token_ids: Optional[List[int]] = None,
176181
):
177182
"""A chat formatter that uses jinja2 templates to format the prompt."""
178183
self.template = template
179184
self.eos_token = eos_token
180185
self.bos_token = bos_token
181186
self.add_generation_prompt = add_generation_prompt
187+
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
182188

183189
self._environment = jinja2.Environment(
184190
loader=jinja2.BaseLoader(),
@@ -211,7 +217,16 @@ def raise_exception(message: str):
211217
tool_choice=tool_choice,
212218
)
213219

214-
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
220+
stopping_criteria = None
221+
if self.stop_token_ids is not None:
222+
def stop_on_last_token(
223+
tokens: npt.NDArray[np.intc],
224+
logits: npt.NDArray[np.single]
225+
) -> bool:
226+
return tokens[-1] in self.stop_token_ids
227+
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
228+
229+
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
215230

216231
def to_chat_handler(self) -> LlamaChatCompletionHandler:
217232
return chat_formatter_to_chat_completion_handler(self)
@@ -533,6 +548,10 @@ def chat_completion_handler(
533548
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
534549
stop = stop + rstop
535550

551+
stopping_criteria = None
552+
if result.stopping_criteria is not None:
553+
stopping_criteria = result.stopping_criteria
554+
536555
if response_format is not None and response_format["type"] == "json_object":
537556
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
538557

@@ -598,6 +617,7 @@ def chat_completion_handler(
598617
mirostat_eta=mirostat_eta,
599618
model=model,
600619
logits_processor=logits_processor,
620+
stopping_criteria=stopping_criteria,
601621
grammar=grammar,
602622
logit_bias=logit_bias,
603623
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy