train_script.py
12.0 KB · 362 lines · python Raw
1 from codecs import EncodedFile
2 from datetime import datetime
3 from typing import Optional
4
5 import datasets
6 import torch
7 from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
8 from torch.utils.data import DataLoader
9 from transformers import (
10 AutoConfig,
11 AutoModelForSequenceClassification,
12 AutoTokenizer,
13 get_linear_schedule_with_warmup,
14 get_scheduler,
15 )
16 import torch
17 import sys
18 import os
19 from argparse import ArgumentParser
20 from datasets import load_dataset
21 import tqdm
22 import json
23 import gzip
24 import random
25 from pytorch_lightning.callbacks import ModelCheckpoint
26 import numpy as np
27 from shutil import copyfile
28 from pytorch_lightning.loggers import WandbLogger
29 import transformers
30
31
32 class MSMARCOData(LightningDataModule):
33 def __init__(
34 self,
35 model_name: str,
36 triplets_path: str,
37 langs,
38 max_seq_length: int = 250,
39 train_batch_size: int = 32,
40 eval_batch_size: int = 32,
41 num_negs: int = 3,
42 cross_lingual_chance: float = 0.0,
43 **kwargs,
44 ):
45 super().__init__()
46 self.model_name = model_name
47 self.triplets_path = triplets_path
48 self.max_seq_length = max_seq_length
49 self.train_batch_size = train_batch_size
50 self.eval_batch_size = eval_batch_size
51 self.langs = langs
52 self.num_negs = num_negs
53 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
54 self.cross_lingual_chance = cross_lingual_chance #Probability for cross-lingual batches
55
56 #def setup(self, stage: str):
57 print(f"!!!!!!!!!!!!!!!!!! SETUP {os.getpid()} !!!!!!!!!!!!!!!")
58
59 #Get the queries
60 self.queries = {lang: {} for lang in self.langs}
61
62 for lang in self.langs:
63 for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{lang}')['train'], desc=lang):
64 self.queries[lang][row['id']] = row['text']
65
66 #Get the passages
67 self.collections = {lang: load_dataset('unicamp-dl/mmarco', f'collection-{lang}')['collection'] for lang in self.langs}
68
69 #Get the triplets
70 with gzip.open(self.triplets_path, 'rt') as fIn:
71 self.triplets = [json.loads(line) for line in tqdm.tqdm(fIn, desc="triplets", total=502938)]
72 """
73 self.triplets = []
74 for line in tqdm.tqdm(fIn):
75 self.triplets.append(json.loads(line))
76 if len(self.triplets) >= 1000:
77 break
78 """
79
80 def collate_fn(self, batch):
81 cross_lingual_batch = random.random() < self.cross_lingual_chance
82
83 #Create data for list-rank-loss
84 query_doc_pairs = [[] for _ in range(1+self.num_negs)]
85
86 for row in batch:
87 qid = row['qid']
88 pos_id = random.choice(row['pos'])
89
90 query_lang = random.choice(self.langs)
91 query_text = self.queries[query_lang][qid]
92
93 doc_lang = random.choice(self.langs) if cross_lingual_batch else query_lang
94 query_doc_pairs[0].append((query_text, self.collections[doc_lang][pos_id]['text']))
95
96 dense_bm25_neg = list(set(row['dense_neg'] + row['bm25_neg']))
97 neg_ids = random.sample(dense_bm25_neg, self.num_negs)
98
99 for num_neg, neg_id in enumerate(neg_ids):
100 doc_lang = random.choice(self.langs) if cross_lingual_batch else query_lang
101 query_doc_pairs[1+num_neg].append((query_text, self.collections[doc_lang][neg_id]['text']))
102
103 #Now tokenize the data
104 features = [self.tokenizer(qd_pair, max_length=self.max_seq_length, padding=True, truncation='only_second', return_tensors="pt") for qd_pair in query_doc_pairs]
105
106 return features
107
108 def train_dataloader(self):
109 return DataLoader(self.triplets, shuffle=True, batch_size=self.train_batch_size, num_workers=1, pin_memory=True, collate_fn=self.collate_fn)
110
111
112
113
114
115 class ListRankLoss(LightningModule):
116 def __init__(
117 self,
118 model_name: str,
119 learning_rate: float = 2e-5,
120 warmup_steps: int = 1000,
121 weight_decay: float = 0.01,
122 train_batch_size: int = 32,
123 eval_batch_size: int = 32,
124 **kwargs,
125 ):
126 super().__init__()
127
128 self.save_hyperparameters()
129 print(self.hparams)
130
131 self.config = AutoConfig.from_pretrained(model_name, num_labels=1)
132 self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
133 self.loss_fct = torch.nn.CrossEntropyLoss()
134 self.global_train_step = 0
135
136
137 def forward(self, **inputs):
138 return self.model(**inputs)
139
140 def training_step(self, batch, batch_idx):
141 pred_scores = []
142 scores = torch.tensor([0] * len(batch[0]['input_ids']), device=self.model.device)
143
144 for feature in batch:
145 pred_scores.append(self(**feature).logits.squeeze())
146
147 pred_scores = torch.stack(pred_scores, 1)
148 loss_value = self.loss_fct(pred_scores, scores)
149 self.global_train_step += 1
150 self.log('global_train_step', self.global_train_step)
151 self.log("train/loss", loss_value)
152
153 return loss_value
154
155
156 def setup(self, stage=None) -> None:
157 if stage != "fit":
158 return
159 # Get dataloader by calling it - train_dataloader() is called after setup() by default
160 train_loader = self.trainer.datamodule.train_dataloader()
161
162 # Calculate total steps
163 tb_size = self.hparams.train_batch_size * max(1, self.trainer.gpus)
164 ab_size = self.trainer.accumulate_grad_batches
165 self.total_steps = (len(train_loader) // ab_size) * self.trainer.max_epochs
166
167 print(f"{tb_size=}")
168 print(f"{ab_size=}")
169 print(f"{len(train_loader)=}")
170 print(f"{self.total_steps=}")
171
172
173 def configure_optimizers(self):
174 """Prepare optimizer and schedule (linear warmup and decay)"""
175 model = self.model
176 no_decay = ["bias", "LayerNorm.weight"]
177 optimizer_grouped_parameters = [
178 {
179 "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
180 "weight_decay": self.hparams.weight_decay,
181 },
182 {
183 "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
184 "weight_decay": 0.0,
185 },
186 ]
187 optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate)
188
189 lr_scheduler = get_scheduler(
190 name="linear",
191 optimizer=optimizer,
192 num_warmup_steps=self.hparams.warmup_steps,
193 num_training_steps=self.total_steps,
194 )
195
196 scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
197 return [optimizer], [scheduler]
198
199
200
201 def main(args):
202 dm = MSMARCOData(
203 model_name=args.model,
204 langs=args.langs,
205 triplets_path='data/msmarco-hard-triplets.jsonl.gz',
206 train_batch_size=args.batch_size,
207 cross_lingual_chance=args.cross_lingual_chance,
208 num_negs=args.num_negs
209 )
210 output_dir = f"output/{args.model.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
211 print("Output_dir:", output_dir)
212
213 os.makedirs(output_dir, exist_ok=True)
214
215 wandb_logger = WandbLogger(project="multilingual-cross-encoder", name=output_dir.split("/")[-1])
216
217 train_script_path = os.path.join(output_dir, 'train_script.py')
218 copyfile(__file__, train_script_path)
219 with open(train_script_path, 'a') as fOut:
220 fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
221
222
223 # saves top-K checkpoints based on "val_loss" metric
224 checkpoint_callback = ModelCheckpoint(
225 every_n_train_steps=25000,
226 save_top_k=5,
227 monitor="global_train_step",
228 mode="max",
229 dirpath=output_dir,
230 filename="ckpt-{global_train_step}",
231 )
232
233
234 model = ListRankLoss(model_name=args.model)
235
236 trainer = Trainer(max_epochs=args.epochs,
237 accelerator="gpu",
238 devices=args.num_gpus,
239 precision=args.precision,
240 strategy=args.strategy,
241 default_root_dir=output_dir,
242 callbacks=[checkpoint_callback],
243 logger=wandb_logger
244 )
245
246 trainer.fit(model, datamodule=dm)
247
248 #Save final HF model
249 final_path = os.path.join(output_dir, "final")
250 dm.tokenizer.save_pretrained(final_path)
251 model.model.save_pretrained(final_path)
252
253
254 def eval(args):
255 import ir_datasets
256
257
258 model = ListRankLoss.load_from_checkpoint(args.ckpt)
259 hf_model = model.model.cuda()
260 tokenizer = AutoTokenizer.from_pretrained(model.hparams.model_name)
261
262 dev_qids = set()
263
264 dev_queries = {}
265 dev_rel_docs = {}
266 needed_pids = set()
267 needed_qids = set()
268
269 corpus = {}
270 retrieved_docs = {}
271
272 dataset = ir_datasets.load("msmarco-passage/dev/small")
273 for query in dataset.queries_iter():
274 dev_qids.add(query.query_id)
275
276
277 with open('data/qrels.dev.tsv') as fIn:
278 for line in fIn:
279 qid, _, pid, _ = line.strip().split('\t')
280
281 if qid not in dev_qids:
282 continue
283
284 if qid not in dev_rel_docs:
285 dev_rel_docs[qid] = set()
286 dev_rel_docs[qid].add(pid)
287
288 retrieved_docs[qid] = set()
289 needed_qids.add(qid)
290 needed_pids.add(pid)
291
292 for query in dataset.queries_iter():
293 qid = query.query_id
294 if qid in needed_qids:
295 dev_queries[qid] = query.text
296
297 with open('data/top1000.dev', 'rt') as fIn:
298 for line in fIn:
299 qid, pid, query, passage = line.strip().split("\t")
300 corpus[pid] = passage
301 retrieved_docs[qid].add(pid)
302
303
304 ## Run evaluator
305 print("Queries: {}".format(len(dev_queries)))
306
307 mrr_scores = []
308 hf_model.eval()
309
310 with torch.no_grad():
311 for qid in tqdm.tqdm(dev_queries, total=len(dev_queries)):
312 query = dev_queries[qid]
313 top_pids = list(retrieved_docs[qid])
314 cross_inp = [[query, corpus[pid]] for pid in top_pids]
315
316 encoded = tokenizer(cross_inp, padding=True, truncation=True, return_tensors="pt").to('cuda')
317 output = model(**encoded)
318 bert_score = output.logits.detach().cpu().numpy()
319 bert_score = np.squeeze(bert_score)
320
321 argsort = np.argsort(-bert_score)
322
323 rank_score = 0
324 for rank, idx in enumerate(argsort[0:10]):
325 pid = top_pids[idx]
326 if pid in dev_rel_docs[qid]:
327 rank_score = 1/(rank+1)
328 break
329
330 mrr_scores.append(rank_score)
331
332 if len(mrr_scores) % 10 == 0:
333 print("{} MRR@10: {:.2f}".format(len(mrr_scores), 100*np.mean(mrr_scores)))
334
335 print("MRR@10: {:.2f}".format(np.mean(mrr_scores)*100))
336
337
338 if __name__ == '__main__':
339 parser = ArgumentParser()
340 parser.add_argument("--num_gpus", type=int, default=1)
341 parser.add_argument("--batch_size", type=int, default=32)
342 parser.add_argument("--epochs", type=int, default=10)
343 parser.add_argument("--strategy", default=None)
344 parser.add_argument("--model", default='microsoft/mdeberta-v3-base')
345 parser.add_argument("--eval", action="store_true")
346 parser.add_argument("--ckpt")
347 parser.add_argument("--cross_lingual_chance", type=float, default=0.33)
348 parser.add_argument("--precision", type=int, default=16)
349 parser.add_argument("--num_negs", type=int, default=3)
350 parser.add_argument("--langs", nargs="+", default=['english', 'chinese', 'french', 'german', 'indonesian', 'italian', 'portuguese', 'russian', 'spanish', 'arabic', 'dutch', 'hindi', 'japanese', 'vietnamese'])
351
352
353 args = parser.parse_args()
354
355 if args.eval:
356 eval(args)
357 else:
358 main(args)
359
360
361 # Script was called via:
362 #python cross_mutlilingual.py --model nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large