custom_interface.py
| 1 | import torch |
| 2 | from speechbrain.inference.interfaces import Pretrained |
| 3 | |
| 4 | |
| 5 | class CustomEncoderWav2vec2Classifier(Pretrained): |
| 6 | """A ready-to-use class for utterance-level classification (e.g, speaker-id, |
| 7 | language-id, emotion recognition, keyword spotting, etc). |
| 8 | |
| 9 | The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model |
| 10 | are defined in the yaml file. If you want to |
| 11 | convert the predicted index into a corresponding text label, please |
| 12 | provide the path of the label_encoder in a variable called 'lab_encoder_file' |
| 13 | within the yaml. |
| 14 | |
| 15 | The class can be used either to run only the encoder (encode_batch()) to |
| 16 | extract embeddings or to run a classification step (classify_batch()). |
| 17 | ``` |
| 18 | |
| 19 | Example |
| 20 | ------- |
| 21 | >>> import torchaudio |
| 22 | >>> from speechbrain.pretrained import EncoderClassifier |
| 23 | >>> # Model is downloaded from the speechbrain HuggingFace repo |
| 24 | >>> tmpdir = getfixture("tmpdir") |
| 25 | >>> classifier = EncoderClassifier.from_hparams( |
| 26 | ... source="speechbrain/spkrec-ecapa-voxceleb", |
| 27 | ... savedir=tmpdir, |
| 28 | ... ) |
| 29 | |
| 30 | >>> # Compute embeddings |
| 31 | >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav") |
| 32 | >>> embeddings = classifier.encode_batch(signal) |
| 33 | |
| 34 | >>> # Classification |
| 35 | >>> prediction = classifier .classify_batch(signal) |
| 36 | """ |
| 37 | |
| 38 | def __init__(self, *args, **kwargs): |
| 39 | super().__init__(*args, **kwargs) |
| 40 | |
| 41 | def encode_batch(self, wavs, wav_lens=None, normalize=False): |
| 42 | """Encodes the input audio into a single vector embedding. |
| 43 | |
| 44 | The waveforms should already be in the model's desired format. |
| 45 | You can call: |
| 46 | ``normalized = <this>.normalizer(signal, sample_rate)`` |
| 47 | to get a correctly converted signal in most cases. |
| 48 | |
| 49 | Arguments |
| 50 | --------- |
| 51 | wavs : torch.tensor |
| 52 | Batch of waveforms [batch, time, channels] or [batch, time] |
| 53 | depending on the model. Make sure the sample rate is fs=16000 Hz. |
| 54 | wav_lens : torch.tensor |
| 55 | Lengths of the waveforms relative to the longest one in the |
| 56 | batch, tensor of shape [batch]. The longest one should have |
| 57 | relative length 1.0 and others len(waveform) / max_length. |
| 58 | Used for ignoring padding. |
| 59 | normalize : bool |
| 60 | If True, it normalizes the embeddings with the statistics |
| 61 | contained in mean_var_norm_emb. |
| 62 | |
| 63 | Returns |
| 64 | ------- |
| 65 | torch.tensor |
| 66 | The encoded batch |
| 67 | """ |
| 68 | # Manage single waveforms in input |
| 69 | if len(wavs.shape) == 1: |
| 70 | wavs = wavs.unsqueeze(0) |
| 71 | |
| 72 | # Assign full length if wav_lens is not assigned |
| 73 | if wav_lens is None: |
| 74 | wav_lens = torch.ones(wavs.shape[0], device=self.device) |
| 75 | |
| 76 | # Storing waveform in the specified device |
| 77 | wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) |
| 78 | wavs = wavs.float() |
| 79 | |
| 80 | # Computing features and embeddings |
| 81 | outputs = self.mods.wav2vec2(wavs) |
| 82 | |
| 83 | # last dim will be used for AdaptativeAVG pool |
| 84 | outputs = self.mods.avg_pool(outputs, wav_lens) |
| 85 | outputs = outputs.view(outputs.shape[0], -1) |
| 86 | return outputs |
| 87 | |
| 88 | def classify_batch(self, wavs, wav_lens=None): |
| 89 | """Performs classification on the top of the encoded features. |
| 90 | |
| 91 | It returns the posterior probabilities, the index and, if the label |
| 92 | encoder is specified it also the text label. |
| 93 | |
| 94 | Arguments |
| 95 | --------- |
| 96 | wavs : torch.tensor |
| 97 | Batch of waveforms [batch, time, channels] or [batch, time] |
| 98 | depending on the model. Make sure the sample rate is fs=16000 Hz. |
| 99 | wav_lens : torch.tensor |
| 100 | Lengths of the waveforms relative to the longest one in the |
| 101 | batch, tensor of shape [batch]. The longest one should have |
| 102 | relative length 1.0 and others len(waveform) / max_length. |
| 103 | Used for ignoring padding. |
| 104 | |
| 105 | Returns |
| 106 | ------- |
| 107 | out_prob |
| 108 | The log posterior probabilities of each class ([batch, N_class]) |
| 109 | score: |
| 110 | It is the value of the log-posterior for the best class ([batch,]) |
| 111 | index |
| 112 | The indexes of the best class ([batch,]) |
| 113 | text_lab: |
| 114 | List with the text labels corresponding to the indexes. |
| 115 | (label encoder should be provided). |
| 116 | """ |
| 117 | outputs = self.encode_batch(wavs, wav_lens) |
| 118 | outputs = self.mods.output_mlp(outputs) |
| 119 | out_prob = self.hparams.softmax(outputs) |
| 120 | score, index = torch.max(out_prob, dim=-1) |
| 121 | text_lab = self.hparams.label_encoder.decode_torch(index) |
| 122 | return out_prob, score, index, text_lab |
| 123 | |
| 124 | def classify_file(self, path): |
| 125 | """Classifies the given audiofile into the given set of labels. |
| 126 | |
| 127 | Arguments |
| 128 | --------- |
| 129 | path : str |
| 130 | Path to audio file to classify. |
| 131 | |
| 132 | Returns |
| 133 | ------- |
| 134 | out_prob |
| 135 | The log posterior probabilities of each class ([batch, N_class]) |
| 136 | score: |
| 137 | It is the value of the log-posterior for the best class ([batch,]) |
| 138 | index |
| 139 | The indexes of the best class ([batch,]) |
| 140 | text_lab: |
| 141 | List with the text labels corresponding to the indexes. |
| 142 | (label encoder should be provided). |
| 143 | """ |
| 144 | waveform = self.load_audio(path) |
| 145 | # Fake a batch: |
| 146 | batch = waveform.unsqueeze(0) |
| 147 | rel_length = torch.tensor([1.0]) |
| 148 | outputs = self.encode_batch(batch, rel_length) |
| 149 | outputs = self.mods.output_mlp(outputs).squeeze(1) |
| 150 | out_prob = self.hparams.softmax(outputs) |
| 151 | score, index = torch.max(out_prob, dim=-1) |
| 152 | text_lab = self.hparams.label_encoder.decode_torch(index) |
| 153 | return out_prob, score, index, text_lab |
| 154 | |
| 155 | def forward(self, wavs, wav_lens=None, normalize=False): |
| 156 | return self.encode_batch( |
| 157 | wavs=wavs, wav_lens=wav_lens, normalize=normalize |
| 158 | ) |
| 159 | |