Skip to content

Instantly share code, notes, and snippets.

@ochafik
Last active July 9, 2025 22:15
Show Gist options
  • Save ochafik/9246d289b7d38d49e1ee2755698d6c79 to your computer and use it in GitHub Desktop.
Save ochafik/9246d289b7d38d49e1ee2755698d6c79 to your computer and use it in GitHub Desktop.
llama.cpp agent example

Agents / Tool Calling w/ llama.cpp

This is a simple e2e example of an agent + isolated tools that leverages llama.cpp's new universal tool call support (which I've added in ggml-org/llama.cpp#9639)

While any model should work (using some generic support), the bigger and more recent the model (fine-tuned specifically for function calling), the better.

Here's how to run an agent w/ local tool call:

  • Install prerequisite:

    git clone https://github.com/ggerganov/llama.cpp
    cd llama.cpp
    cmake -B build -DLLAMA_CURL=1
    cmake --build build --parallel
    cmake --install build
  • Clone this gist and cd to its contents:

    git clone https://gist.github.com/9246d289b7d38d49e1ee2755698d6c79.git llama.cpp-agent-example
    cd llama.cpp-agent-example
  • Run llama-server w/ any model good at function calling. Bigger / more recent = better (try Q8_0 quants if you can). Also you might need a template chat override to make the most of your model's native tool call format support, see server/README.md). The following models have been verified:

    # Native support for Llama 3.x, Mistral Nemo, Qwen 2.5, Hermes 3, Functionary 3.x, Firefunction v2...
    
    llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
    
    llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
    
    llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
    
    llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
    
    llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
      --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
      
    llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
    
    llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-Q5_K_M.gguf \
      --chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
      
    llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
      --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
  • Run the tools in this gist inside a docker container for some level of isolation (+ sneaky logging of outgoing http and https traffic: you wanna watch over those agents' shoulders for the time being 🧐). Check http://localhost:8088/docs to see the tools exposed.

    export BRAVE_SEARCH_API_KEY=... # Get one at https://api.search.brave.com/
    ./serve_tools_inside_docker.sh

    [!WARNING] The command above gives tools (and your agent) access to the web (and read-only access to examples/agent/**. You can loosen / restrict web access in examples/agent/squid/conf/squid.conf.

  • Run the agent with some goal

    uv run agent.py "What is the sum of 2535 squared and 32222000403?"
    See output w/ Hermes-3-Llama-3.1-8B
    🛠️  Tools: python, fetch_page, brave_search
    ⚙️  python(code="print(2535**2 + 32222000403)")
    → 15 chars
    The sum of 2535 squared and 32222000403 is 32228426628.
    
    uv run agent.py "What is the best BBQ joint in Laguna Beach?"
    See output w/ Hermes-3-Llama-3.1-8B
    🛠️  Tools: python, fetch_page, brave_search
    ⚙️  brave_search(query="best bbq joint in laguna beach")
    → 4283 chars
    Based on the search results, Beach Pit BBQ seems to be a popular and highly-rated BBQ joint in Laguna Beach. They offer a variety of BBQ options, including ribs, pulled pork, brisket, salads, wings, and more. They have dine-in, take-out, and catering options available.
    
    uv run agent.py "Search (with brave), fetch and summarize the homepage of llama.cpp"
    See output w/ Hermes-3-Llama-3.1-8B
    🛠️  Tools: python, fetch_page, brave_search
    ⚙️  brave_search(query="llama.cpp")
    → 3330 chars
    Llama.cpp is an open-source software library written in C++ that performs inference on various Large Language Models (LLMs). Alongside the library, it includes a CLI and web server. It is co-developed alongside the GGML project, a general-purpose tensor library. Llama.cpp is also available with Python bindings, known as llama.cpp-python. It has gained popularity for its ability to run LLMs on local machines, such as Macs with NVIDIA RTX systems. Users can leverage this library to accelerate LLMs and integrate them into various applications. There are numerous resources available, including tutorials and guides, for getting started with Llama.cpp and llama.cpp-python.
    
  • To compare the above results w/ a cloud provider's tool usage behaviour, just set the --provider flag (accepts openai, together, groq) and/or use --endpoint, --api-key, and --model

    export LLAMA_API_KEY=...      # for --provider=llama.cpp https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md
    export OPENAI_API_KEY=...     # for --provider=openai    https://platform.openai.com/api-keys
    export TOGETHER_API_KEY=...   # for --provider=together  https://api.together.ai/settings/api-keys
    export GROQ_API_KEY=...       # for --provider=groq      https://console.groq.com/keys
    uv run agent.py "Search for, fetch and summarize the homepage of llama.cpp" --provider=openai
# SPDX-License-Identifier: Apache-2.0
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "aiohttp",
# "fastapi",
# "pydantic",
# "typer",
# "uvicorn",
# ]
# ///
import aiohttp
import asyncio
from functools import wraps
import json
from openapi import discover_tools
import os
from pydantic import BaseModel
import sys
import typer
from typing import Annotated, Literal, Optional
def typer_async_workaround():
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
return decorator
_PROVIDERS = {
'llama.cpp': {
'endpoint': 'http://localhost:8080/v1/',
'api_key_env': 'LLAMA_API_KEY', # https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md
},
'openai': {
'endpoint': 'https://api.openai.com/v1/',
'default_model': 'gpt-4o',
'api_key_env': 'OPENAI_API_KEY', # https://platform.openai.com/api-keys
},
'together': {
'endpoint': 'https://api.together.xyz',
'default_model': 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo',
'api_key_env': 'TOGETHER_API_KEY', # https://api.together.ai/settings/api-keys
},
'groq': {
'endpoint': 'https://api.groq.com/openai',
'default_model': 'llama-3.1-70b-versatile',
'api_key_env': 'GROQ_API_KEY', # https://console.groq.com/keys
},
}
@typer_async_workaround()
async def main(
goal: str,
model: str = 'gpt-4o',
tool_endpoints: Optional[list[str]] = None,
think: bool = False,
max_iterations: Optional[int] = 10,
system: Optional[str] = None,
verbose: bool = False,
cache_prompt: bool = True,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
greedy: bool = False,
seed: Optional[int] = None,
interactive: bool = True,
provider: Annotated[str, Literal['llama.cpp', 'openai', 'together', 'groq']] = 'llama.cpp',
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
):
if not tool_endpoints:
tool_endpoints = ["http://localhost:8088"]
provider_info = _PROVIDERS[provider]
if endpoint is None:
endpoint = provider_info['endpoint']
if api_key is None:
api_key = os.environ.get(provider_info['api_key_env'])
tool_map, tools = await discover_tools(tool_endpoints or [], verbose)
if greedy:
if temperature is None:
temperature = 0.0
if top_k is None:
top_k = 1
if top_p is None:
top_p = 0.0
if think:
tools.append({
'type': 'function',
'function': {
'name': 'think',
'description': 'Call this function at every step to explain your thought process, before taking any other action',
'parameters': {
'type': 'object',
'properties': {
'thought': {
'type': 'string'
}
},
'required': ['thought']
}
}
})
tool_map['think'] = lambda thought: 'ACK'
sys.stdout.write(f'🛠️ Tools: {", ".join(tool_map.keys()) if tool_map else "<none>"}\n')
try:
messages = []
if system:
messages.append(dict(
role='system',
content=system,
))
messages.append(
dict(
role='user',
content=goal,
)
)
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
async def run_turn():
for i in range(max_iterations or sys.maxsize):
url = f'{endpoint}chat/completions'
payload = dict(
messages=messages,
model=model,
tools=tools,
temperature=temperature,
top_p=top_p,
top_k=top_k,
seed=seed,
)
if provider == 'llama.cpp':
payload.update(dict(
cache_prompt=cache_prompt,
)) # type: ignore
if verbose:
print(f'Calling {url} with {json.dumps(payload, indent=2)}', file=sys.stderr)
async with aiohttp.ClientSession(headers=headers) as session:
async with session.post(url, json=payload) as response:
response.raise_for_status()
response = await response.json()
if verbose:
print(f'Response: {json.dumps(response, indent=2)}', file=sys.stderr)
assert len(response['choices']) == 1
choice = response['choices'][0]
content = choice['message']['content']
if choice['finish_reason'] == 'tool_calls':
messages.append(choice['message'])
assert choice['message']['tool_calls']
for tool_call in choice['message']['tool_calls']:
if content:
print(f'💭 {content}', file=sys.stderr)
name = tool_call['function']['name']
args = json.loads(tool_call['function']['arguments'])
if verbose:
print(f'tool_call: {json.dumps(tool_call, indent=2)}', file=sys.stderr)
if think and name == 'think':
print(f'🧠 {args["thought"]}', file=sys.stderr)
else:
pretty_call = f'{name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
print(f'⚙️ {pretty_call}', file=sys.stderr, end=None)
sys.stderr.flush()
try:
tool_result = await tool_map[name](**args)
except Exception as e:
tool_result = 'ERROR: ' + str(e)
tool_result_str = tool_result if isinstance(tool_result, str) else json.dumps(tool_result)
if not (think and name == 'think'):
def describe(res, res_str, max_len = 1000):
if isinstance(res, list):
return f'{len(res)} items'
return f'{len(res_str)} chars\n {res_str[:1000] if len(res_str) > max_len else res_str}...'
print(f' → {describe(tool_result, tool_result_str)}', file=sys.stderr)
if verbose:
print(tool_result_str, file=sys.stderr)
messages.append(dict(
tool_call_id=tool_call.get('id'),
role='tool',
name=name,
content=tool_result_str,
))
else:
assert content
print(content)
return
if max_iterations is not None:
raise Exception(f'Failed to get a valid response after {max_iterations} tool calls')
while interactive:
await run_turn()
messages.append(dict(
role='user',
content=input('💬 ')
))
except aiohttp.ClientResponseError as e:
sys.stdout.write(f'💥 {e}\n')
sys.exit(1)
if __name__ == '__main__':
typer.run(main)
# SPDX-License-Identifier: Apache-2.0
services:
# Forwards tool calls to the `siloed_tools` container.
tools_endpoint:
container_name: tools_endpoint
depends_on:
- siloed_tools
image: alpine/socat:latest
networks:
- private_net
- external_net
ports:
- 8088:8088
command: TCP-LISTEN:8088,fork,bind=tools_endpoint TCP-CONNECT:siloed_tools:8088
# Runs tools w/o **direct* internet access.
#
# All outgoing tool traffic must go through outgoing_proxy, which will log even HTTPS requests
# (the proxy's self-signed cert is added to this container's root CAs).
#
# Even if you trust your agents (which you shouldn't), please verify the kind of traffic they emit.
siloed_tools:
container_name: siloed_tools
depends_on:
# - embeddings_server
- outgoing_proxy
image: local/llama.cpp:isolated-tools
# sqlite-vec isn't compiled for linux/arm64 so to virtualize on Mac we force this to be x86_64
platform: linux/amd64
build:
context: .
dockerfile: Dockerfile.tools
ports:
- 8088:8088
volumes:
- ./data:/data:rw
networks:
- private_net
environment:
- BRAVE_SEARCH_API_KEY=${BRAVE_SEARCH_API_KEY}
- EMBEDDINGS_DIMS=768
- EMBEDDINGS_MODEL_FILE=/models/nomic-embed-text-v1.5.Q4_K_M.gguf
# - EMBEDDINGS_ENDPOINT=http://embeddings_server:8081/v1/embeddings
- EXCLUDE_TOOLS=${EXCLUDE_TOOLS:-}
- INCLUDE_TOOLS=${INCLUDE_TOOLS:-}
- MEMORY_SQLITE_DB=/data/memory.db
- REQUESTS_CA_BUNDLE=/usr/local/share/ca-certificates/squidCA.crt
- VERBOSE=1
- http_proxy=http://outgoing_proxy:3128
- https_proxy=http://outgoing_proxy:3128
# entrypoint: /usr/bin/bash
# command: ["-c", "pip install --upgrade gguf && apt update && apt install -y curl && curl https://ochafik.com && pip install gguf"]
# Logs all outgoing traffic, and caches pip & apt packages.
outgoing_proxy:
container_name: outgoing_proxy
image: local/llama.cpp:squid
build:
context: .
dockerfile: Dockerfile.squid
volumes:
- ./squid.conf:/etc/squid/squid.conf:ro
- ./squid/cache:/var/spool/squid:rw
- ./squid/logs:/var/log/squid:rw
- ./squid/ssl_cert:/etc/squid/ssl_cert:ro
- ./squid/ssl_db:/var/spool/squid/ssl_db:rw
extra_hosts:
- host.docker.internal:host-gateway
networks:
- private_net
- external_net
ports:
- "3128:3128"
restart: unless-stopped
entrypoint: /usr/bin/bash
command: -c "squid -N -z && ( test -d /var/spool/squid/ssl_db/db || /usr/lib/squid/security_file_certgen -c -s /var/spool/squid/ssl_db/db -M 20MB ) && /usr/sbin/squid -N -d 1 -s"
networks:
private_net:
driver: bridge
internal: true
external_net:
driver: bridge
# SPDX-License-Identifier: Apache-2.0
FROM debian:stable
ENV SQUID_CACHE_DIR=/var/spool/squid \
SQUID_LOG_DIR=/var/log/squid
RUN apt update && \
apt install -y squid-openssl && \
apt clean cache
# SPDX-License-Identifier: Apache-2.0
FROM python:3.12-slim
RUN python -m pip install --upgrade pip && \
apt update && \
apt install -y wget && \
apt clean cache
COPY requirements.txt /root/
WORKDIR /root
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu && \
pip install -r requirements.txt
RUN wget https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.Q4_K_M.gguf \
-O /root/nomic-embed-text-v1.5.Q4_K_M.gguf
COPY tool*.py /root/tools/
ENV PYTHONPATH=/root/tools
COPY ./squid/ssl_cert/squidCA.crt /usr/local/share/ca-certificates/squidCA.crt
RUN chmod 644 /usr/local/share/ca-certificates/squidCA.crt && update-ca-certificates
ENTRYPOINT [ "uvicorn" ]
CMD ["tools:app", "--host", "0.0.0.0", "--port", "8088"]
# SPDX-License-Identifier: Apache-2.0
import aiohttp
import json
import sys
import urllib.parse
class OpenAPIMethod:
def __init__(self, url, name, descriptor, catalog):
'''
Wraps a remote OpenAPI method as an async Python function.
'''
self.url = url
self.__name__ = name
assert 'post' in descriptor, 'Only POST methods are supported'
post_descriptor = descriptor['post']
self.__doc__ = post_descriptor.get('description', '')
parameters = post_descriptor.get('parameters', [])
request_body = post_descriptor.get('requestBody')
self.parameters = {p['name']: p for p in parameters}
assert all(param['in'] == 'query' for param in self.parameters.values()), f'Only query path parameters are supported (path: {url}, descriptor: {json.dumps(descriptor)})'
self.body = None
if request_body:
assert 'application/json' in request_body['content'], f'Only application/json is supported for request body (path: {url}, descriptor: {json.dumps(descriptor)})'
body_name = 'body'
i = 2
while body_name in self.parameters:
body_name = f'body{i}'
i += 1
self.body = dict(
name=body_name,
required=request_body['required'],
schema=request_body['content']['application/json']['schema'],
)
self.parameters_schema = dict(
type='object',
properties={
**({
self.body['name']: self.body['schema']
} if self.body else {}),
**{
name: param['schema']
for name, param in self.parameters.items()
}
},
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
)
if (components := catalog.get('components', {})) is not None:
if (schemas := components.get('schemas')) is not None:
del schemas['HTTPValidationError']
del schemas['ValidationError']
if not schemas:
del components['schemas']
if components:
self.parameters_schema['components'] = components
async def __call__(self, **kwargs):
if self.body:
body = kwargs.pop(self.body['name'], None)
if self.body['required']:
assert body is not None, f'Missing required body parameter: {self.body["name"]}'
else:
body = None
query_params = {}
for name, param in self.parameters.items():
value = kwargs.pop(name, None)
if param['required']:
assert value is not None, f'Missing required parameter: {name}'
assert param['in'] == 'query', 'Only query parameters are supported'
query_params[name] = value
params = '&'.join(f'{name}={urllib.parse.quote(str(value))}' for name, value in query_params.items() if value is not None)
url = f'{self.url}?{params}'
async with aiohttp.ClientSession() as session:
async with session.post(url, json=body) as response:
if response.status == 500:
raise Exception(await response.text())
response.raise_for_status()
response_json = await response.json()
return response_json
async def discover_tools(tool_endpoints: list[str], verbose) -> tuple[dict, list]:
tool_map = {}
tools = []
async with aiohttp.ClientSession() as session:
for url in tool_endpoints:
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'
catalog_url = f'{url}/openapi.json'
async with session.get(catalog_url) as response:
response.raise_for_status()
catalog = await response.json()
for path, descriptor in catalog['paths'].items():
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
tool_map[fn.__name__] = fn
if verbose:
print(f'Function {fn.__name__}: params schema: {fn.parameters_schema}', file=sys.stderr)
tools.append(dict(
type='function',
function=dict(
name=fn.__name__,
description=fn.__doc__ or '',
parameters=fn.parameters_schema,
)
)
)
return tool_map, tools
aiosqlite
docling
fastapi[standard]
# html2text
ipython
requests
sparqlwrapper
sqlite-lembed
sqlite-rembed
sqlite-vec
uvicorn
#!/bin/bash
#
# SPDX-License-Identifier: Apache-2.0
#
# Serves tools inside a docker container.
#
# All outgoing HTTP *and* HTTPS traffic will be logged to `examples/agent/squid/logs/access.log`.
# Direct traffic to the host machine will be ~blocked, but clever AIs may find a way around it:
# make sure to have proper firewall rules in place.
#
# Take a look at `examples/agent/squid/conf/squid.conf` if you want tools to access your local llama-server(s).
#
# Usage:
# examples/agent/serve_tools_inside_docker.sh
#
set -euo pipefail
mkdir -p squid/{cache,logs,ssl_cert,ssl_db}
rm -f squid/logs/{access,cache}.log
# Generate a self-signed certificate for the outgoing proxy.
# Tools can only reach out to HTTPS endpoints through that proxy, which they are told to trust blindly.
openssl req -new -newkey rsa:4096 -days 3650 -nodes -x509 \
-keyout squid/ssl_cert/squidCA.pem \
-out squid/ssl_cert/squidCA.pem \
-subj "/C=US/ST=State/L=City/O=Organization/OU=Org Unit/CN=outgoing_proxy"
openssl x509 -outform PEM -in squid/ssl_cert/squidCA.pem -out squid/ssl_cert/squidCA.crt
docker compose up --build "$@"
# SPDX-License-Identifier: Apache-2.0
# Squid Proxy w/ logging of both HTTP *and* HTTPS requests.
# We setup SSL Bump so http_proxy & https_proxy environment variables can be set to
# `http://<this_host>:3128` on any clients that trusts the CA certificate.
http_port 3128 ssl-bump cert=/etc/squid/ssl_cert/squidCA.pem tls-cafile=/etc/squid/ssl_cert/squidCA.crt
sslcrtd_program /usr/lib/squid/security_file_certgen -s /var/spool/squid/ssl_db/db -M 20MB
sslcrtd_children 5 startup=1
acl step1 at_step SslBump1
ssl_bump peek step1
ssl_bump bump all
dns_nameservers 8.8.8.8 8.8.4.4
dns_timeout 5 seconds
positive_dns_ttl 24 hours
negative_dns_ttl 1 minutes
# Forbid access to the host.
# If you want to allow tools to call llama-server on the host (e.g. embeddings, or recursive thoughts),
# you can comment out the next two lines.
acl blocked_sites dstdomain host.docker.internal host-gateway docker.for.mac.localhost docker.for.mac.host.internal
http_access deny blocked_sites
# Allow all other traffic (you may want to restrict this in a production environment)
http_access allow all
request_header_access Cache-Control deny all
request_header_add Cache-Control "no-cache" all
# refresh_pattern ^.*$ 0 0% 0
# Cache Python packages
refresh_pattern -i ($|\.)(files\.pythonhosted\.org|pypi\.org)/.*?\.(whl|zip|tar\.gz)$ 10080 90% 43200 reload-into-ims
# Cache Debian packages
refresh_pattern \.debian\.org/.*?\.(deb|udeb|tar\.(gz|xz|bz2))$ 129600 100% 129600
# Configure cache
cache_dir ufs /var/spool/squid 10000 16 256
cache_mem 256 MB
maximum_object_size 1024 MB
maximum_object_size_in_memory 512 MB
# Configure logs
strip_query_terms off
cache_log stdio:/var/log/squid/cache.log
access_log stdio:/var/log/squid/access.log squid
cache_store_log none
# SPDX-License-Identifier: Apache-2.0
import logging
from docling.document_converter import DocumentConverter
def fetch(url: str) -> str:
'''
Fetch a document at the provided URL and convert it to Markdown.
'''
logging.debug(f'[fetch] Fetching %s', url)
converter = DocumentConverter()
result = converter.convert(url)
return result.document.export_to_markdown()
# SPDX-License-Identifier: Apache-2.0
'''
Memory tools that use sqlite-vec as a vector database (combined w/ sqlite-lembed or sqlite-rembed for embeddings).
Note: it's best to run this in a silo w/:
./examples/agent/serve_tools_inside_docker.sh
# Run w/o other tools:
## Prerequisites:
pip install aiosqlite "fastapi[standard]" sqlite-lembed sqlite-rembed sqlite-vec uvicorn
## Usage w/ sqlite-rembed:
./llama-server --port 8081 -fa -c 0 --embeddings --rope-freq-scale 0.75 \
-hfr nomic-ai/nomic-embed-text-v1.5-GGUF -hff nomic-embed-text-v1.5.Q4_K_M.gguf
MEMORY_SQLITE_DB=memory_rembed.db \
EMBEDDINGS_DIMS=768 \
EMBEDDINGS_ENDPOINT=http://localhost:8081/v1/embeddings \
python examples/agent/tools/memory.py
## Usage w/ sqlite-lembed:
MEMORY_SQLITE_DB=memory_lembed.db \
EMBEDDINGS_DIMS=768 \
EMBEDDINGS_MODEL_FILE=~/Library/Caches/llama.cpp/nomic-embed-text-v1.5.Q4_K_M.gguf \
python examples/agent/tools/memory.py
## Test:
curl -X POST "http://localhost:8000/memorize" -H "Content-Type: application/json" -d '["User is Olivier Chafik", "User is a Software Engineer"]'
curl -X POST "http://localhost:8000/search_memory?text=What%20do%20we%20do%3F"
'''
import logging
import aiosqlite
import fastapi
import os
import sqlite_lembed
import sqlite_rembed
import sqlite_vec
verbose = os.environ.get('VERBOSE', '0') == '1'
db_path = os.environ['MEMORY_SQLITE_DB']
# Embeddings configuration:
# Can either provide an embeddings model file (to be loaded locally by sqlite-lembed)
# or an embeddings endpoint w/ optional api key (to be queried remotely by sqlite-rembed).
embeddings_dims = int(os.environ['EMBEDDINGS_DIMS'])
if 'EMBEDDINGS_MODEL_FILE' in os.environ:
local = True
embed_fn = 'lembed'
embeddings_model_file = os.environ['EMBEDDINGS_MODEL_FILE']
logging.info(f'Using local embeddings model: {embeddings_model_file}')
elif 'EMBEDDINGS_ENDPOINT' in os.environ:
local = False
embed_fn = 'rembed'
embeddings_endpoint = os.environ['EMBEDDINGS_ENDPOINT']
embeddings_api_key = os.environ.get('EMBEDDINGS_API_KEY')
logging.info(f'Using remote embeddings endpoint: {embeddings_endpoint}')
else:
raise ValueError('Either EMBEDDINGS_MODEL_FILE or EMBEDDINGS_ENDPOINT must be set')
async def setup_db(db: aiosqlite.Connection):
await db.enable_load_extension(True)
await db.load_extension(sqlite_vec.loadable_path())
if local:
await db.load_extension(sqlite_lembed.loadable_path())
else:
await db.load_extension(sqlite_rembed.loadable_path())
await db.enable_load_extension(False)
client_name = 'default'
if local:
await db.execute(f'''
INSERT INTO lembed_models(name, model) VALUES (
'{client_name}', lembed_model_from_file(?)
);
''', (embeddings_model_file,))
else:
await db.execute(f'''
INSERT INTO rembed_clients(name, options) VALUES (
'{client_name}', rembed_client_options('format', 'llamafile', 'url', ?, 'key', ?)
);
''', (embeddings_endpoint, embeddings_api_key))
async def create_vector_index(table_name, text_column, embedding_column):
'''
Create an sqlite-vec virtual table w/ an embedding column
kept in sync with a source table's text column.
'''
await db.execute(f'''
CREATE VIRTUAL TABLE IF NOT EXISTS {table_name}_{embedding_column} USING vec0(
{embedding_column} float[{embeddings_dims}]
)
''')
await db.execute(f'''
CREATE TRIGGER IF NOT EXISTS insert_{table_name}_{embedding_column}
AFTER INSERT ON {table_name}
BEGIN
INSERT INTO {table_name}_{embedding_column} (rowid, {embedding_column})
VALUES (NEW.rowid, {embed_fn}('{client_name}', NEW.{text_column}));
END;
''')
await db.execute(f'''
CREATE TRIGGER IF NOT EXISTS update_{table_name}_{embedding_column}
AFTER UPDATE OF {text_column} ON {table_name}
BEGIN
UPDATE {table_name}_{embedding_column}
SET {embedding_column} = {embed_fn}('{client_name}', NEW.{text_column})
WHERE rowid = NEW.rowid;
END;
''')
await db.execute(f'''
CREATE TRIGGER IF NOT EXISTS delete_{table_name}_{embedding_column}
AFTER DELETE ON {table_name}
BEGIN
DELETE FROM {table_name}_{embedding_column}
WHERE rowid = OLD.rowid;
END;
''')
def search(text: str, top_n: int, columns: list[str] = ['rowid', text_column]):
'''
Search the vector index for the embedding of the provided text and return
the distance of the top_n nearest matches + their corresponding original table's columns.
'''
col_seq = ', '.join(['distance', *(f"{table_name}.{c}" for c in columns)])
return db.execute(
f'''
SELECT {col_seq}
FROM (
SELECT rowid, distance
FROM {table_name}_{embedding_column}
WHERE {table_name}_{embedding_column}.{embedding_column} MATCH {embed_fn}('{client_name}', ?)
ORDER BY distance
LIMIT ?
)
JOIN {table_name} USING (rowid)
''',
(text, top_n)
)
return search
await db.execute('''
CREATE TABLE IF NOT EXISTS facts (
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT NOT NULL
)
''')
facts_search = await create_vector_index('facts', 'content', 'embedding')
await db.commit()
return dict(
facts_search=facts_search,
)
async def memorize(facts: list[str]):
'Memorize a set of statements / facts.'
async with aiosqlite.connect(db_path) as db:
await setup_db(db)
await db.executemany(
'INSERT INTO facts (content) VALUES (?)',
[(fact,) for fact in facts]
)
await db.commit()
async def search_memory(text: str, top_n: int = 10):
'Search the memory for the closest informations to the provided text (return only the top_n best matches).'
async with aiosqlite.connect(db_path) as db:
db_functions = await setup_db(db)
async with db_functions['facts_search'](text, top_n) as cursor:
# Return a json array of objects w/ columns
results = await cursor.fetchall()
cols = [c[0] for c in cursor.description]
return [dict(zip(cols, row)) for row in results]
# This main entry point is just here for easy debugging
if __name__ == '__main__':
import uvicorn
logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO)
app = fastapi.FastAPI()
app.post('/memorize')(memorize)
app.post('/search_memory')(search_memory)
uvicorn.run(app)
# SPDX-License-Identifier: Apache-2.0
import re
from IPython.core.interactiveshell import InteractiveShell
from io import StringIO
import logging
import sys
python_tools_registry = {}
def _strip_ansi_codes(text):
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
return ansi_escape.sub('', text)
def python(code: str) -> str:
'''
Execute Python code in a siloed environment using IPython and return the output.
Parameters:
code (str): The Python code to execute.
Returns:
str: The output of the executed code.
'''
logging.debug('[python] Executing %s', code)
shell = InteractiveShell(
colors='neutral',
)
shell.user_global_ns.update(python_tools_registry)
old_stdout = sys.stdout
sys.stdout = out = StringIO()
try:
shell.run_cell(code)
except Exception as e:
# logging.debug('[python] Execution failed: %s\nCode: %s', e, code)
return f'An error occurred:\n{_strip_ansi_codes(str(e))}'
finally:
sys.stdout = old_stdout
return _strip_ansi_codes(out.getvalue())
# SPDX-License-Identifier: Apache-2.0
import itertools
import json
import logging
import os
from typing import Dict, List
import urllib.parse
import requests
def _extract_values(keys, obj):
return dict((k, v) for k in keys if (v := obj.get(k)) is not None)
# Let's keep this tool aligned w/ llama_stack.providers.impls.meta_reference.agents.tools.builtin.BraveSearch
# (see https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/impls/meta_reference/agents/tools/builtin.py)
_brave_search_result_keys_by_type = {
'web': ('type', 'title', 'url', 'description', 'date', 'extra_snippets'),
'videos': ('type', 'title', 'url', 'description', 'date'),
'news': ('type', 'title', 'url', 'description'),
'infobox': ('type', 'title', 'url', 'description', 'long_desc'),
'locations': ('type', 'title', 'url', 'description', 'coordinates', 'postal_address', 'contact', 'rating', 'distance', 'zoom_level'),
'faq': ('type', 'title', 'url', 'question', 'answer'),
}
async def brave_search(*, query: str) -> List[Dict]:
'''
Search the Brave Search API for the specified query.
Parameters:
query (str): The query to search for.
Returns:
List[Dict]: The search results.
'''
logging.debug('[brave_search] Searching for %s', query)
max_results = 10
url = f'https://api.search.brave.com/res/v1/web/search?q={urllib.parse.quote(query)}'
headers = {
'Accept': 'application/json',
'Accept-Encoding': 'gzip',
'X-Subscription-Token': os.environ['BRAVE_SEARCH_API_KEY'],
}
def extract_results(search_response):
# print("SEARCH RESPONSE: " + json.dumps(search_response, indent=2))
for m in search_response['mixed']['main']:
result_type = m['type']
keys = _brave_search_result_keys_by_type.get(result_type)
if keys is None:
logging.warning(f'[brave_search] Unknown result type: %s', result_type)
continue
results_of_type = search_response[result_type]['results']
if (idx := m.get('index')) is not None:
yield _extract_values(keys, results_of_type[idx])
elif m['all']:
for r in results_of_type:
yield _extract_values(keys, r)
response = requests.get(url, headers=headers)
if not response.ok:
raise Exception(response.text)
response.raise_for_status()
response_json = response.json()
results = list(itertools.islice(extract_results(response_json), max_results))
print(json.dumps(dict(query=query, response=response_json, results=results), indent=2))
return results
# SPDX-License-Identifier: Apache-2.0
import json
import logging
from SPARQLWrapper import JSON, SPARQLWrapper
def execute_sparql(endpoint: str, query: str) -> str:
'''
Execute a SPARQL query on a given endpoint
'''
logging.debug(f'[sparql] Executing on %s:\n%s', endpoint, query)
sparql = SPARQLWrapper(endpoint)
sparql.setQuery(query)
sparql.setReturnFormat(JSON)
return json.dumps(sparql.query().convert(), indent=2)
def wikidata_sparql(query: str) -> str:
'Execute a SPARQL query on Wikidata'
return execute_sparql("https://query.wikidata.org/sparql", query)
def dbpedia_sparql(query: str) -> str:
'Execute a SPARQL query on DBpedia'
return execute_sparql("https://dbpedia.org/sparql", query)
# SPDX-License-Identifier: Apache-2.0
'''
Runs simple tools as a FastAPI server.
Usage (docker isolation - with network access):
export BRAVE_SEARCH_API_KEY=...
./serve_tools_inside_docker.sh
Usage (non-siloed, DANGEROUS):
pip install -r requirements.txt --upgrade
export BRAVE_SEARCH_API_KEY=...
export EMBEDDINGS_DIMS=768
export EMBEDDINGS_MODEL_FILE=/models/nomic-embed-text-v1.5.Q4_K_M.gguf
export MEMORY_SQLITE_DB=./data/memory.db
fastapi dev tools.py --port 8088
'''
import logging
import fastapi
import os
import re
import sys
sys.path.insert(0, os.path.dirname(__file__))
from tool_fetch import fetch
from tool_search import brave_search
from tool_python import python, python_tools_registry
from tool_memory import memorize, search_memory
from tool_sparql import wikidata_sparql, dbpedia_sparql
verbose = os.environ.get('VERBOSE', '0') == '1'
include = os.environ.get('INCLUDE_TOOLS')
exclude = os.environ.get('EXCLUDE_TOOLS')
logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO)
ALL_TOOLS = {
fn.__name__: fn
for fn in [
python,
fetch,
brave_search,
memorize,
search_memory,
wikidata_sparql,
dbpedia_sparql,
]
}
app = fastapi.FastAPI()
for name, fn in ALL_TOOLS.items():
if include and not re.match(include, fn.__name__):
continue
if exclude and re.match(exclude, fn.__name__):
continue
app.post(f'/{name}')(fn)
if name != 'python':
python_tools_registry[name] = fn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy