eval.py
6.1 KB · 165 lines · python Raw
1 #!/usr/bin/env python3
2 from datasets import load_dataset, load_metric, Audio, Dataset
3 from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer, AutoConfig, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
4 import re
5 import torch
6 import argparse
7 from typing import Dict
8
9 def log_results(result: Dataset, args: Dict[str, str]):
10 """ DO NOT CHANGE. This function computes and logs the result metrics. """
11
12 log_outputs = args.log_outputs
13 dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
14
15 # load metric
16 wer = load_metric("wer")
17 cer = load_metric("cer")
18
19 # compute metrics
20 wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
21 cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
22
23 # print & log results
24 result_str = (
25 f"WER: {wer_result}\n"
26 f"CER: {cer_result}"
27 )
28 print(result_str)
29
30 with open(f"{dataset_id}_eval_results.txt", "w") as f:
31 f.write(result_str)
32
33 # log all results in text file. Possibly interesting for analysis
34 if log_outputs is not None:
35 pred_file = f"log_{dataset_id}_predictions.txt"
36 target_file = f"log_{dataset_id}_targets.txt"
37
38 with open(pred_file, "w") as p, open(target_file, "w") as t:
39
40 # mapping function to write output
41 def write_to_file(batch, i):
42 p.write(f"{i}" + "\n")
43 p.write(batch["prediction"] + "\n")
44 t.write(f"{i}" + "\n")
45 t.write(batch["target"] + "\n")
46
47 result.map(write_to_file, with_indices=True)
48
49
50 def normalize_text(text: str, invalid_chars_regex: str, to_lower: bool) -> str:
51 """ DO ADAPT FOR YOUR USE CASE. this function normalizes the target text. """
52
53 text = text.lower() if to_lower else text.upper()
54
55 text = re.sub(invalid_chars_regex, " ", text)
56
57 text = re.sub("\s+", " ", text).strip()
58
59 return text
60
61
62 def main(args):
63 # load dataset
64 dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
65
66 # for testing: only process the first two examples as a test
67 # dataset = dataset.select(range(10))
68
69 # load processor
70 if args.greedy:
71 processor = Wav2Vec2Processor.from_pretrained(args.model_id)
72 decoder = None
73 else:
74 processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
75 decoder = processor.decoder
76
77 feature_extractor = processor.feature_extractor
78 tokenizer = processor.tokenizer
79
80 # resample audio
81 dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
82
83 # load eval pipeline
84 if args.device is None:
85 args.device = 0 if torch.cuda.is_available() else -1
86
87 config = AutoConfig.from_pretrained(args.model_id)
88 model = AutoModelForCTC.from_pretrained(args.model_id)
89
90 #asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
91 asr = pipeline("automatic-speech-recognition", config=config, model=model, tokenizer=tokenizer,
92 feature_extractor=feature_extractor, decoder=decoder, device=args.device)
93
94 # build normalizer config
95 tokenizer = AutoTokenizer.from_pretrained(args.model_id)
96 tokens = [x for x in tokenizer.convert_ids_to_tokens(range(0, tokenizer.vocab_size))]
97 special_tokens = [
98 tokenizer.pad_token, tokenizer.word_delimiter_token,
99 tokenizer.unk_token, tokenizer.bos_token,
100 tokenizer.eos_token,
101 ]
102 non_special_tokens = [x for x in tokens if x not in special_tokens]
103 invalid_chars_regex = f"[^\s{re.escape(''.join(set(non_special_tokens)))}]"
104 normalize_to_lower = False
105 for token in non_special_tokens:
106 if token.isalpha() and token.islower():
107 normalize_to_lower = True
108 break
109
110 # map function to decode audio
111 def map_to_pred(batch, args=args, asr=asr, invalid_chars_regex=invalid_chars_regex, normalize_to_lower=normalize_to_lower):
112 prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
113
114 batch["prediction"] = prediction["text"]
115 batch["target"] = normalize_text(batch["sentence"], invalid_chars_regex, normalize_to_lower)
116 return batch
117
118 # run inference on all examples
119 result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
120
121 # filtering out empty targets
122 result = result.filter(lambda example: example["target"] != "")
123
124 # compute and log_results
125 # do not change function below
126 log_results(result, args)
127
128
129 if __name__ == "__main__":
130 parser = argparse.ArgumentParser()
131
132 parser.add_argument(
133 "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
134 )
135 parser.add_argument(
136 "--dataset", type=str, required=True, help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
137 )
138 parser.add_argument(
139 "--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
140 )
141 parser.add_argument(
142 "--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
143 )
144 parser.add_argument(
145 "--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to None. For long audio files a good value would be 5.0 seconds."
146 )
147 parser.add_argument(
148 "--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to None. For long audio files a good value would be 1.0 seconds."
149 )
150 parser.add_argument(
151 "--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
152 )
153 parser.add_argument(
154 "--greedy", action='store_true', help="If defined, the LM will be ignored during inference."
155 )
156 parser.add_argument(
157 "--device",
158 type=int,
159 default=None,
160 help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
161 )
162 args = parser.parse_args()
163
164 main(args)
165