Skip to content

Commit e811a81

Browse files
committed
Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main
2 parents ca8e3c9 + 5212fb0 commit e811a81

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

llama_cpp/llama.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ def __init__(
410410
if self.verbose:
411411
print(f"Model metadata: {self.metadata}", file=sys.stderr)
412412

413-
eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos()))
414-
bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos()))
413+
eos_token_id = self.token_eos()
414+
bos_token_id = self.token_bos()
415415

416416
eos_token = self._model.token_get_text(eos_token_id)
417417
bos_token = self._model.token_get_text(bos_token_id)
@@ -961,9 +961,9 @@ def _create_completion(
961961

962962
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
963963
created: int = int(time.time())
964-
prefix_token_id: int = int(self.metadata.get("tokenizer.ggml.prefix_token_id", self._model.token_prefix()))
965-
middle_token_id: int = int(self.metadata.get("tokenizer.ggml.middle_token_id", self._model.token_middle()))
966-
suffix_token_id: int = int(self.metadata.get("tokenizer.ggml.suffix_token_id", self._model.token_suffix()))
964+
prefix_token_id: int = self._model.token_prefix()
965+
middle_token_id: int = self._model.token_middle()
966+
suffix_token_id: int = self._model.token_suffix()
967967
# If prompt is empty, initialize completion with BOS token to avoid
968968
# detokenization including a space at the beginning of the completion
969969
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
@@ -2084,3 +2084,19 @@ def __call__(
20842084
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
20852085
) -> bool:
20862086
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
2087+
2088+
2089+
class MinTokensLogitsProcessor(LogitsProcessor):
2090+
def __init__(self, min_tokens: int, token_eos: int):
2091+
self.min_tokens = min_tokens
2092+
self.token_eos = token_eos
2093+
self.prompt_tokens = None
2094+
2095+
def __call__(
2096+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
2097+
) -> npt.NDArray[np.single]:
2098+
if self.prompt_tokens is None:
2099+
self.prompt_tokens = len(input_ids)
2100+
if len(input_ids) - self.prompt_tokens < self.min_tokens:
2101+
scores[self.token_eos] = -np.inf
2102+
return scores

llama_cpp/server/app.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ async def create_completion(
275275
"best_of",
276276
"logit_bias_type",
277277
"user",
278+
"min_tokens",
278279
}
279280
kwargs = body.model_dump(exclude=exclude)
280281

@@ -288,6 +289,15 @@ async def create_completion(
288289
if body.grammar is not None:
289290
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
290291

292+
if body.min_tokens > 0:
293+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
294+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
295+
)
296+
if "logits_processor" not in kwargs:
297+
kwargs["logits_processor"] = _min_tokens_logits_processor
298+
else:
299+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
300+
291301
iterator_or_completion: Union[
292302
llama_cpp.CreateCompletionResponse,
293303
Iterator[llama_cpp.CreateCompletionStreamResponse],
@@ -445,6 +455,7 @@ async def create_chat_completion(
445455
"n",
446456
"logit_bias_type",
447457
"user",
458+
"min_tokens",
448459
}
449460
kwargs = body.model_dump(exclude=exclude)
450461
llama = llama_proxy(body.model)
@@ -458,6 +469,15 @@ async def create_chat_completion(
458469
if body.grammar is not None:
459470
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
460471

472+
if body.min_tokens > 0:
473+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
474+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
475+
)
476+
if "logits_processor" not in kwargs:
477+
kwargs["logits_processor"] = _min_tokens_logits_processor
478+
else:
479+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
480+
461481
iterator_or_completion: Union[
462482
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
463483
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)

llama_cpp/server/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
default=16, ge=1, description="The maximum number of tokens to generate."
1717
)
1818

19+
min_tokens_field = Field(
20+
default=0,
21+
ge=0,
22+
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
23+
)
24+
1925
temperature_field = Field(
2026
default=0.8,
2127
description="Adjust the randomness of the generated text.\n\n"
@@ -111,6 +117,7 @@ class CreateCompletionRequest(BaseModel):
111117
max_tokens: Optional[int] = Field(
112118
default=16, ge=0, description="The maximum number of tokens to generate."
113119
)
120+
min_tokens: int = min_tokens_field
114121
temperature: float = temperature_field
115122
top_p: float = top_p_field
116123
min_p: float = min_p_field
@@ -206,6 +213,7 @@ class CreateChatCompletionRequest(BaseModel):
206213
default=None,
207214
description="The maximum number of tokens to generate. Defaults to inf",
208215
)
216+
min_tokens: int = min_tokens_field
209217
logprobs: Optional[bool] = Field(
210218
default=False,
211219
description="Whether to output the logprobs or not. Default is True"

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy