train_script.py
| 1 | from codecs import EncodedFile |
| 2 | from datetime import datetime |
| 3 | from typing import Optional |
| 4 | |
| 5 | import datasets |
| 6 | import torch |
| 7 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything |
| 8 | from torch.utils.data import DataLoader |
| 9 | from transformers import ( |
| 10 | AutoConfig, |
| 11 | AutoModelForSequenceClassification, |
| 12 | AutoTokenizer, |
| 13 | get_linear_schedule_with_warmup, |
| 14 | get_scheduler, |
| 15 | ) |
| 16 | import torch |
| 17 | import sys |
| 18 | import os |
| 19 | from argparse import ArgumentParser |
| 20 | from datasets import load_dataset |
| 21 | import tqdm |
| 22 | import json |
| 23 | import gzip |
| 24 | import random |
| 25 | from pytorch_lightning.callbacks import ModelCheckpoint |
| 26 | import numpy as np |
| 27 | from shutil import copyfile |
| 28 | from pytorch_lightning.loggers import WandbLogger |
| 29 | import transformers |
| 30 | |
| 31 | |
| 32 | class MSMARCOData(LightningDataModule): |
| 33 | def __init__( |
| 34 | self, |
| 35 | model_name: str, |
| 36 | triplets_path: str, |
| 37 | langs, |
| 38 | max_seq_length: int = 250, |
| 39 | train_batch_size: int = 32, |
| 40 | eval_batch_size: int = 32, |
| 41 | num_negs: int = 3, |
| 42 | cross_lingual_chance: float = 0.0, |
| 43 | **kwargs, |
| 44 | ): |
| 45 | super().__init__() |
| 46 | self.model_name = model_name |
| 47 | self.triplets_path = triplets_path |
| 48 | self.max_seq_length = max_seq_length |
| 49 | self.train_batch_size = train_batch_size |
| 50 | self.eval_batch_size = eval_batch_size |
| 51 | self.langs = langs |
| 52 | self.num_negs = num_negs |
| 53 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| 54 | self.cross_lingual_chance = cross_lingual_chance #Probability for cross-lingual batches |
| 55 | |
| 56 | #def setup(self, stage: str): |
| 57 | print(f"!!!!!!!!!!!!!!!!!! SETUP {os.getpid()} !!!!!!!!!!!!!!!") |
| 58 | |
| 59 | #Get the queries |
| 60 | self.queries = {lang: {} for lang in self.langs} |
| 61 | |
| 62 | for lang in self.langs: |
| 63 | for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{lang}')['train'], desc=lang): |
| 64 | self.queries[lang][row['id']] = row['text'] |
| 65 | |
| 66 | #Get the passages |
| 67 | self.collections = {lang: load_dataset('unicamp-dl/mmarco', f'collection-{lang}')['collection'] for lang in self.langs} |
| 68 | |
| 69 | #Get the triplets |
| 70 | with gzip.open(self.triplets_path, 'rt') as fIn: |
| 71 | self.triplets = [json.loads(line) for line in tqdm.tqdm(fIn, desc="triplets", total=502938)] |
| 72 | """ |
| 73 | self.triplets = [] |
| 74 | for line in tqdm.tqdm(fIn): |
| 75 | self.triplets.append(json.loads(line)) |
| 76 | if len(self.triplets) >= 1000: |
| 77 | break |
| 78 | """ |
| 79 | |
| 80 | def collate_fn(self, batch): |
| 81 | cross_lingual_batch = random.random() < self.cross_lingual_chance |
| 82 | |
| 83 | #Create data for list-rank-loss |
| 84 | query_doc_pairs = [[] for _ in range(1+self.num_negs)] |
| 85 | |
| 86 | for row in batch: |
| 87 | qid = row['qid'] |
| 88 | pos_id = random.choice(row['pos']) |
| 89 | |
| 90 | query_lang = random.choice(self.langs) |
| 91 | query_text = self.queries[query_lang][qid] |
| 92 | |
| 93 | doc_lang = random.choice(self.langs) if cross_lingual_batch else query_lang |
| 94 | query_doc_pairs[0].append((query_text, self.collections[doc_lang][pos_id]['text'])) |
| 95 | |
| 96 | dense_bm25_neg = list(set(row['dense_neg'] + row['bm25_neg'])) |
| 97 | neg_ids = random.sample(dense_bm25_neg, self.num_negs) |
| 98 | |
| 99 | for num_neg, neg_id in enumerate(neg_ids): |
| 100 | doc_lang = random.choice(self.langs) if cross_lingual_batch else query_lang |
| 101 | query_doc_pairs[1+num_neg].append((query_text, self.collections[doc_lang][neg_id]['text'])) |
| 102 | |
| 103 | #Now tokenize the data |
| 104 | features = [self.tokenizer(qd_pair, max_length=self.max_seq_length, padding=True, truncation='only_second', return_tensors="pt") for qd_pair in query_doc_pairs] |
| 105 | |
| 106 | return features |
| 107 | |
| 108 | def train_dataloader(self): |
| 109 | return DataLoader(self.triplets, shuffle=True, batch_size=self.train_batch_size, num_workers=1, pin_memory=True, collate_fn=self.collate_fn) |
| 110 | |
| 111 | |
| 112 | |
| 113 | |
| 114 | |
| 115 | class ListRankLoss(LightningModule): |
| 116 | def __init__( |
| 117 | self, |
| 118 | model_name: str, |
| 119 | learning_rate: float = 2e-5, |
| 120 | warmup_steps: int = 1000, |
| 121 | weight_decay: float = 0.01, |
| 122 | train_batch_size: int = 32, |
| 123 | eval_batch_size: int = 32, |
| 124 | **kwargs, |
| 125 | ): |
| 126 | super().__init__() |
| 127 | |
| 128 | self.save_hyperparameters() |
| 129 | print(self.hparams) |
| 130 | |
| 131 | self.config = AutoConfig.from_pretrained(model_name, num_labels=1) |
| 132 | self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) |
| 133 | self.loss_fct = torch.nn.CrossEntropyLoss() |
| 134 | self.global_train_step = 0 |
| 135 | |
| 136 | |
| 137 | def forward(self, **inputs): |
| 138 | return self.model(**inputs) |
| 139 | |
| 140 | def training_step(self, batch, batch_idx): |
| 141 | pred_scores = [] |
| 142 | scores = torch.tensor([0] * len(batch[0]['input_ids']), device=self.model.device) |
| 143 | |
| 144 | for feature in batch: |
| 145 | pred_scores.append(self(**feature).logits.squeeze()) |
| 146 | |
| 147 | pred_scores = torch.stack(pred_scores, 1) |
| 148 | loss_value = self.loss_fct(pred_scores, scores) |
| 149 | self.global_train_step += 1 |
| 150 | self.log('global_train_step', self.global_train_step) |
| 151 | self.log("train/loss", loss_value) |
| 152 | |
| 153 | return loss_value |
| 154 | |
| 155 | |
| 156 | def setup(self, stage=None) -> None: |
| 157 | if stage != "fit": |
| 158 | return |
| 159 | # Get dataloader by calling it - train_dataloader() is called after setup() by default |
| 160 | train_loader = self.trainer.datamodule.train_dataloader() |
| 161 | |
| 162 | # Calculate total steps |
| 163 | tb_size = self.hparams.train_batch_size * max(1, self.trainer.gpus) |
| 164 | ab_size = self.trainer.accumulate_grad_batches |
| 165 | self.total_steps = (len(train_loader) // ab_size) * self.trainer.max_epochs |
| 166 | |
| 167 | print(f"{tb_size=}") |
| 168 | print(f"{ab_size=}") |
| 169 | print(f"{len(train_loader)=}") |
| 170 | print(f"{self.total_steps=}") |
| 171 | |
| 172 | |
| 173 | def configure_optimizers(self): |
| 174 | """Prepare optimizer and schedule (linear warmup and decay)""" |
| 175 | model = self.model |
| 176 | no_decay = ["bias", "LayerNorm.weight"] |
| 177 | optimizer_grouped_parameters = [ |
| 178 | { |
| 179 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
| 180 | "weight_decay": self.hparams.weight_decay, |
| 181 | }, |
| 182 | { |
| 183 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
| 184 | "weight_decay": 0.0, |
| 185 | }, |
| 186 | ] |
| 187 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate) |
| 188 | |
| 189 | lr_scheduler = get_scheduler( |
| 190 | name="linear", |
| 191 | optimizer=optimizer, |
| 192 | num_warmup_steps=self.hparams.warmup_steps, |
| 193 | num_training_steps=self.total_steps, |
| 194 | ) |
| 195 | |
| 196 | scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1} |
| 197 | return [optimizer], [scheduler] |
| 198 | |
| 199 | |
| 200 | |
| 201 | def main(args): |
| 202 | dm = MSMARCOData( |
| 203 | model_name=args.model, |
| 204 | langs=args.langs, |
| 205 | triplets_path='data/msmarco-hard-triplets.jsonl.gz', |
| 206 | train_batch_size=args.batch_size, |
| 207 | cross_lingual_chance=args.cross_lingual_chance, |
| 208 | num_negs=args.num_negs |
| 209 | ) |
| 210 | output_dir = f"output/{args.model.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" |
| 211 | print("Output_dir:", output_dir) |
| 212 | |
| 213 | os.makedirs(output_dir, exist_ok=True) |
| 214 | |
| 215 | wandb_logger = WandbLogger(project="multilingual-cross-encoder", name=output_dir.split("/")[-1]) |
| 216 | |
| 217 | train_script_path = os.path.join(output_dir, 'train_script.py') |
| 218 | copyfile(__file__, train_script_path) |
| 219 | with open(train_script_path, 'a') as fOut: |
| 220 | fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
| 221 | |
| 222 | |
| 223 | # saves top-K checkpoints based on "val_loss" metric |
| 224 | checkpoint_callback = ModelCheckpoint( |
| 225 | every_n_train_steps=25000, |
| 226 | save_top_k=5, |
| 227 | monitor="global_train_step", |
| 228 | mode="max", |
| 229 | dirpath=output_dir, |
| 230 | filename="ckpt-{global_train_step}", |
| 231 | ) |
| 232 | |
| 233 | |
| 234 | model = ListRankLoss(model_name=args.model) |
| 235 | |
| 236 | trainer = Trainer(max_epochs=args.epochs, |
| 237 | accelerator="gpu", |
| 238 | devices=args.num_gpus, |
| 239 | precision=args.precision, |
| 240 | strategy=args.strategy, |
| 241 | default_root_dir=output_dir, |
| 242 | callbacks=[checkpoint_callback], |
| 243 | logger=wandb_logger |
| 244 | ) |
| 245 | |
| 246 | trainer.fit(model, datamodule=dm) |
| 247 | |
| 248 | #Save final HF model |
| 249 | final_path = os.path.join(output_dir, "final") |
| 250 | dm.tokenizer.save_pretrained(final_path) |
| 251 | model.model.save_pretrained(final_path) |
| 252 | |
| 253 | |
| 254 | def eval(args): |
| 255 | import ir_datasets |
| 256 | |
| 257 | |
| 258 | model = ListRankLoss.load_from_checkpoint(args.ckpt) |
| 259 | hf_model = model.model.cuda() |
| 260 | tokenizer = AutoTokenizer.from_pretrained(model.hparams.model_name) |
| 261 | |
| 262 | dev_qids = set() |
| 263 | |
| 264 | dev_queries = {} |
| 265 | dev_rel_docs = {} |
| 266 | needed_pids = set() |
| 267 | needed_qids = set() |
| 268 | |
| 269 | corpus = {} |
| 270 | retrieved_docs = {} |
| 271 | |
| 272 | dataset = ir_datasets.load("msmarco-passage/dev/small") |
| 273 | for query in dataset.queries_iter(): |
| 274 | dev_qids.add(query.query_id) |
| 275 | |
| 276 | |
| 277 | with open('data/qrels.dev.tsv') as fIn: |
| 278 | for line in fIn: |
| 279 | qid, _, pid, _ = line.strip().split('\t') |
| 280 | |
| 281 | if qid not in dev_qids: |
| 282 | continue |
| 283 | |
| 284 | if qid not in dev_rel_docs: |
| 285 | dev_rel_docs[qid] = set() |
| 286 | dev_rel_docs[qid].add(pid) |
| 287 | |
| 288 | retrieved_docs[qid] = set() |
| 289 | needed_qids.add(qid) |
| 290 | needed_pids.add(pid) |
| 291 | |
| 292 | for query in dataset.queries_iter(): |
| 293 | qid = query.query_id |
| 294 | if qid in needed_qids: |
| 295 | dev_queries[qid] = query.text |
| 296 | |
| 297 | with open('data/top1000.dev', 'rt') as fIn: |
| 298 | for line in fIn: |
| 299 | qid, pid, query, passage = line.strip().split("\t") |
| 300 | corpus[pid] = passage |
| 301 | retrieved_docs[qid].add(pid) |
| 302 | |
| 303 | |
| 304 | ## Run evaluator |
| 305 | print("Queries: {}".format(len(dev_queries))) |
| 306 | |
| 307 | mrr_scores = [] |
| 308 | hf_model.eval() |
| 309 | |
| 310 | with torch.no_grad(): |
| 311 | for qid in tqdm.tqdm(dev_queries, total=len(dev_queries)): |
| 312 | query = dev_queries[qid] |
| 313 | top_pids = list(retrieved_docs[qid]) |
| 314 | cross_inp = [[query, corpus[pid]] for pid in top_pids] |
| 315 | |
| 316 | encoded = tokenizer(cross_inp, padding=True, truncation=True, return_tensors="pt").to('cuda') |
| 317 | output = model(**encoded) |
| 318 | bert_score = output.logits.detach().cpu().numpy() |
| 319 | bert_score = np.squeeze(bert_score) |
| 320 | |
| 321 | argsort = np.argsort(-bert_score) |
| 322 | |
| 323 | rank_score = 0 |
| 324 | for rank, idx in enumerate(argsort[0:10]): |
| 325 | pid = top_pids[idx] |
| 326 | if pid in dev_rel_docs[qid]: |
| 327 | rank_score = 1/(rank+1) |
| 328 | break |
| 329 | |
| 330 | mrr_scores.append(rank_score) |
| 331 | |
| 332 | if len(mrr_scores) % 10 == 0: |
| 333 | print("{} MRR@10: {:.2f}".format(len(mrr_scores), 100*np.mean(mrr_scores))) |
| 334 | |
| 335 | print("MRR@10: {:.2f}".format(np.mean(mrr_scores)*100)) |
| 336 | |
| 337 | |
| 338 | if __name__ == '__main__': |
| 339 | parser = ArgumentParser() |
| 340 | parser.add_argument("--num_gpus", type=int, default=1) |
| 341 | parser.add_argument("--batch_size", type=int, default=32) |
| 342 | parser.add_argument("--epochs", type=int, default=10) |
| 343 | parser.add_argument("--strategy", default=None) |
| 344 | parser.add_argument("--model", default='microsoft/mdeberta-v3-base') |
| 345 | parser.add_argument("--eval", action="store_true") |
| 346 | parser.add_argument("--ckpt") |
| 347 | parser.add_argument("--cross_lingual_chance", type=float, default=0.33) |
| 348 | parser.add_argument("--precision", type=int, default=16) |
| 349 | parser.add_argument("--num_negs", type=int, default=3) |
| 350 | parser.add_argument("--langs", nargs="+", default=['english', 'chinese', 'french', 'german', 'indonesian', 'italian', 'portuguese', 'russian', 'spanish', 'arabic', 'dutch', 'hindi', 'japanese', 'vietnamese']) |
| 351 | |
| 352 | |
| 353 | args = parser.parse_args() |
| 354 | |
| 355 | if args.eval: |
| 356 | eval(args) |
| 357 | else: |
| 358 | main(args) |
| 359 | |
| 360 | |
| 361 | # Script was called via: |
| 362 | #python cross_mutlilingual.py --model nreimers/mMiniLMv2-L12-H384-distilled-from-XLMR-Large |