16
16
*/
17
17
import { InferenceOutputError } from "../lib/InferenceOutputError.js" ;
18
18
import { isUrl } from "../lib/isUrl.js" ;
19
+ import type { TextToVideoArgs } from "../tasks/index.js" ;
19
20
import type { BodyParams , UrlParams } from "../types.js" ;
21
+ import { delay } from "../utils/delay.js" ;
20
22
import { omit } from "../utils/omit.js" ;
21
23
import {
22
24
BaseConversationalTask ,
@@ -26,11 +28,11 @@ import {
26
28
} from "./providerHelper.js" ;
27
29
28
30
const NOVITA_API_BASE_URL = "https://api.novita.ai" ;
29
- export interface NovitaOutput {
30
- video : {
31
- video_url : string ;
32
- } ;
31
+
32
+ export interface NovitaAsyncAPIOutput {
33
+ task_id : string ;
33
34
}
35
+
34
36
export class NovitaTextGenerationTask extends BaseTextGenerationTask {
35
37
constructor ( ) {
36
38
super ( "novita" , NOVITA_API_BASE_URL ) ;
@@ -50,38 +52,94 @@ export class NovitaConversationalTask extends BaseConversationalTask {
50
52
return "/v3/openai/chat/completions" ;
51
53
}
52
54
}
55
+
53
56
export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToVideoTaskHelper {
54
57
constructor ( ) {
55
58
super ( "novita" , NOVITA_API_BASE_URL ) ;
56
59
}
57
60
58
- makeRoute ( params : UrlParams ) : string {
59
- return `/v3/hf /${ params . model } ` ;
61
+ override makeRoute ( params : UrlParams ) : string {
62
+ return `/v3/async /${ params . model } ` ;
60
63
}
61
64
62
- preparePayload ( params : BodyParams ) : Record < string , unknown > {
65
+ override preparePayload ( params : BodyParams < TextToVideoArgs > ) : Record < string , unknown > {
66
+ const { num_inference_steps, ...restParameters } = params . args . parameters ?? { } ;
63
67
return {
64
68
...omit ( params . args , [ "inputs" , "parameters" ] ) ,
65
- ...( params . args . parameters as Record < string , unknown > ) ,
69
+ ...restParameters ,
70
+ steps : num_inference_steps ,
66
71
prompt : params . args . inputs ,
67
72
} ;
68
73
}
69
- override async getResponse ( response : NovitaOutput ) : Promise < Blob > {
70
- const isValidOutput =
71
- typeof response === "object" &&
72
- ! ! response &&
73
- "video" in response &&
74
- typeof response . video === "object" &&
75
- ! ! response . video &&
76
- "video_url" in response . video &&
77
- typeof response . video . video_url === "string" &&
78
- isUrl ( response . video . video_url ) ;
79
74
80
- if ( ! isValidOutput ) {
81
- throw new InferenceOutputError ( "Expected { video: { video_url: string } }" ) ;
75
+ override async getResponse (
76
+ response : NovitaAsyncAPIOutput ,
77
+ url ?: string ,
78
+ headers ?: Record < string , string >
79
+ ) : Promise < Blob > {
80
+ if ( ! url || ! headers ) {
81
+ throw new InferenceOutputError ( "URL and headers are required for text-to-video task" ) ;
82
82
}
83
+ const taskId = response . task_id ;
84
+ if ( ! taskId ) {
85
+ throw new InferenceOutputError ( "No task ID found in the response" ) ;
86
+ }
87
+
88
+ const parsedUrl = new URL ( url ) ;
89
+ const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
90
+ parsedUrl . host === "router.huggingface.co" ? "/novita" : ""
91
+ } `;
92
+ const resultUrl = `${ baseUrl } /v3/async/task-result?task_id=${ taskId } ` ;
93
+
94
+ let status = "" ;
95
+ let taskResult : unknown ;
83
96
84
- const urlResponse = await fetch ( response . video . video_url ) ;
85
- return await urlResponse . blob ( ) ;
97
+ while ( status !== "TASK_STATUS_SUCCEED" && status !== "TASK_STATUS_FAILED" ) {
98
+ await delay ( 500 ) ;
99
+ const resultResponse = await fetch ( resultUrl , { headers } ) ;
100
+ if ( ! resultResponse . ok ) {
101
+ throw new InferenceOutputError ( "Failed to fetch task result" ) ;
102
+ }
103
+ try {
104
+ taskResult = await resultResponse . json ( ) ;
105
+ if (
106
+ taskResult &&
107
+ typeof taskResult === "object" &&
108
+ "task" in taskResult &&
109
+ taskResult . task &&
110
+ typeof taskResult . task === "object" &&
111
+ "status" in taskResult . task &&
112
+ typeof taskResult . task . status === "string"
113
+ ) {
114
+ status = taskResult . task . status ;
115
+ } else {
116
+ throw new InferenceOutputError ( "Failed to get task status" ) ;
117
+ }
118
+ } catch ( error ) {
119
+ throw new InferenceOutputError ( "Failed to parse task result" ) ;
120
+ }
121
+ }
122
+
123
+ if ( status === "TASK_STATUS_FAILED" ) {
124
+ throw new InferenceOutputError ( "Task failed" ) ;
125
+ }
126
+
127
+ if (
128
+ typeof taskResult === "object" &&
129
+ ! ! taskResult &&
130
+ "videos" in taskResult &&
131
+ typeof taskResult . videos === "object" &&
132
+ ! ! taskResult . videos &&
133
+ Array . isArray ( taskResult . videos ) &&
134
+ taskResult . videos . length > 0 &&
135
+ "video_url" in taskResult . videos [ 0 ] &&
136
+ typeof taskResult . videos [ 0 ] . video_url === "string" &&
137
+ isUrl ( taskResult . videos [ 0 ] . video_url )
138
+ ) {
139
+ const urlResponse = await fetch ( taskResult . videos [ 0 ] . video_url ) ;
140
+ return await urlResponse . blob ( ) ;
141
+ } else {
142
+ throw new InferenceOutputError ( "Expected { videos: [{ video_url: string }] }" ) ;
143
+ }
86
144
}
87
145
}
0 commit comments