@@ -20,11 +20,14 @@ import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.
20
20
import { omit } from "../utils/omit.js" ;
21
21
import {
22
22
TaskProviderHelper ,
23
+ type AutomaticSpeechRecognitionTaskHelper ,
23
24
type ImageToImageTaskHelper ,
24
25
type TextToImageTaskHelper ,
25
26
type TextToVideoTaskHelper ,
26
27
} from "./providerHelper.js" ;
27
28
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js" ;
29
+ import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js" ;
30
+ import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks" ;
28
31
import { base64FromBytes } from "../utils/base64FromBytes.js" ;
29
32
export interface ReplicateOutput {
30
33
output ?: string | string [ ] ;
@@ -163,6 +166,64 @@ export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVid
163
166
}
164
167
}
165
168
169
+ export class ReplicateAutomaticSpeechRecognitionTask
170
+ extends ReplicateTask
171
+ implements AutomaticSpeechRecognitionTaskHelper
172
+ {
173
+ override preparePayload ( params : BodyParams ) : Record < string , unknown > {
174
+ return {
175
+ input : {
176
+ ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
177
+ ...( params . args . parameters as Record < string , unknown > ) ,
178
+ audio : params . args . inputs , // This will be processed in preparePayloadAsync
179
+ } ,
180
+ version : params . model . includes ( ":" ) ? params . model . split ( ":" ) [ 1 ] : undefined ,
181
+ } ;
182
+ }
183
+
184
+ async preparePayloadAsync ( args : AutomaticSpeechRecognitionArgs ) : Promise < RequestArgs > {
185
+ const blob = "data" in args && args . data instanceof Blob ? args . data : "inputs" in args ? args . inputs : undefined ;
186
+
187
+ if ( ! blob || ! ( blob instanceof Blob ) ) {
188
+ throw new Error ( "Audio input must be a Blob" ) ;
189
+ }
190
+
191
+ // Convert Blob to base64 data URL
192
+ const bytes = new Uint8Array ( await blob . arrayBuffer ( ) ) ;
193
+ const base64 = base64FromBytes ( bytes ) ;
194
+ const audioInput = `data:${ blob . type || "audio/wav" } ;base64,${ base64 } ` ;
195
+
196
+ return {
197
+ ...( "data" in args ? omit ( args , "data" ) : omit ( args , "inputs" ) ) ,
198
+ inputs : audioInput ,
199
+ } ;
200
+ }
201
+
202
+ override async getResponse ( response : ReplicateOutput ) : Promise < AutomaticSpeechRecognitionOutput > {
203
+ if ( typeof response ?. output === "string" ) return { text : response . output } ;
204
+ if ( Array . isArray ( response ?. output ) && typeof response . output [ 0 ] === "string" ) return { text : response . output [ 0 ] } ;
205
+
206
+ const out = response ?. output as
207
+ | undefined
208
+ | {
209
+ transcription ?: string ;
210
+ translation ?: string ;
211
+ txt_file ?: string ;
212
+ } ;
213
+ if ( out && typeof out === "object" ) {
214
+ if ( typeof out . transcription === "string" ) return { text : out . transcription } ;
215
+ if ( typeof out . translation === "string" ) return { text : out . translation } ;
216
+ if ( typeof out . txt_file === "string" ) {
217
+ const r = await fetch ( out . txt_file ) ;
218
+ return { text : await r . text ( ) } ;
219
+ }
220
+ }
221
+ throw new InferenceClientProviderOutputError (
222
+ "Received malformed response from Replicate automatic-speech-recognition API"
223
+ ) ;
224
+ }
225
+ }
226
+
166
227
export class ReplicateImageToImageTask extends ReplicateTask implements ImageToImageTaskHelper {
167
228
override preparePayload ( params : BodyParams < ImageToImageArgs > ) : Record < string , unknown > {
168
229
return {
0 commit comments