custom_interface.py
5.9 KB · 159 lines · python Raw
1 import torch
2 from speechbrain.inference.interfaces import Pretrained
3
4
5 class CustomEncoderWav2vec2Classifier(Pretrained):
6 """A ready-to-use class for utterance-level classification (e.g, speaker-id,
7 language-id, emotion recognition, keyword spotting, etc).
8
9 The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
10 are defined in the yaml file. If you want to
11 convert the predicted index into a corresponding text label, please
12 provide the path of the label_encoder in a variable called 'lab_encoder_file'
13 within the yaml.
14
15 The class can be used either to run only the encoder (encode_batch()) to
16 extract embeddings or to run a classification step (classify_batch()).
17 ```
18
19 Example
20 -------
21 >>> import torchaudio
22 >>> from speechbrain.pretrained import EncoderClassifier
23 >>> # Model is downloaded from the speechbrain HuggingFace repo
24 >>> tmpdir = getfixture("tmpdir")
25 >>> classifier = EncoderClassifier.from_hparams(
26 ... source="speechbrain/spkrec-ecapa-voxceleb",
27 ... savedir=tmpdir,
28 ... )
29
30 >>> # Compute embeddings
31 >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
32 >>> embeddings = classifier.encode_batch(signal)
33
34 >>> # Classification
35 >>> prediction = classifier .classify_batch(signal)
36 """
37
38 def __init__(self, *args, **kwargs):
39 super().__init__(*args, **kwargs)
40
41 def encode_batch(self, wavs, wav_lens=None, normalize=False):
42 """Encodes the input audio into a single vector embedding.
43
44 The waveforms should already be in the model's desired format.
45 You can call:
46 ``normalized = <this>.normalizer(signal, sample_rate)``
47 to get a correctly converted signal in most cases.
48
49 Arguments
50 ---------
51 wavs : torch.tensor
52 Batch of waveforms [batch, time, channels] or [batch, time]
53 depending on the model. Make sure the sample rate is fs=16000 Hz.
54 wav_lens : torch.tensor
55 Lengths of the waveforms relative to the longest one in the
56 batch, tensor of shape [batch]. The longest one should have
57 relative length 1.0 and others len(waveform) / max_length.
58 Used for ignoring padding.
59 normalize : bool
60 If True, it normalizes the embeddings with the statistics
61 contained in mean_var_norm_emb.
62
63 Returns
64 -------
65 torch.tensor
66 The encoded batch
67 """
68 # Manage single waveforms in input
69 if len(wavs.shape) == 1:
70 wavs = wavs.unsqueeze(0)
71
72 # Assign full length if wav_lens is not assigned
73 if wav_lens is None:
74 wav_lens = torch.ones(wavs.shape[0], device=self.device)
75
76 # Storing waveform in the specified device
77 wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
78 wavs = wavs.float()
79
80 # Computing features and embeddings
81 outputs = self.mods.wav2vec2(wavs)
82
83 # last dim will be used for AdaptativeAVG pool
84 outputs = self.mods.avg_pool(outputs, wav_lens)
85 outputs = outputs.view(outputs.shape[0], -1)
86 return outputs
87
88 def classify_batch(self, wavs, wav_lens=None):
89 """Performs classification on the top of the encoded features.
90
91 It returns the posterior probabilities, the index and, if the label
92 encoder is specified it also the text label.
93
94 Arguments
95 ---------
96 wavs : torch.tensor
97 Batch of waveforms [batch, time, channels] or [batch, time]
98 depending on the model. Make sure the sample rate is fs=16000 Hz.
99 wav_lens : torch.tensor
100 Lengths of the waveforms relative to the longest one in the
101 batch, tensor of shape [batch]. The longest one should have
102 relative length 1.0 and others len(waveform) / max_length.
103 Used for ignoring padding.
104
105 Returns
106 -------
107 out_prob
108 The log posterior probabilities of each class ([batch, N_class])
109 score:
110 It is the value of the log-posterior for the best class ([batch,])
111 index
112 The indexes of the best class ([batch,])
113 text_lab:
114 List with the text labels corresponding to the indexes.
115 (label encoder should be provided).
116 """
117 outputs = self.encode_batch(wavs, wav_lens)
118 outputs = self.mods.output_mlp(outputs)
119 out_prob = self.hparams.softmax(outputs)
120 score, index = torch.max(out_prob, dim=-1)
121 text_lab = self.hparams.label_encoder.decode_torch(index)
122 return out_prob, score, index, text_lab
123
124 def classify_file(self, path):
125 """Classifies the given audiofile into the given set of labels.
126
127 Arguments
128 ---------
129 path : str
130 Path to audio file to classify.
131
132 Returns
133 -------
134 out_prob
135 The log posterior probabilities of each class ([batch, N_class])
136 score:
137 It is the value of the log-posterior for the best class ([batch,])
138 index
139 The indexes of the best class ([batch,])
140 text_lab:
141 List with the text labels corresponding to the indexes.
142 (label encoder should be provided).
143 """
144 waveform = self.load_audio(path)
145 # Fake a batch:
146 batch = waveform.unsqueeze(0)
147 rel_length = torch.tensor([1.0])
148 outputs = self.encode_batch(batch, rel_length)
149 outputs = self.mods.output_mlp(outputs).squeeze(1)
150 out_prob = self.hparams.softmax(outputs)
151 score, index = torch.max(out_prob, dim=-1)
152 text_lab = self.hparams.label_encoder.decode_torch(index)
153 return out_prob, score, index, text_lab
154
155 def forward(self, wavs, wav_lens=None, normalize=False):
156 return self.encode_batch(
157 wavs=wavs, wav_lens=wav_lens, normalize=normalize
158 )
159