README.md
4.1 KB · 150 lines · markdown Raw
1 ---
2 datasets:
3 - agender
4 - mozillacommonvoice
5 - timit
6 - voxceleb2
7 inference: true
8 tags:
9 - speech
10 - audio
11 - wav2vec2
12 - audio-classification
13 - age-recognition
14 - gender-recognition
15 license: cc-by-nc-sa-4.0
16 base_model:
17 - facebook/wav2vec2-large-robust
18 ---
19
20 # Model for Age and Gender Recognition based on Wav2vec 2.0 (24 layers)
21
22 The model expects a raw audio signal as input and outputs predictions
23 for age in a range of approximately 0...1 (0...100 years)
24 and gender expressing the probababilty for being child, female, or male.
25 In addition, it also provides the pooled states of the last transformer layer.
26 The model was created by fine-tuning [
27 Wav2Vec2-Large-Robust](https://huggingface.co/facebook/wav2vec2-large-robust)
28 on [aGender](https://paperswithcode.com/dataset/agender),
29 [Mozilla Common Voice](https://commonvoice.mozilla.org/),
30 [Timit](https://catalog.ldc.upenn.edu/LDC93s1) and
31 [Voxceleb 2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html).
32 For this version of the model we trained all 24 transformer layers.
33 An [ONNX](https://onnx.ai/) export of the model is available from
34 [doi:10.5281/zenodo.7761387](https://doi.org/10.5281/zenodo.7761387).
35 Further details are given in the associated [paper](https://arxiv.org/abs/2306.16962)
36 and [tutorial](https://github.com/audeering/w2v2-age-gender-how-to).
37
38 # Usage
39
40 ```python
41 import numpy as np
42 import torch
43 import torch.nn as nn
44 from transformers import Wav2Vec2Processor
45 from transformers.models.wav2vec2.modeling_wav2vec2 import (
46 Wav2Vec2Model,
47 Wav2Vec2PreTrainedModel,
48 )
49
50
51 class ModelHead(nn.Module):
52 r"""Classification head."""
53
54 def __init__(self, config, num_labels):
55
56 super().__init__()
57
58 self.dense = nn.Linear(config.hidden_size, config.hidden_size)
59 self.dropout = nn.Dropout(config.final_dropout)
60 self.out_proj = nn.Linear(config.hidden_size, num_labels)
61
62 def forward(self, features, **kwargs):
63
64 x = features
65 x = self.dropout(x)
66 x = self.dense(x)
67 x = torch.tanh(x)
68 x = self.dropout(x)
69 x = self.out_proj(x)
70
71 return x
72
73
74 class AgeGenderModel(Wav2Vec2PreTrainedModel):
75 r"""Speech emotion classifier."""
76
77 def __init__(self, config):
78
79 super().__init__(config)
80
81 self.config = config
82 self.wav2vec2 = Wav2Vec2Model(config)
83 self.age = ModelHead(config, 1)
84 self.gender = ModelHead(config, 3)
85 self.init_weights()
86
87 def forward(
88 self,
89 input_values,
90 ):
91
92 outputs = self.wav2vec2(input_values)
93 hidden_states = outputs[0]
94 hidden_states = torch.mean(hidden_states, dim=1)
95 logits_age = self.age(hidden_states)
96 logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
97
98 return hidden_states, logits_age, logits_gender
99
100
101
102 # load model from hub
103 device = 'cpu'
104 model_name = 'audeering/wav2vec2-large-robust-24-ft-age-gender'
105 processor = Wav2Vec2Processor.from_pretrained(model_name)
106 model = AgeGenderModel.from_pretrained(model_name)
107
108 # dummy signal
109 sampling_rate = 16000
110 signal = np.zeros((1, sampling_rate), dtype=np.float32)
111
112
113 def process_func(
114 x: np.ndarray,
115 sampling_rate: int,
116 embeddings: bool = False,
117 ) -> np.ndarray:
118 r"""Predict age and gender or extract embeddings from raw audio signal."""
119
120 # run through processor to normalize signal
121 # always returns a batch, so we just get the first entry
122 # then we put it on the device
123 y = processor(x, sampling_rate=sampling_rate)
124 y = y['input_values'][0]
125 y = y.reshape(1, -1)
126 y = torch.from_numpy(y).to(device)
127
128 # run through model
129 with torch.no_grad():
130 y = model(y)
131 if embeddings:
132 y = y[0]
133 else:
134 y = torch.hstack([y[1], y[2]])
135
136 # convert to numpy
137 y = y.detach().cpu().numpy()
138
139 return y
140
141
142 print(process_func(signal, sampling_rate))
143 # Age female male child
144 # [[ 0.33793038 0.2715511 0.2275236 0.5009253 ]]
145
146 print(process_func(signal, sampling_rate, embeddings=True))
147 # Pooled hidden states of last transformer layer
148 # [[ 0.024444 0.0508722 0.04930823 ... 0.07247854 -0.0697901
149 # -0.0170537 ]]
150 ```