Commits


kunal-vaishnavi authored and GitHub committed 5b663d6797f
Whisper Multitask and Multilingual (#15936) ### Description This PR enables Whisper's multitask format and allows a user to use Whisper for multiple tasks (e.g. transcription, translation) and for multilingual purposes (e.g. English, Spanish). This PR also removes `attention_mask` as a required input for Whisper with beam search. ### Usage Here is an example of how you can use Whisper for English transcription. ``` import numpy as np import onnxruntime as ort from datasets import load_dataset from transformers import AutoConfig, AutoProcessor model = "openai/whisper-tiny" config = AutoConfig.from_pretrained(model) processor = AutoProcessor.from_pretrained(model) forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") # forced_decoder_ids is of the format [(1, 50259), (2, 50359), (3, 50363)] and needs to be # of the format [50258, 50259, 50359, 50363] where 50258 is the start token id forced_decoder_ids = [config.decoder_start_token_id] + list(map(lambda token: token[1], forced_decoder_ids)) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor(ds[0]["audio"]["array"], return_tensors="np").input_features inputs = { "input_features": np.float32(input_features), "max_length": np.array([26], dtype=np.int32), "min_length": np.array([1], dtype=np.int32), "num_beams": np.array([2], dtype=np.int32), "num_return_sequences": np.array([1], dtype=np.int32), "length_penalty": np.array([1.0], dtype=np.float32), "repetition_penalty": np.array([1.0], dtype=np.float32), "decoder_input_ids": np.array([forced_decoder_ids], dtype=np.int32), } sess = ort.InferenceSession("whisper-tiny_beamsearch.onnx", providers=["CPUExecutionProvider"]) outputs = sess.run(None, inputs) # Print tokens and decoded output print(outputs[0][0][0]) print(processor.decode(outputs[0][0][0])) ``` If you don't want to provide specific decoder input ids or you want Whisper to predict the output language and task, you can set `forced_decoder_ids = [config.decoder_start_token_id]` instead. ### Motivation and Context As seen in the figure below from the [OpenAI Whisper paper](https://cdn.openai.com/papers/whisper.pdf), Whisper can be used for multiple tasks and languages. 