Skip to content

Commit 940968e

Browse files
authored
fixing safety inference and safety adapter for new API spec. Pinned t… (meta-llama#105)
* fixing safety inference and safety adapter for new API spec. Pinned the llama_models version to 0.0.24 as the latest version 0.0.35 has the model descriptor name changed. I was getting the missing package error during runtime as well, hence added the dependency to requirements.txt * support Llama 3.2 models in Together inference adapter and cleanup Together safety adapter * fixing model names * adding vision guard to Together safety
1 parent 0a3999a commit 940968e

File tree

5 files changed

+68
-40
lines changed

5 files changed

+68
-40
lines changed

llama_stack/providers/adapters/inference/together/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from .config import TogetherImplConfig, TogetherHeaderExtractor
7+
from .config import TogetherImplConfig
88

99

1010
async def get_adapter_impl(config: TogetherImplConfig, _deps):

llama_stack/providers/adapters/inference/together/config.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,8 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from pydantic import BaseModel, Field
8-
97
from llama_models.schema_utils import json_schema_type
10-
11-
from llama_stack.distribution.request_headers import annotate_header
12-
13-
14-
class TogetherHeaderExtractor(BaseModel):
15-
api_key: annotate_header(
16-
"X-LlamaStack-Together-ApiKey", str, "The API Key for the request"
17-
)
8+
from pydantic import BaseModel, Field
189

1910

2011
@json_schema_type

llama_stack/providers/adapters/inference/together/together.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818
from llama_stack.providers.utils.inference.augment_messages import (
1919
augment_messages_for_tools,
2020
)
21+
from llama_stack.distribution.request_headers import get_request_provider_data
2122

2223
from .config import TogetherImplConfig
2324

2425
TOGETHER_SUPPORTED_MODELS = {
25-
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo",
26-
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo",
27-
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo",
26+
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
27+
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
28+
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
29+
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
30+
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
31+
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
2832
}
2933

3034

@@ -97,6 +101,16 @@ async def chat_completion(
97101
stream: Optional[bool] = False,
98102
logprobs: Optional[LogProbConfig] = None,
99103
) -> AsyncGenerator:
104+
105+
together_api_key = None
106+
provider_data = get_request_provider_data()
107+
if provider_data is None or not provider_data.together_api_key:
108+
raise ValueError(
109+
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
110+
)
111+
together_api_key = provider_data.together_api_key
112+
113+
client = Together(api_key=together_api_key)
100114
# wrapper request to make it easier to pass around (internal only, not exposed to API)
101115
request = ChatCompletionRequest(
102116
model=model,
@@ -116,7 +130,7 @@ async def chat_completion(
116130

117131
if not request.stream:
118132
# TODO: might need to add back an async here
119-
r = self.client.chat.completions.create(
133+
r = client.chat.completions.create(
120134
model=together_model,
121135
messages=self._messages_to_together_messages(messages),
122136
stream=False,
@@ -151,7 +165,7 @@ async def chat_completion(
151165
ipython = False
152166
stop_reason = None
153167

154-
for chunk in self.client.chat.completions.create(
168+
for chunk in client.chat.completions.create(
155169
model=together_model,
156170
messages=self._messages_to_together_messages(messages),
157171
stream=True,

llama_stack/providers/adapters/safety/together/together.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,41 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
6+
from llama_models.sku_list import resolve_model
77
from together import Together
88

9+
from llama_models.llama3.api.datatypes import * # noqa: F403
10+
from llama_stack.apis.safety import (
11+
RunShieldResponse,
12+
Safety,
13+
SafetyViolation,
14+
ViolationLevel,
15+
)
916
from llama_stack.distribution.request_headers import get_request_provider_data
1017

11-
from .config import TogetherProviderDataValidator, TogetherSafetyConfig
18+
from .config import TogetherSafetyConfig
19+
20+
SAFETY_SHIELD_TYPES = {
21+
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
22+
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
23+
}
24+
25+
26+
def shield_type_to_model_name(shield_type: str) -> str:
27+
if shield_type == "llama_guard":
28+
shield_type = "Llama-Guard-3-8B"
29+
30+
model = resolve_model(shield_type)
31+
if (
32+
model is None
33+
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
34+
or model.model_family is not ModelFamily.safety
35+
):
36+
raise ValueError(
37+
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
38+
)
39+
40+
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
1241

1342

1443
class TogetherSafetyImpl(Safety):
@@ -21,48 +50,42 @@ async def initialize(self) -> None:
2150
async def run_shield(
2251
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
2352
) -> RunShieldResponse:
24-
if shield_type != "llama_guard":
25-
raise ValueError(f"shield type {shield_type} is not supported")
26-
27-
provider_data = get_request_provider_data()
2853

2954
together_api_key = None
30-
if provider_data is not None:
31-
if not isinstance(provider_data, TogetherProviderDataValidator):
32-
raise ValueError(
33-
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
34-
)
35-
36-
together_api_key = provider_data.together_api_key
37-
if not together_api_key:
38-
together_api_key = self.config.api_key
55+
provider_data = get_request_provider_data()
56+
if provider_data is None or not provider_data.together_api_key:
57+
raise ValueError(
58+
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
59+
)
60+
together_api_key = provider_data.together_api_key
3961

40-
if not together_api_key:
41-
raise ValueError("The API key must be provider in the header or config")
62+
model_name = shield_type_to_model_name(shield_type)
4263

4364
# messages can have role assistant or user
4465
api_messages = []
4566
for message in messages:
4667
if message.role in (Role.user.value, Role.assistant.value):
4768
api_messages.append({"role": message.role, "content": message.content})
4869

49-
violation = await get_safety_response(together_api_key, api_messages)
70+
violation = await get_safety_response(
71+
together_api_key, model_name, api_messages
72+
)
5073
return RunShieldResponse(violation=violation)
5174

5275

5376
async def get_safety_response(
54-
api_key: str, messages: List[Dict[str, str]]
77+
api_key: str, model_name: str, messages: List[Dict[str, str]]
5578
) -> Optional[SafetyViolation]:
5679
client = Together(api_key=api_key)
57-
response = client.chat.completions.create(
58-
messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B"
59-
)
80+
response = client.chat.completions.create(messages=messages, model=model_name)
6081
if len(response.choices) == 0:
6182
return None
6283

6384
response_text = response.choices[0].message.content
6485
if response_text == "safe":
65-
return None
86+
return SafetyViolation(
87+
violation_level=ViolationLevel.INFO, user_message="safe", metadata={}
88+
)
6689

6790
parts = response_text.split("\n")
6891
if len(parts) != 2:

llama_stack/providers/registry/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def available_providers() -> List[ProviderSpec]:
9191
],
9292
module="llama_stack.providers.adapters.inference.together",
9393
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
94-
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
94+
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
9595
),
9696
),
9797
]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy