Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use distil-whisper-large-v3-de-kd model from HF? #95

Open
Arche151 opened this issue Mar 4, 2024 · 10 comments
Open

How to use distil-whisper-large-v3-de-kd model from HF? #95

Arche151 opened this issue Mar 4, 2024 · 10 comments

Comments

@Arche151
Copy link

Arche151 commented Mar 4, 2024

Officially, multi-language support is still not implemented in distil-whisper.

But I noticed, that the esteemed @sanchit-gandhi uploaded a German model for distil-whisper to HuggingFace, called 'distil-whisper-large-v3-de-kd'

How can I use this specific model for transcribing something?

@sanchit-gandhi
Copy link
Collaborator

sanchit-gandhi commented Mar 4, 2024

Hey @Arche151 - the "official" Distil-Whisper checkpoints were trained only on English speech recognition data, thus they can only be used for English. However, the training code provided in this repository generalises to all languages:
https://github.com/huggingface/distil-whisper/tree/main/training

The checkpoint you've mentioned is trained using exactly this approach on German speech recognition data, giving a model compatible with German audio. You can use it in exactly the same way as the original Distil-Whisper checkpoints:

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "sanchit-gandhi/distil-whisper-large-v3-de-kd"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("facebook/multilingual_librispeech", "german", split="validation", streaming=True)
sample = next(iter(dataset))["audio"]

result = pipe(sample, generate_kwargs={"language": "german", "task": "transcribe"})
print(result["text"])

Print output:

 dann, und als er das sagte, übertrieb er sehr arg, wie alle, die in Italien geliebt haben.

@sanchit-gandhi
Copy link
Collaborator

You can even use the recipe for this checkpoint to train a model in a different language: https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-de-kd#training-procedure

Simply swap out the --train_dataset_config_name and --eval_dataset_config_name to the common voice split of your choice: https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1

@Arche151
Copy link
Author

Arche151 commented Mar 4, 2024

@sanchit-gandhi Thanks a lot for the quick and insightful response!!

Last question: Do you maybe also know, whether I can use the German model via faster-whisper or would it be necessary to convert it with ctranslate2?

Okay, last last question: Do you know the speed difference of transcription via distil-whisper vs. faster-distil-whisper? I coudln't find a comparison.

@sanchit-gandhi
Copy link
Collaborator

Hey @Arche151, no problem!

  1. You would indeed need to convert the weights from HF Transformers format to faster-whisper format (CTranslate2). You can use this script for the conversion
  2. Distil-Whisper is an architectural change that leads to a faster model (the model itself is inherently faster). Faster-Whisper is an implementation change (the model is the same, but the code is more efficient). That means that speed gains from Distil-Whisper should carry over to Faster-Whisper (now you have a more efficient model and more efficient code), but this might not necessarily be the compound of the two speed gains combined

@Arche151
Copy link
Author

Arche151 commented Mar 5, 2024

@sanchit-gandhi Ahh okay, that background information is super useful. Thanks for taking the time to explain!

Thanks for the script! Should've read the repo better 😅

@Arche151
Copy link
Author

Arche151 commented Mar 5, 2024

@sanchit-gandhi So, I converted the model and used float16 quantization and the quality of the transcription compared to the original large-v3 is really bad :(

A lot of words are transcribed falsely, some words are just not transcribed at all and some words are transcribed twice, so there's duplicates.

In my test script I wrote this: model = WhisperModel(/path/distil-whisper-large-v3-de-kd-ct2", device="cpu", compute_type="int8")

@sanchit-gandhi
Copy link
Collaborator

Hey @Arche151 - could you provide a reproducible code snippet for the behaviour you're seeing? Both for the original large-v3 model, and for the distil-whisper one? This should help discern where the divergence is occurring

@sanchit-gandhi
Copy link
Collaborator

sanchit-gandhi commented Mar 6, 2024

I'm guessing here that you're using a long audio file (> 30 seconds)? One of the limitations with the current Distil-Whisper training code is that it shifts the distribution of the model to shorter audio lengths. This means it often breaks when used with OpenAI's sequential long form algorithm. I'm going to push some updated training code in the coming days, where this phenomenon is addressed

Once this is ready, we can re-train the model and it should be compatible with other Whisper libraries, like Faster-Whisper

@Arche151
Copy link
Author

Arche151 commented Mar 8, 2024

@sanchit-gandhi Thanks for getting back to me!

I'm actually using audios that are shorter than 30 seconds.

Here's the complete Python script, that I'm using, when I use distil-whisper-large-v3-de-kd:

import os
import subprocess
from faster_whisper import WhisperModel

audio_file = "/tmp/audio_recording.wav"
recording_state_file = "/tmp/recording_state"

def start_recording():
    subprocess.Popen(["arecord", "-f", "cd", audio_file])
    open(recording_state_file, 'w').close()

def stop_recording():
    subprocess.call(["pkill", "arecord"])
    if os.path.exists(recording_state_file):
        os.remove(recording_state_file)
    transcribe_audio()
    os.remove(audio_file)

def is_recording():
    return os.path.exists(recording_state_file)

def transcribe_audio():
    model = WhisperModel("path/to/distil-whisper-large-v3-de-kd-ct2", device="cpu")
    segments, info = model.transcribe(audio_file)
    transcription = " ".join([segment.text for segment in segments]).strip()
    subprocess.Popen(["xclip", "-selection", "c"], stdin=subprocess.PIPE).communicate(input=transcription.encode())
    # Notify the user that transcription is complete and copied to clipboard
    subprocess.call(["notify-send", "Transcription Complete", "The transcription has been copied to the clipboard."])

def main():
    if is_recording():
        stop_recording()
    else:
        start_recording()

if __name__ == "__main__":
    main()

And here's the same script, but with large-v3 instead, where the transcription works fine:

import os
import subprocess
from faster_whisper import WhisperModel

audio_file = "/tmp/audio_recording.wav"
recording_state_file = "/tmp/recording_state"

def start_recording():
    subprocess.Popen(["arecord", "-f", "cd", audio_file])
    open(recording_state_file, 'w').close()

def stop_recording():
    subprocess.call(["pkill", "arecord"])
    if os.path.exists(recording_state_file):
        os.remove(recording_state_file)
    transcribe_audio()
    os.remove(audio_file)

def is_recording():
    return os.path.exists(recording_state_file)

def transcribe_audio():
    model = WhisperModel("large-v3", device="cpu", compute_type="int8")
    segments, info = model.transcribe(audio_file)
    transcription = " ".join([segment.text for segment in segments]).strip()
    subprocess.Popen(["xclip", "-selection", "c"], stdin=subprocess.PIPE).communicate(input=transcription.encode())
    # Notify the user that transcription is complete and copied to clipboard
    subprocess.call(["notify-send", "Transcription Complete", "The transcription has been copied to the clipboard."])

def main():
    if is_recording():
        stop_recording()
    else:
        start_recording()

if __name__ == "__main__":
    main()

@Arche151
Copy link
Author

Arche151 commented Apr 2, 2024

@sanchit-gandhi I wanted to ask, whether you had the time to look at my scripts and if the updated training code, that you mentioned, has already been pushed :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants