Skip to content

Commit 90ce13c

Browse files
WauplinSBrandeis
andauthored
[InferenceSnippet] Take token from env variable if not set (#1514)
Solve #1361. Long awaited feature for @gary149. I did not go for the cleanest solution but it works well and should be robust/flexible enough if we need to fix something in the future. ## EDIT: breaking change => access token should be passed as `opts.accessToken` now in `snippets.getInferenceSnippets` ## TODO once merged: - [ ] adapt in moon-landing for snippets on model page huggingface-internal/moon-landing#13964 - [ ] adapt in doc-builder for `<inferencesnippet>` html tag (used in hub-docs) huggingface/doc-builder#570 - [ ] hardcoded examples in hub-docs huggingface/hub-docs#1764 ## Some examples: ### JS client ```js import { InferenceClient } from "@huggingface/inference"; const client = new InferenceClient(process.env.HF_TOKEN); const chatCompletion = await client.chatCompletion({ provider: "hf-inference", model: "meta-llama/Llama-3.1-8B-Instruct", messages: [ { role: "user", content: "What is the capital of France?", }, ], }); console.log(chatCompletion.choices[0].message); ``` ### Python client ```py import os from huggingface_hub import InferenceClient client = InferenceClient( provider="hf-inference", api_key=os.environ["HF_TOKEN"], ) completion = client.chat.completions.create( model="meta-llama/Llama-3.1-8B-Instruct", messages=[ { "role": "user", "content": "What is the capital of France?" } ], ) print(completion.choices[0].message) ``` ### openai client ```py import os from openai import OpenAI client = OpenAI( base_url="https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1", api_key=os.environ["HF_TOKEN"], ) completion = client.chat.completions.create( model="meta-llama/Llama-3.1-8B-Instruct", messages=[ { "role": "user", "content": "What is the capital of France?" } ], ) print(completion.choices[0].message) ``` ### curl ```sh curl https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions \ -H "Authorization: Bearer $HF_TOKEN" \ -H 'Content-Type: application/json' \ -d '{ "messages": [ { "role": "user", "content": "What is the capital of France?" } ], "model": "meta-llama/Llama-3.1-8B-Instruct", "stream": false }' ``` ### check out PR diff for more examples --------- Co-authored-by: Simon Brandeis <33657802+SBrandeis@users.noreply.github.com>
1 parent 49d93f2 commit 90ce13c

File tree

119 files changed

+352
-119
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+352
-119
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.j
1414
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js";
1515
import { templates } from "./templates.exported.js";
1616

17-
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string } & Record<string, unknown>;
17+
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string; accessToken?: string } & Record<
18+
string,
19+
unknown
20+
>;
1821

1922
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
2023
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
@@ -121,11 +124,12 @@ const HF_JS_METHODS: Partial<Record<WidgetType, string>> = {
121124
translation: "translation",
122125
};
123126

127+
const ACCESS_TOKEN_PLACEHOLDER = "<ACCESS_TOKEN>"; // Placeholder to replace with env variable in snippets
128+
124129
// Snippet generators
125130
const snippetGenerator = (templateName: string, inputPreparationFn?: InputPreparationFn) => {
126131
return (
127132
model: ModelDataMinimal,
128-
accessToken: string,
129133
provider: InferenceProviderOrPolicy,
130134
inferenceProviderMapping?: InferenceProviderModelMapping,
131135
opts?: InferenceSnippetOptions
@@ -149,13 +153,15 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
149153
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
150154
return [];
151155
}
156+
const accessTokenOrPlaceholder = opts?.accessToken ?? ACCESS_TOKEN_PLACEHOLDER;
157+
152158
/// Prepare inputs + make request
153159
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
154160
const request = makeRequestOptionsFromResolvedModel(
155161
providerModelId,
156162
providerHelper,
157163
{
158-
accessToken,
164+
accessToken: accessTokenOrPlaceholder,
159165
provider,
160166
...inputs,
161167
} as RequestArgs,
@@ -180,7 +186,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
180186

181187
/// Prepare template injection data
182188
const params: TemplateParams = {
183-
accessToken,
189+
accessToken: accessTokenOrPlaceholder,
184190
authorizationHeader: (request.info.headers as Record<string, string>)?.Authorization,
185191
baseUrl: removeSuffix(request.url, "/chat/completions"),
186192
fullUrl: request.url,
@@ -248,6 +254,11 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
248254
snippet = `${importSection}\n\n${snippet}`;
249255
}
250256

257+
/// Replace access token placeholder
258+
if (snippet.includes(ACCESS_TOKEN_PLACEHOLDER)) {
259+
snippet = replaceAccessTokenPlaceholder(snippet, language, provider);
260+
}
261+
251262
/// Snippet is ready!
252263
return { language, client: client as string, content: snippet };
253264
})
@@ -299,7 +310,6 @@ const snippets: Partial<
299310
PipelineType,
300311
(
301312
model: ModelDataMinimal,
302-
accessToken: string,
303313
provider: InferenceProviderOrPolicy,
304314
inferenceProviderMapping?: InferenceProviderModelMapping,
305315
opts?: InferenceSnippetOptions
@@ -339,13 +349,12 @@ const snippets: Partial<
339349

340350
export function getInferenceSnippets(
341351
model: ModelDataMinimal,
342-
accessToken: string,
343352
provider: InferenceProviderOrPolicy,
344353
inferenceProviderMapping?: InferenceProviderModelMapping,
345354
opts?: Record<string, unknown>
346355
): InferenceSnippet[] {
347356
return model.pipeline_tag && model.pipeline_tag in snippets
348-
? snippets[model.pipeline_tag]?.(model, accessToken, provider, inferenceProviderMapping, opts) ?? []
357+
? snippets[model.pipeline_tag]?.(model, provider, inferenceProviderMapping, opts) ?? []
349358
: [];
350359
}
351360

@@ -420,3 +429,56 @@ function indentString(str: string): string {
420429
function removeSuffix(str: string, suffix: string) {
421430
return str.endsWith(suffix) ? str.slice(0, -suffix.length) : str;
422431
}
432+
433+
function replaceAccessTokenPlaceholder(
434+
snippet: string,
435+
language: InferenceSnippetLanguage,
436+
provider: InferenceProviderOrPolicy
437+
): string {
438+
// If "opts.accessToken" is not set, the snippets are generated with a placeholder.
439+
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.
440+
441+
// Determine if HF_TOKEN or specific provider token should be used
442+
const accessTokenEnvVar =
443+
!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
444+
snippet.includes("https://router.huggingface.co") || // explicit routed request => use $HF_TOKEN
445+
provider == "hf-inference" // hf-inference provider => use $HF_TOKEN
446+
? "HF_TOKEN"
447+
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
448+
449+
// Replace the placeholder with the env variable
450+
if (language === "sh") {
451+
snippet = snippet.replace(
452+
`'Authorization: Bearer ${ACCESS_TOKEN_PLACEHOLDER}'`,
453+
`"Authorization: Bearer $${accessTokenEnvVar}"` // e.g. "Authorization: Bearer $HF_TOKEN"
454+
);
455+
} else if (language === "python") {
456+
snippet = "import os\n" + snippet;
457+
snippet = snippet.replace(
458+
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
459+
`os.environ["${accessTokenEnvVar}"]` // e.g. os.environ["HF_TOKEN")
460+
);
461+
snippet = snippet.replace(
462+
`"Bearer ${ACCESS_TOKEN_PLACEHOLDER}"`,
463+
`f"Bearer {os.environ['${accessTokenEnvVar}']}"` // e.g. f"Bearer {os.environ['HF_TOKEN']}"
464+
);
465+
snippet = snippet.replace(
466+
`"Key ${ACCESS_TOKEN_PLACEHOLDER}"`,
467+
`f"Key {os.environ['${accessTokenEnvVar}']}"` // e.g. f"Key {os.environ['FAL_AI_API_KEY']}"
468+
);
469+
} else if (language === "js") {
470+
snippet = snippet.replace(
471+
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
472+
`process.env.${accessTokenEnvVar}` // e.g. process.env.HF_TOKEN
473+
);
474+
snippet = snippet.replace(
475+
`Authorization: "Bearer ${ACCESS_TOKEN_PLACEHOLDER}",`,
476+
`Authorization: \`Bearer $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `Bearer ${process.env.HF_TOKEN}`,
477+
);
478+
snippet = snippet.replace(
479+
`Authorization: "Key ${ACCESS_TOKEN_PLACEHOLDER}",`,
480+
`Authorization: \`Key $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `Key ${process.env.FAL_AI_API_KEY}`,
481+
);
482+
}
483+
return snippet;
484+
}

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,18 @@ const TEST_CASES: {
240240
providers: ["hf-inference"],
241241
opts: { billTo: "huggingface" },
242242
},
243+
{
244+
testName: "with-access-token",
245+
task: "conversational",
246+
model: {
247+
id: "meta-llama/Llama-3.1-8B-Instruct",
248+
pipeline_tag: "text-generation",
249+
tags: ["conversational"],
250+
inference: "",
251+
},
252+
providers: ["hf-inference"],
253+
opts: { accessToken: "hf_xxx" },
254+
},
243255
{
244256
testName: "text-to-speech",
245257
task: "text-to-speech",
@@ -314,7 +326,6 @@ function generateInferenceSnippet(
314326
): InferenceSnippet[] {
315327
const allSnippets = snippets.getInferenceSnippets(
316328
model,
317-
"api_token",
318329
provider,
319330
{
320331
hfModelId: model.id,

packages/tasks-gen/snippets-fixtures/automatic-speech-recognition/js/fetch/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ async function query(data) {
33
"https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo",
44
{
55
headers: {
6-
Authorization: "Bearer api_token",
6+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
77
"Content-Type": "audio/flac",
88
},
99
method: "POST",

packages/tasks-gen/snippets-fixtures/automatic-speech-recognition/js/huggingface.js/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { InferenceClient } from "@huggingface/inference";
22

3-
const client = new InferenceClient("api_token");
3+
const client = new InferenceClient(process.env.HF_TOKEN);
44

55
const data = fs.readFileSync("sample1.flac");
66

Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from huggingface_hub import InferenceClient
23

34
client = InferenceClient(
45
provider="hf-inference",
5-
api_key="api_token",
6+
api_key=os.environ["HF_TOKEN"],
67
)
78

89
output = client.automatic_speech_recognition("sample1.flac", model="openai/whisper-large-v3-turbo")

packages/tasks-gen/snippets-fixtures/automatic-speech-recognition/python/requests/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
import requests
23

34
API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo"
45
headers = {
5-
"Authorization": "Bearer api_token",
6+
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
67
}
78

89
def query(filename):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
curl https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo \
22
-X POST \
3-
-H 'Authorization: Bearer api_token' \
3+
-H "Authorization: Bearer $HF_TOKEN" \
44
-H 'Content-Type: audio/flac' \
55
--data-binary @"sample1.flac"

packages/tasks-gen/snippets-fixtures/basic-snippet--token-classification/js/fetch/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ async function query(data) {
33
"https://router.huggingface.co/hf-inference/models/FacebookAI/xlm-roberta-large-finetuned-conll03-english",
44
{
55
headers: {
6-
Authorization: "Bearer api_token",
6+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
77
"Content-Type": "application/json",
88
},
99
method: "POST",

packages/tasks-gen/snippets-fixtures/basic-snippet--token-classification/js/huggingface.js/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { InferenceClient } from "@huggingface/inference";
22

3-
const client = new InferenceClient("api_token");
3+
const client = new InferenceClient(process.env.HF_TOKEN);
44

55
const output = await client.tokenClassification({
66
model: "FacebookAI/xlm-roberta-large-finetuned-conll03-english",

packages/tasks-gen/snippets-fixtures/basic-snippet--token-classification/python/huggingface_hub/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from huggingface_hub import InferenceClient
23

34
client = InferenceClient(
45
provider="hf-inference",
5-
api_key="api_token",
6+
api_key=os.environ["HF_TOKEN"],
67
)
78

89
result = client.token_classification(

packages/tasks-gen/snippets-fixtures/basic-snippet--token-classification/python/requests/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
import requests
23

34
API_URL = "https://router.huggingface.co/hf-inference/models/FacebookAI/xlm-roberta-large-finetuned-conll03-english"
45
headers = {
5-
"Authorization": "Bearer api_token",
6+
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
67
}
78

89
def query(payload):

packages/tasks-gen/snippets-fixtures/basic-snippet--token-classification/sh/curl/0.hf-inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
curl https://router.huggingface.co/hf-inference/models/FacebookAI/xlm-roberta-large-finetuned-conll03-english \
22
-X POST \
3-
-H 'Authorization: Bearer api_token' \
3+
-H "Authorization: Bearer $HF_TOKEN" \
44
-H 'Content-Type: application/json' \
55
-d '{
66
"inputs": "\"My name is Sarah Jessica Parker but you can call me Jessica\""

packages/tasks-gen/snippets-fixtures/bill-to-param/js/huggingface.js/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { InferenceClient } from "@huggingface/inference";
22

3-
const client = new InferenceClient("api_token");
3+
const client = new InferenceClient(process.env.HF_TOKEN);
44

55
const chatCompletion = await client.chatCompletion({
66
provider: "hf-inference",

packages/tasks-gen/snippets-fixtures/bill-to-param/js/openai/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { OpenAI } from "openai";
22

33
const client = new OpenAI({
44
baseURL: "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
5-
apiKey: "api_token",
5+
apiKey: process.env.HF_TOKEN,
66
defaultHeaders: {
77
"X-HF-Bill-To": "huggingface"
88
}

packages/tasks-gen/snippets-fixtures/bill-to-param/python/huggingface_hub/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from huggingface_hub import InferenceClient
23

34
client = InferenceClient(
45
provider="hf-inference",
5-
api_key="api_token",
6+
api_key=os.environ["HF_TOKEN"],
67
bill_to="huggingface",
78
)
89

packages/tasks-gen/snippets-fixtures/bill-to-param/python/openai/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from openai import OpenAI
23

34
client = OpenAI(
45
base_url="https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
5-
api_key="api_token",
6+
api_key=os.environ["HF_TOKEN"],
67
default_headers={
78
"X-HF-Bill-To": "huggingface"
89
}

packages/tasks-gen/snippets-fixtures/bill-to-param/python/requests/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
import requests
23

34
API_URL = "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions"
45
headers = {
5-
"Authorization": "Bearer api_token",
6+
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
67
"X-HF-Bill-To": "huggingface"
78
}
89

packages/tasks-gen/snippets-fixtures/bill-to-param/sh/curl/0.hf-inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
curl https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions \
2-
-H 'Authorization: Bearer api_token' \
2+
-H "Authorization: Bearer $HF_TOKEN" \
33
-H 'Content-Type: application/json' \
44
-H 'X-HF-Bill-To: huggingface' \
55
-d '{

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/js/huggingface.js/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { InferenceClient } from "@huggingface/inference";
22

3-
const client = new InferenceClient("api_token");
3+
const client = new InferenceClient(process.env.HF_TOKEN);
44

55
const chatCompletion = await client.chatCompletion({
66
provider: "hf-inference",

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/js/huggingface.js/0.together.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { InferenceClient } from "@huggingface/inference";
22

3-
const client = new InferenceClient("api_token");
3+
const client = new InferenceClient(process.env.HF_TOKEN);
44

55
const chatCompletion = await client.chatCompletion({
66
provider: "together",

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/js/openai/0.hf-inference.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { OpenAI } from "openai";
22

33
const client = new OpenAI({
44
baseURL: "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
5-
apiKey: "api_token",
5+
apiKey: process.env.HF_TOKEN,
66
});
77

88
const chatCompletion = await client.chat.completions.create({

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/js/openai/0.together.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { OpenAI } from "openai";
22

33
const client = new OpenAI({
44
baseURL: "https://api.together.xyz/v1",
5-
apiKey: "api_token",
5+
apiKey: process.env.TOGETHER_API_KEY,
66
});
77

88
const chatCompletion = await client.chat.completions.create({

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/python/huggingface_hub/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from huggingface_hub import InferenceClient
23

34
client = InferenceClient(
45
provider="hf-inference",
5-
api_key="api_token",
6+
api_key=os.environ["HF_TOKEN"],
67
)
78

89
completion = client.chat.completions.create(

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/python/huggingface_hub/0.together.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from huggingface_hub import InferenceClient
23

34
client = InferenceClient(
45
provider="together",
5-
api_key="api_token",
6+
api_key=os.environ["HF_TOKEN"],
67
)
78

89
completion = client.chat.completions.create(

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/python/openai/0.hf-inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os
12
from openai import OpenAI
23

34
client = OpenAI(
45
base_url="https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
5-
api_key="api_token",
6+
api_key=os.environ["HF_TOKEN"],
67
)
78

89
completion = client.chat.completions.create(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy