@@ -159,11 +159,14 @@ class llama_token_data_array(Structure):
159
159
160
160
161
161
# struct llama_context_params {
162
- # uint32_t seed; // RNG seed, -1 for random
163
- # int32_t n_ctx; // text context
164
- # int32_t n_batch; // prompt processing batch size
165
- # int32_t n_gpu_layers; // number of layers to store in VRAM
166
- # int32_t main_gpu; // the GPU that is used for scratch and small tensors
162
+ # uint32_t seed; // RNG seed, -1 for random
163
+ # int32_t n_ctx; // text context
164
+ # int32_t n_batch; // prompt processing batch size
165
+ # int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
166
+ # float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
167
+ # int32_t n_gpu_layers; // number of layers to store in VRAM
168
+ # int32_t main_gpu; // the GPU that is used for scratch and small tensors
169
+ #
167
170
# const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
168
171
169
172
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -190,6 +193,8 @@ class llama_context_params(Structure):
190
193
("seed" , c_uint32 ),
191
194
("n_ctx" , c_int32 ),
192
195
("n_batch" , c_int32 ),
196
+ ("n_gqa" , c_int32 ),
197
+ ("rms_norm_eps" , c_float ),
193
198
("n_gpu_layers" , c_int32 ),
194
199
("main_gpu" , c_int32 ),
195
200
("tensor_split" , POINTER (c_float )),
@@ -265,6 +270,57 @@ class llama_model_quantize_params(Structure):
265
270
]
266
271
267
272
273
+ # // grammar types
274
+ # struct llama_grammar;
275
+ llama_grammar_p = c_void_p
276
+
277
+ # // grammar element type
278
+ # enum llama_gretype {
279
+ # // end of rule definition
280
+ # LLAMA_GRETYPE_END = 0,
281
+
282
+ # // start of alternate definition for rule
283
+ # LLAMA_GRETYPE_ALT = 1,
284
+
285
+ # // non-terminal element: reference to rule
286
+ # LLAMA_GRETYPE_RULE_REF = 2,
287
+
288
+ # // terminal element: character (code point)
289
+ # LLAMA_GRETYPE_CHAR = 3,
290
+
291
+ # // inverse char(s) ([^a], [^a-b] [^abc])
292
+ # LLAMA_GRETYPE_CHAR_NOT = 4,
293
+
294
+ # // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
295
+ # // be an inclusive range ([a-z])
296
+ # LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
297
+
298
+ # // modifies a preceding LLAMA_GRETYPE_CHAR or
299
+ # // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
300
+ # LLAMA_GRETYPE_CHAR_ALT = 6,
301
+ # };
302
+ LLAMA_GRETYPE_END = c_int (0 )
303
+ LLAMA_GRETYPE_ALT = c_int (1 )
304
+ LLAMA_GRETYPE_RULE_REF = c_int (2 )
305
+ LLAMA_GRETYPE_CHAR = c_int (3 )
306
+ LLAMA_GRETYPE_CHAR_NOT = c_int (4 )
307
+ LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int (5 )
308
+ LLAMA_GRETYPE_CHAR_ALT = c_int (6 )
309
+
310
+
311
+ # typedef struct llama_grammar_element {
312
+ # enum llama_gretype type;
313
+ # uint32_t value; // Unicode code point or rule ID
314
+ # } llama_grammar_element;
315
+ class llama_grammar_element (Structure ):
316
+ _fields_ = [
317
+ ("type" , c_int ),
318
+ ("value" , c_uint32 ),
319
+ ]
320
+
321
+
322
+ llama_grammar_element_p = POINTER (llama_grammar_element )
323
+
268
324
# // performance timing information
269
325
# struct llama_timings {
270
326
# double t_start_ms;
@@ -871,6 +927,37 @@ def llama_token_nl() -> int:
871
927
_lib .llama_token_nl .restype = llama_token
872
928
873
929
930
+ # // Grammar
931
+ # //
932
+ # LLAMA_API struct llama_grammar * llama_grammar_init(
933
+ # const llama_grammar_element ** rules,
934
+ # size_t n_rules,
935
+ # size_t start_rule_index);
936
+ def llama_grammar_init (
937
+ rules , # type: Array[llama_grammar_element_p] # type: ignore
938
+ n_rules : c_size_t ,
939
+ start_rule_index : c_size_t ,
940
+ ) -> llama_grammar_p :
941
+ return _lib .llama_grammar_init (rules , n_rules , start_rule_index )
942
+
943
+
944
+ _lib .llama_grammar_init .argtypes = [
945
+ POINTER (llama_grammar_element_p ),
946
+ c_size_t ,
947
+ c_size_t ,
948
+ ]
949
+ _lib .llama_grammar_init .restype = llama_grammar_p
950
+
951
+
952
+ # LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
953
+ def llama_grammar_free (grammar : llama_grammar_p ):
954
+ return _lib .llama_grammar_free (grammar )
955
+
956
+
957
+ _lib .llama_grammar_free .argtypes = [llama_grammar_p ]
958
+ _lib .llama_grammar_free .restype = None
959
+
960
+
874
961
# Sampling functions
875
962
876
963
0 commit comments