Skip to content

Commit a0cd70b

Browse files
lucatacohanouticelinacoyotte508
authored andcommitted
[Inference] Add ASR support for Replicate provider (huggingface#1679)
Hello! This PR adds support for the `Automatic Speech Recognition` task type for Replicate models. Example: - [huggingface.co/openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) - [replicate.com/openai/whisper](https://replicate.com/openai/whisper) cc @hanouticelina --------- Co-authored-by: Celina Hanouti <hanouticelina@gmail.com> Co-authored-by: Eliott C. <coyotte508@gmail.com>
1 parent 4e276ce commit a0cd70b

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
144144
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
145145
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
146146
"image-to-image": new Replicate.ReplicateImageToImageTask(),
147+
"automatic-speech-recognition": new Replicate.ReplicateAutomaticSpeechRecognitionTask(),
147148
},
148149
sambanova: {
149150
conversational: new Sambanova.SambanovaConversationalTask(),

packages/inference/src/providers/replicate.ts

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@ import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.
2020
import { omit } from "../utils/omit.js";
2121
import {
2222
TaskProviderHelper,
23+
type AutomaticSpeechRecognitionTaskHelper,
2324
type ImageToImageTaskHelper,
2425
type TextToImageTaskHelper,
2526
type TextToVideoTaskHelper,
2627
} from "./providerHelper.js";
2728
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
29+
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js";
30+
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2831
import { base64FromBytes } from "../utils/base64FromBytes.js";
2932
export interface ReplicateOutput {
3033
output?: string | string[];
@@ -163,6 +166,64 @@ export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVid
163166
}
164167
}
165168

169+
export class ReplicateAutomaticSpeechRecognitionTask
170+
extends ReplicateTask
171+
implements AutomaticSpeechRecognitionTaskHelper
172+
{
173+
override preparePayload(params: BodyParams): Record<string, unknown> {
174+
return {
175+
input: {
176+
...omit(params.args, ["inputs", "parameters"]),
177+
...(params.args.parameters as Record<string, unknown>),
178+
audio: params.args.inputs, // This will be processed in preparePayloadAsync
179+
},
180+
version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
181+
};
182+
}
183+
184+
async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
185+
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
186+
187+
if (!blob || !(blob instanceof Blob)) {
188+
throw new Error("Audio input must be a Blob");
189+
}
190+
191+
// Convert Blob to base64 data URL
192+
const bytes = new Uint8Array(await blob.arrayBuffer());
193+
const base64 = base64FromBytes(bytes);
194+
const audioInput = `data:${blob.type || "audio/wav"};base64,${base64}`;
195+
196+
return {
197+
...("data" in args ? omit(args, "data") : omit(args, "inputs")),
198+
inputs: audioInput,
199+
};
200+
}
201+
202+
override async getResponse(response: ReplicateOutput): Promise<AutomaticSpeechRecognitionOutput> {
203+
if (typeof response?.output === "string") return { text: response.output };
204+
if (Array.isArray(response?.output) && typeof response.output[0] === "string") return { text: response.output[0] };
205+
206+
const out = response?.output as
207+
| undefined
208+
| {
209+
transcription?: string;
210+
translation?: string;
211+
txt_file?: string;
212+
};
213+
if (out && typeof out === "object") {
214+
if (typeof out.transcription === "string") return { text: out.transcription };
215+
if (typeof out.translation === "string") return { text: out.translation };
216+
if (typeof out.txt_file === "string") {
217+
const r = await fetch(out.txt_file);
218+
return { text: await r.text() };
219+
}
220+
}
221+
throw new InferenceClientProviderOutputError(
222+
"Received malformed response from Replicate automatic-speech-recognition API"
223+
);
224+
}
225+
}
226+
166227
export class ReplicateImageToImageTask extends ReplicateTask implements ImageToImageTaskHelper {
167228
override preparePayload(params: BodyParams<ImageToImageArgs>): Record<string, unknown> {
168229
return {

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import { getProviderHelper } from "../../lib/getProviderHelper.js";
44
import type { BaseArgs, Options } from "../../types.js";
55
import { innerRequest } from "../../utils/request.js";
66
import type { LegacyAudioInput } from "./utils.js";
7-
import { InferenceClientProviderOutputError } from "../../errors.js";
87

98
export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput);
109
/**
@@ -22,9 +21,5 @@ export async function automaticSpeechRecognition(
2221
...options,
2322
task: "automatic-speech-recognition",
2423
});
25-
const isValidOutput = typeof res?.text === "string";
26-
if (!isValidOutput) {
27-
throw new InferenceClientProviderOutputError("Received malformed response from automatic-speech-recognition API");
28-
}
2924
return providerHelper.getResponse(res);
3025
}

0 commit comments

Comments
 (0)