README.md
2.0 KB · 75 lines · markdown Raw
1 ---
2 language:
3 - ru
4 tags:
5 - SER
6 - speech
7 - audio
8 - russian
9 license: apache-2.0
10 pipeline_tag: audio-classification
11 base_model: facebook/hubert-large-ls960-ft
12 datasets:
13 - xbgoose/dusha
14 ---
15 # HuBERT fine-tuned on DUSHA dataset for speech emotion recognition in russian language
16
17 The pre-trained model is this one - [facebook/hubert-large-ls960-ft](https://huggingface.co/facebook/hubert-large-ls960-ft)
18
19 The DUSHA dataset used can be found [here](https://github.com/salute-developers/golos/tree/master/dusha#dataset-structure)
20
21 # Fine-tuning
22
23 Fine-tuned in Google Colab using Pro account with A100 GPU
24
25 Freezed all layers exept projector, classifier and all 24 HubertEncoderLayerStableLayerNorm layers
26
27 Used half of the train dataset
28
29 # Training parameters
30
31 - 2 epochs
32 - train batch size = 8
33 - eval batch size = 8
34 - gradient accumulation steps = 4
35 - learning rate = 5e-5 without warm up and decay
36
37 # Metrics
38
39 Achieved
40 - accuracy = 0.86
41 - balanced = 0.76
42 - macro f1 score = 0.81
43 on test set, improving accucary and f1 score compared to dataset baseline
44
45 # Usage
46
47 ```python
48 from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor
49 import torchaudio
50 import torch
51
52 feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-large-ls960-ft")
53 model = HubertForSequenceClassification.from_pretrained("xbgoose/hubert-speech-emotion-recognition-russian-dusha-finetuned")
54 num2emotion = {0: 'neutral', 1: 'angry', 2: 'positive', 3: 'sad', 4: 'other'}
55
56 filepath = "path/to/audio.wav"
57
58 waveform, sample_rate = torchaudio.load(filepath, normalize=True)
59 transform = torchaudio.transforms.Resample(sample_rate, 16000)
60 waveform = transform(waveform)
61
62 inputs = feature_extractor(
63 waveform,
64 sampling_rate=feature_extractor.sampling_rate,
65 return_tensors="pt",
66 padding=True,
67 max_length=16000 * 10,
68 truncation=True
69 )
70
71 logits = model(inputs['input_values'][0]).logits
72 predictions = torch.argmax(logits, dim=-1)
73 predicted_emotion = num2emotion[predictions.numpy()[0]]
74 print(predicted_emotion)
75 ```