Skip to content

Commit

Permalink
feat(separation): add PixIT task, ToTaToNet model and SpeechSeparatio…
Browse files Browse the repository at this point in the history
…n pipeline (#1676)

Co-authored-by: Hervé BREDIN <[email protected]>
  • Loading branch information
joonaskalda and hbredin committed May 30, 2024
1 parent f1951a6 commit 49d3b8e
Show file tree
Hide file tree
Showing 9 changed files with 2,316 additions and 0 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,19 @@

## develop

### TL;DR

`pyannote.audio` does [speech separation](https://hf.co/pyannote/speech-separation-ami-1.0): multi-speaker audio in, one audio channel per speaker out!

```bash
pip install pyannote.audio[separation]==3.3.0
```

### New features

- feat(task): add `PixIT` joint speaker diarization and speech separation task (with [@joonaskalda](https://github.com/joonaskalda/))
- feat(model): add `ToTaToNet` joint speaker diarization and speech separation model (with [@joonaskalda](https://github.com/joonaskalda/))
- feat(pipeline): add `SpeechSeparation` pipeline (with [@joonaskalda](https://github.com/joonaskalda/))
- feat(io): add option to select torchaudio `backend`

### Fixes
Expand Down
351 changes: 351 additions & 0 deletions pyannote/audio/models/separation/ToTaToNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
# MIT License
#
# Copyright (c) 2024- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# AUTHOR: Joonas Kalda (github.com/joonaskalda)

from functools import lru_cache
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from asteroid_filterbanks import make_enc_dec
from pyannote.core.utils.generators import pairwise

from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task
from pyannote.audio.utils.params import merge_dict
from pyannote.audio.utils.receptive_field import (
conv1d_num_frames,
conv1d_receptive_field_center,
conv1d_receptive_field_size,
)

try:
from asteroid.masknn import DPRNN
from asteroid.utils.torch_utils import pad_x_to_y

ASTEROID_IS_AVAILABLE = True
except ImportError:
ASTEROID_IS_AVAILABLE = False


try:
from transformers import AutoModel

TRANSFORMERS_IS_AVAILABLE = True
except ImportError:
TRANSFORMERS_IS_AVAILABLE = False


class ToTaToNet(Model):
"""ToTaToNet joint speaker diarization and speech separation model
/--------------\\
Conv1D Encoder --------+--- DPRNN --X------- Conv1D Decoder
WavLM -- upsampling --/ \\--- Avg pool -- Linear -- Classifier
Parameters
----------
sample_rate : int, optional
Audio sample rate. Defaults to 16kHz (16000).
num_channels : int, optional
Number of channels. Defaults to mono (1).
sincnet : dict, optional
Keyword arugments passed to the SincNet block.
Defaults to {"stride": 1}.
linear : dict, optional
Keyword arugments used to initialize linear layers
See ToTaToNet.LINEAR_DEFAULTS for default values.
diar : dict, optional
Keyword arguments used to initalize the average pooling in the diarization branch.
See ToTaToNet.DIAR_DEFAULTS for default values.
encoder_decoder : dict, optional
Keyword arguments used to initalize the encoder and decoder.
See ToTaToNet.ENCODER_DECODER_DEFAULTS for default values.
dprnn : dict, optional
Keyword arguments used to initalize the DPRNN model.
See ToTaToNet.DPRNN_DEFAULTS for default values.
sample_rate : int, optional
Audio sample rate. Defaults to 16000.
num_channels : int, optional
Number of channels. Defaults to 1.
task : Task, optional
Task to perform. Defaults to None.
n_sources : int, optional
Number of separated sources. Defaults to 3.
use_wavlm : bool, optional
Whether to use the WavLM large model for feature extraction. Defaults to True.
gradient_clip_val : float, optional
Gradient clipping value. Required when fine-tuning the WavLM model and thus using two different optimizers.
Defaults to 5.0.
References
----------
Joonas Kalda, Clément Pagés, Ricard Marxer, Tanel Alumäe, and Hervé Bredin.
"PixIT: Joint Training of Speaker Diarization and Speech Separation
from Real-world Multi-speaker Recordings"
Odyssey 2024. https://arxiv.org/abs/2403.02288
"""

ENCODER_DECODER_DEFAULTS = {
"fb_name": "free",
"kernel_size": 32,
"n_filters": 64,
"stride": 16,
}
LINEAR_DEFAULTS = {"hidden_size": 64, "num_layers": 2}
DPRNN_DEFAULTS = {
"n_repeats": 6,
"bn_chan": 128,
"hid_size": 128,
"chunk_size": 100,
"norm_type": "gLN",
"mask_act": "relu",
"rnn_type": "LSTM",
}
DIAR_DEFAULTS = {"frames_per_second": 125}

def __init__(
self,
encoder_decoder: dict = None,
linear: Optional[dict] = None,
diar: Optional[dict] = None,
dprnn: dict = None,
sample_rate: int = 16000,
num_channels: int = 1,
task: Optional[Task] = None,
n_sources: int = 3,
use_wavlm: bool = True,
gradient_clip_val: float = 5.0,
):
if not ASTEROID_IS_AVAILABLE:
raise ImportError(
"'asteroid' must be installed to use ToTaToNet separation. "
"`pip install pyannote-audio[separation]` should do the trick."
)

if not TRANSFORMERS_IS_AVAILABLE:
raise ImportError(
"'transformers' must be installed to use ToTaToNet separation. "
"`pip install pyannote-audio[separation]` should do the trick."
)

super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task)

linear = merge_dict(self.LINEAR_DEFAULTS, linear)
dprnn = merge_dict(self.DPRNN_DEFAULTS, dprnn)
encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder)
diar = merge_dict(self.DIAR_DEFAULTS, diar)
self.use_wavlm = use_wavlm
self.save_hyperparameters("encoder_decoder", "linear", "dprnn", "diar")
self.n_sources = n_sources

if encoder_decoder["fb_name"] == "free":
n_feats_out = encoder_decoder["n_filters"]
elif encoder_decoder["fb_name"] == "stft":
n_feats_out = int(2 * (encoder_decoder["n_filters"] / 2 + 1))
else:
raise ValueError("Filterbank type not recognized.")
self.encoder, self.decoder = make_enc_dec(
sample_rate=sample_rate, **self.hparams.encoder_decoder
)

if self.use_wavlm:
self.wavlm = AutoModel.from_pretrained("microsoft/wavlm-large")
downsampling_factor = 1
for conv_layer in self.wavlm.feature_extractor.conv_layers:
if isinstance(conv_layer.conv, nn.Conv1d):
downsampling_factor *= conv_layer.conv.stride[0]
self.wavlm_scaling = int(downsampling_factor / encoder_decoder["stride"])

self.masker = DPRNN(
encoder_decoder["n_filters"]
+ self.wavlm.feature_projection.projection.out_features,
out_chan=encoder_decoder["n_filters"],
n_src=n_sources,
**self.hparams.dprnn,
)
else:
self.masker = DPRNN(
encoder_decoder["n_filters"],
out_chan=encoder_decoder["n_filters"],
n_src=n_sources,
**self.hparams.dprnn,
)

# diarization can use a lower resolution than separation
self.diarization_scaling = int(
sample_rate / diar["frames_per_second"] / encoder_decoder["stride"]
)
self.average_pool = nn.AvgPool1d(
self.diarization_scaling, stride=self.diarization_scaling
)
linaer_input_features = n_feats_out
if linear["num_layers"] > 0:
self.linear = nn.ModuleList(
[
nn.Linear(in_features, out_features)
for in_features, out_features in pairwise(
[
linaer_input_features,
]
+ [self.hparams.linear["hidden_size"]]
* self.hparams.linear["num_layers"]
)
]
)
self.gradient_clip_val = gradient_clip_val
self.automatic_optimization = False

@property
def dimension(self) -> int:
"""Dimension of output"""
return 1

def build(self):
if self.hparams.linear["num_layers"] > 0:
self.classifier = nn.Linear(64, self.dimension)
else:
self.classifier = nn.Linear(1, self.dimension)
self.activation = self.default_activation()

@lru_cache
def num_frames(self, num_samples: int) -> int:
"""Compute number of output frames
Parameters
----------
num_samples : int
Number of input samples.
Returns
-------
num_frames : int
Number of output frames.
"""

equivalent_stride = (
self.diarization_scaling * self.hparams.encoder_decoder["stride"]
)
equivalent_kernel_size = (
self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"]
)

return conv1d_num_frames(
num_samples, kernel_size=equivalent_kernel_size, stride=equivalent_stride
)

def receptive_field_size(self, num_frames: int = 1) -> int:
"""Compute size of receptive field
Parameters
----------
num_frames : int, optional
Number of frames in the output signal
Returns
-------
receptive_field_size : int
Receptive field size.
"""

equivalent_stride = (
self.diarization_scaling * self.hparams.encoder_decoder["stride"]
)
equivalent_kernel_size = (
self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"]
)

return conv1d_receptive_field_size(
num_frames, kernel_size=equivalent_kernel_size, stride=equivalent_stride
)

def receptive_field_center(self, frame: int = 0) -> int:
"""Compute center of receptive field
Parameters
----------
frame : int, optional
Frame index
Returns
-------
receptive_field_center : int
Index of receptive field center.
"""

equivalent_stride = (
self.diarization_scaling * self.hparams.encoder_decoder["stride"]
)
equivalent_kernel_size = (
self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"]
)

return conv1d_receptive_field_center(
frame, kernel_size=equivalent_kernel_size, stride=equivalent_stride
)

def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward
Parameters
----------
waveforms : (batch, channel, sample)
Returns
-------
scores : (batch, frame, classes)
sources : (batch, sample, n_sources)
"""
bsz = waveforms.shape[0]
tf_rep = self.encoder(waveforms)
if self.use_wavlm:
wavlm_rep = self.wavlm(waveforms.squeeze(1)).last_hidden_state
wavlm_rep = wavlm_rep.transpose(1, 2)
wavlm_rep = wavlm_rep.repeat_interleave(self.wavlm_scaling, dim=-1)
wavlm_rep = pad_x_to_y(wavlm_rep, tf_rep)
wavlm_rep = torch.cat((tf_rep, wavlm_rep), dim=1)
masks = self.masker(wavlm_rep)
else:
masks = self.masker(tf_rep)
# shape: (batch, nsrc, nfilters, nframes)
masked_tf_rep = masks * tf_rep.unsqueeze(1)
decoded_sources = self.decoder(masked_tf_rep)
decoded_sources = pad_x_to_y(decoded_sources, waveforms)
decoded_sources = decoded_sources.transpose(1, 2)
outputs = torch.flatten(masked_tf_rep, start_dim=0, end_dim=1)
# shape (batch * nsrc, nfilters, nframes)
outputs = self.average_pool(outputs)
outputs = outputs.transpose(1, 2)
# shape (batch, nframes, nfilters)
if self.hparams.linear["num_layers"] > 0:
for linear in self.linear:
outputs = F.leaky_relu(linear(outputs))
if self.hparams.linear["num_layers"] == 0:
outputs = (outputs**2).sum(dim=2).unsqueeze(-1)
outputs = self.classifier(outputs)
outputs = outputs.reshape(bsz, self.n_sources, -1)
outputs = outputs.transpose(1, 2)

return self.activation[0](outputs), decoded_sources
25 changes: 25 additions & 0 deletions pyannote/audio/models/separation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# MIT License
#
# Copyright (c) 2024- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from .ToTaToNet import ToTaToNet

__all__ = ["ToTaToNet"]
Loading

0 comments on commit 49d3b8e

Please sign in to comment.