3
3
#
4
4
# This source code is licensed under the terms described in the LICENSE file in
5
5
# the root directory of this source tree.
6
-
6
+ from llama_models . sku_list import resolve_model
7
7
from together import Together
8
8
9
+ from llama_models .llama3 .api .datatypes import * # noqa: F403
10
+ from llama_stack .apis .safety import (
11
+ RunShieldResponse ,
12
+ Safety ,
13
+ SafetyViolation ,
14
+ ViolationLevel ,
15
+ )
9
16
from llama_stack .distribution .request_headers import get_request_provider_data
10
17
11
- from .config import TogetherProviderDataValidator , TogetherSafetyConfig
18
+ from .config import TogetherSafetyConfig
19
+
20
+ SAFETY_SHIELD_TYPES = {
21
+ "Llama-Guard-3-8B" : "meta-llama/Meta-Llama-Guard-3-8B" ,
22
+ "Llama-Guard-3-11B-Vision" : "meta-llama/Llama-Guard-3-11B-Vision-Turbo" ,
23
+ }
24
+
25
+
26
+ def shield_type_to_model_name (shield_type : str ) -> str :
27
+ if shield_type == "llama_guard" :
28
+ shield_type = "Llama-Guard-3-8B"
29
+
30
+ model = resolve_model (shield_type )
31
+ if (
32
+ model is None
33
+ or not model .descriptor (shorten_default_variant = True ) in SAFETY_SHIELD_TYPES
34
+ or model .model_family is not ModelFamily .safety
35
+ ):
36
+ raise ValueError (
37
+ f"{ shield_type } is not supported, please use of { ',' .join (SAFETY_SHIELD_TYPES .keys ())} "
38
+ )
39
+
40
+ return SAFETY_SHIELD_TYPES .get (model .descriptor (shorten_default_variant = True ))
12
41
13
42
14
43
class TogetherSafetyImpl (Safety ):
@@ -21,48 +50,42 @@ async def initialize(self) -> None:
21
50
async def run_shield (
22
51
self , shield_type : str , messages : List [Message ], params : Dict [str , Any ] = None
23
52
) -> RunShieldResponse :
24
- if shield_type != "llama_guard" :
25
- raise ValueError (f"shield type { shield_type } is not supported" )
26
-
27
- provider_data = get_request_provider_data ()
28
53
29
54
together_api_key = None
30
- if provider_data is not None :
31
- if not isinstance (provider_data , TogetherProviderDataValidator ):
32
- raise ValueError (
33
- 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
34
- )
35
-
36
- together_api_key = provider_data .together_api_key
37
- if not together_api_key :
38
- together_api_key = self .config .api_key
55
+ provider_data = get_request_provider_data ()
56
+ if provider_data is None or not provider_data .together_api_key :
57
+ raise ValueError (
58
+ 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
59
+ )
60
+ together_api_key = provider_data .together_api_key
39
61
40
- if not together_api_key :
41
- raise ValueError ("The API key must be provider in the header or config" )
62
+ model_name = shield_type_to_model_name (shield_type )
42
63
43
64
# messages can have role assistant or user
44
65
api_messages = []
45
66
for message in messages :
46
67
if message .role in (Role .user .value , Role .assistant .value ):
47
68
api_messages .append ({"role" : message .role , "content" : message .content })
48
69
49
- violation = await get_safety_response (together_api_key , api_messages )
70
+ violation = await get_safety_response (
71
+ together_api_key , model_name , api_messages
72
+ )
50
73
return RunShieldResponse (violation = violation )
51
74
52
75
53
76
async def get_safety_response (
54
- api_key : str , messages : List [Dict [str , str ]]
77
+ api_key : str , model_name : str , messages : List [Dict [str , str ]]
55
78
) -> Optional [SafetyViolation ]:
56
79
client = Together (api_key = api_key )
57
- response = client .chat .completions .create (
58
- messages = messages , model = "meta-llama/Meta-Llama-Guard-3-8B"
59
- )
80
+ response = client .chat .completions .create (messages = messages , model = model_name )
60
81
if len (response .choices ) == 0 :
61
82
return None
62
83
63
84
response_text = response .choices [0 ].message .content
64
85
if response_text == "safe" :
65
- return None
86
+ return SafetyViolation (
87
+ violation_level = ViolationLevel .INFO , user_message = "safe" , metadata = {}
88
+ )
66
89
67
90
parts = response_text .split ("\n " )
68
91
if len (parts ) != 2 :
0 commit comments