train_script.py
12.8 KB · 344 lines · python Raw
1 """
2 Train script for a single file
3
4 Need to set the TPU address first:
5 export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6 """
7
8 import torch.multiprocessing as mp
9 import threading
10 import time
11 import random
12 import sys
13 import argparse
14 import gzip
15 import json
16 import logging
17 import tqdm
18 import torch
19 from torch import nn
20 from torch.utils.data import DataLoader
21 import torch
22 import torch_xla
23 import torch_xla.core
24 import torch_xla.core.functions
25 import torch_xla.core.xla_model as xm
26 import torch_xla.distributed.xla_multiprocessing as xmp
27 import torch_xla.distributed.parallel_loader as pl
28 import os
29 from shutil import copyfile
30
31
32 from transformers import (
33 AdamW,
34 AutoModel,
35 AutoTokenizer,
36 get_linear_schedule_with_warmup,
37 set_seed,
38 )
39
40 class AutoModelForSentenceEmbedding(nn.Module):
41 def __init__(self, model_name, tokenizer, normalize=True):
42 super(AutoModelForSentenceEmbedding, self).__init__()
43
44 self.model = AutoModel.from_pretrained(model_name)
45 self.normalize = normalize
46 self.tokenizer = tokenizer
47
48 def forward(self, **kwargs):
49 model_output = self.model(**kwargs)
50 embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
51 if self.normalize:
52 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
53
54 return embeddings
55
56 def mean_pooling(self, model_output, attention_mask):
57 token_embeddings = model_output[0] # First element of model_output contains all token embeddings
58 input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
59 return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
60
61 def save_pretrained(self, output_path):
62 if xm.is_master_ordinal():
63 self.tokenizer.save_pretrained(output_path)
64 self.model.config.save_pretrained(output_path)
65
66 xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
67
68
69
70
71 def train_function(index, args, queue):
72 tokenizer = AutoTokenizer.from_pretrained(args.model)
73 model = AutoModelForSentenceEmbedding(args.model, tokenizer)
74
75
76 ### Train Loop
77 device = xm.xla_device()
78 model = model.to(device)
79
80 # Instantiate optimizer
81 optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
82
83 lr_scheduler = get_linear_schedule_with_warmup(
84 optimizer=optimizer,
85 num_warmup_steps=500,
86 num_training_steps=args.steps,
87 )
88
89 # Now we train the model
90 cross_entropy_loss = nn.CrossEntropyLoss()
91 max_grad_norm = 1
92
93 model.train()
94
95 for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
96 #### Get the batch data
97 batch = queue.get()
98 #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
99
100
101 if len(batch[0]) == 2: #(anchor, positive)
102 text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
103 text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
104
105 ### Compute embeddings
106 embeddings_a = model(**text1.to(device))
107 embeddings_b = model(**text2.to(device))
108
109 ### Gather all embedings
110 embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
111 embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
112
113 ### Compute similarity scores 512 x 512
114 scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
115
116 ### Compute cross-entropy loss
117 labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
118
119 ## Symmetric loss as in CLIP
120 loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
121
122 else: #(anchor, positive, negative)
123 text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
124 text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
125 text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
126
127 embeddings_a = model(**text1.to(device))
128 embeddings_b1 = model(**text2.to(device))
129 embeddings_b2 = model(**text3.to(device))
130
131 embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
132 embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
133 embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
134
135 embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
136
137 ### Compute similarity scores 512 x 1024
138 scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
139
140 ### Compute cross-entropy loss
141 labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
142
143 ## One-way loss
144 loss = cross_entropy_loss(scores, labels)
145
146
147 # Backward pass
148 optimizer.zero_grad()
149 loss.backward()
150 torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
151
152 xm.optimizer_step(optimizer, barrier=True)
153 lr_scheduler.step()
154
155
156 #Save model
157 if (global_step+1) % args.save_steps == 0:
158 output_path = os.path.join(args.output, str(global_step+1))
159 xm.master_print("save model: "+output_path)
160 model.save_pretrained(output_path)
161
162
163 output_path = os.path.join(args.output, "final")
164 xm.master_print("save model final: "+ output_path)
165 model.save_pretrained(output_path)
166
167
168 def produce_data(args, queue, filepaths, dataset_indices):
169 global_batch_size = args.batch_size*args.nprocs #Global batch size
170 size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
171 num_same_dataset = int(size_per_dataset / args.batch_size)
172 print("producer", "global_batch_size", global_batch_size)
173 print("producer", "size_per_dataset", size_per_dataset)
174 print("producer", "num_same_dataset", num_same_dataset)
175
176 datasets = []
177 for filepath in filepaths:
178 if "reddit_" in filepath: #Special dataset class for Reddit files
179 data_obj = RedditDataset(filepath)
180 else:
181 data_obj = Dataset(filepath)
182 datasets.append(iter(data_obj))
183
184 # Store if dataset is in a 2 col or 3 col format
185 num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
186
187 while True:
188 texts_in_batch = set()
189 batch_format = None #2 vs 3 col format for this batch
190
191 #Add data from several sub datasets
192 for _ in range(args.datasets_per_batch):
193 valid_dataset = False #Check that datasets have the same 2/3 col format
194 while not valid_dataset:
195 data_idx = random.choice(dataset_indices)
196 if batch_format is None:
197 batch_format = num_cols[data_idx]
198 valid_dataset = True
199 else: #Check that this dataset has the same format
200 valid_dataset = (batch_format == num_cols[data_idx])
201
202 #Get data from this dataset
203 dataset = datasets[data_idx]
204 for _ in range(num_same_dataset):
205 for _ in range(args.nprocs):
206 batch_device = [] #A batch for one device
207 while len(batch_device) < args.batch_size:
208 sample = next(dataset)
209 in_batch = False
210 for text in sample:
211 if text in texts_in_batch:
212 in_batch = True
213 break
214
215 if not in_batch:
216 for text in sample:
217 texts_in_batch.add(text)
218 batch_device.append(sample)
219
220 queue.put(batch_device)
221
222
223 class RedditDataset:
224 """
225 A class that handles the reddit data files
226 """
227 def __init__(self, filepath):
228 self.filepath = filepath
229
230 def __iter__(self):
231 while True:
232 with gzip.open(self.filepath, "rt") as fIn:
233 for line in fIn:
234 data = json.loads(line)
235
236 if "response" in data and "context" in data:
237 yield [data["response"], data["context"]]
238
239 class Dataset:
240 """
241 A class that handles one dataset
242 """
243 def __init__(self, filepath):
244 self.filepath = filepath
245
246 def __iter__(self):
247 max_dataset_size = 10*1000*1000 #Cache small datasets in memory
248 dataset = []
249 data_format = None
250
251 while dataset is None or len(dataset) == 0:
252 with gzip.open(self.filepath, "rt") as fIn:
253 for line in fIn:
254 data = json.loads(line)
255 if isinstance(data, dict):
256 data = data['texts']
257
258 if data_format is None:
259 data_format = len(data)
260
261 #Ensure that all entries are of the same 2/3 col format
262 assert len(data) == data_format
263
264 if dataset is not None:
265 dataset.append(data)
266 if len(dataset) >= max_dataset_size:
267 dataset = None
268
269 yield data
270
271 # Data loaded. Now stream to the queue
272 # Shuffle for each epoch
273 while True:
274 random.shuffle(dataset)
275 for data in dataset:
276 yield data
277
278
279
280 if __name__ == "__main__":
281 parser = argparse.ArgumentParser()
282 parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
283 parser.add_argument('--steps', type=int, default=2000)
284 parser.add_argument('--save_steps', type=int, default=10000)
285 parser.add_argument('--batch_size', type=int, default=64)
286 parser.add_argument('--max_length', type=int, default=128)
287 parser.add_argument('--nprocs', type=int, default=8)
288 parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
289 parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
290 parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
291 parser.add_argument('data_config', help="A data_config.json file")
292 parser.add_argument('output')
293 args = parser.parse_args()
294
295 # Ensure global batch size is divisble by data_sample_size
296 assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
297
298 logging.info("Output: "+args.output)
299 if os.path.exists(args.output):
300 print("Output folder already exists.")
301 input("Continue?")
302
303 # Write train script to output path
304 os.makedirs(args.output, exist_ok=True)
305
306 data_config_path = os.path.join(args.output, 'data_config.json')
307 copyfile(args.data_config, data_config_path)
308
309 train_script_path = os.path.join(args.output, 'train_script.py')
310 copyfile(__file__, train_script_path)
311 with open(train_script_path, 'a') as fOut:
312 fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
313
314
315
316 #Load data config
317 with open(args.data_config) as fIn:
318 data_config = json.load(fIn)
319
320 queue = mp.Queue(maxsize=100*args.nprocs)
321
322 filepaths = []
323 dataset_indices = []
324 for idx, data in enumerate(data_config):
325 filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
326 dataset_indices.extend([idx]*data['weight'])
327
328 # Start producer
329 p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
330 p.start()
331
332 # Run training
333 print("Start processes:", args.nprocs)
334 xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
335 print("Training done")
336 print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
337 print("With 'pkill python' you can kill all remaining python processes")
338 p.kill()
339 exit()
340
341
342
343 # Script was called via:
344 #python train_many_data_files_v2.py --steps 1000000 --batch_size 64 --model microsoft/mpnet-base train_data_configs/all_datasets_v4.json output/all_datasets_v4_mpnet-base