modeling_xlm_roberta.py
42.7 KB · 1120 lines · python Raw
1 # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2 # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3 # Copyright (c) 2022, Tri Dao.
4 # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5 # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
6 # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
7
8 # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
10 import importlib.util
11 import logging
12 import re
13 from collections import OrderedDict
14 from collections.abc import Sequence
15 from functools import partial
16 import numpy as np
17
18 import torch
19 import torch.nn as nn
20 import torch.nn.functional as F
21 import torch.utils.checkpoint
22 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23 from einops import rearrange
24 from transformers import PretrainedConfig
25 from transformers.modeling_utils import PreTrainedModel
26 from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27 from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
29 from transformers.models.bert.modeling_bert import (
30 BaseModelOutputWithPoolingAndCrossAttentions,
31 BertForPreTrainingOutput,
32 )
33
34 from typing import List, Optional, Tuple, Union
35
36 from .xlm_padding import (
37 index_first_axis,
38 index_first_axis_residual,
39 pad_input,
40 unpad_input,
41 )
42 from .configuration_xlm_roberta import XLMRobertaFlashConfig
43 from .block import Block
44 from .embedding import XLMRobertaEmbeddings
45 from .mha import MHA
46 from .mlp import FusedMLP, Mlp
47
48 try:
49 from flash_attn.ops.fused_dense import FusedDense
50 except ImportError:
51 FusedDense = None
52
53 try:
54 from flash_attn.ops.triton.layer_norm import layer_norm_fn
55 except ImportError:
56 layer_norm_fn = None
57
58
59 try:
60 from flash_attn.losses.cross_entropy import CrossEntropyLoss
61 except ImportError:
62 CrossEntropyLoss = torch.nn.CrossEntropyLoss
63
64 try:
65 from tqdm.autonotebook import trange
66 except ImportError:
67 trange = None
68
69
70 logger = logging.getLogger(__name__)
71
72
73 def get_use_flash_attn(config: XLMRobertaFlashConfig):
74 if not getattr(config, "use_flash_attn", False):
75 return False
76 if not torch.cuda.is_available():
77 return False
78 if importlib.util.find_spec("flash_attn") is None:
79 logger.warning(
80 'flash_attn is not installed. Using PyTorch native attention implementation.'
81 )
82 return False
83 return True
84
85
86 def create_mixer_cls(config, cross_attn=False, return_residual=False):
87 use_flash_attn = get_use_flash_attn(config)
88 fused_bias_fc = getattr(config, "fused_bias_fc", False)
89
90 mixer_cls = partial(
91 MHA,
92 num_heads=config.num_attention_heads,
93 cross_attn=cross_attn,
94 dropout=config.attention_probs_dropout_prob,
95 causal=False,
96 fused_bias_fc=fused_bias_fc,
97 use_flash_attn=use_flash_attn,
98 return_residual=return_residual,
99 )
100 return mixer_cls
101
102
103 def create_mlp_cls(config, layer_idx=None, return_residual=False):
104 inner_dim = config.intermediate_size
105 fused_mlp = getattr(config, "fused_mlp", False)
106 if fused_mlp:
107 assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
108 "fused_mlp only " "supports approximate gelu"
109 )
110 if not fused_mlp:
111 approximate = (
112 "tanh"
113 if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
114 else "none"
115 )
116 mlp_cls = partial(
117 Mlp,
118 hidden_features=inner_dim,
119 activation=partial(F.gelu, approximate=approximate),
120 return_residual=return_residual,
121 )
122 else:
123 if FusedMLP is None:
124 raise ImportError("fused_dense is not installed")
125 mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
126 # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
127 if isinstance(mlp_checkpoint_lvl, Sequence):
128 assert layer_idx is not None
129 mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
130 mlp_cls = partial(
131 FusedMLP,
132 hidden_features=inner_dim,
133 checkpoint_lvl=mlp_checkpoint_lvl,
134 return_residual=return_residual,
135 )
136 return mlp_cls
137
138
139 def create_block(config, layer_idx=None):
140 last_layer_subset = getattr(config, "last_layer_subset", False)
141 cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
142 # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
143 # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
144 # one layer) so we just choose not to return residual in this case.
145 return_residual = not cross_attn
146 mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
147 mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
148 norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
149 block = Block(
150 config.hidden_size,
151 mixer_cls,
152 mlp_cls,
153 norm_cls=norm_cls,
154 prenorm=False,
155 resid_dropout1=config.hidden_dropout_prob,
156 resid_dropout2=config.hidden_dropout_prob,
157 fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
158 return_residual=return_residual,
159 )
160 return block
161
162
163 # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
164 def _init_weights(module, initializer_range=0.02):
165 if isinstance(module, nn.Linear):
166 nn.init.normal_(module.weight, std=initializer_range)
167 if module.bias is not None:
168 nn.init.zeros_(module.bias)
169 elif isinstance(module, nn.Embedding):
170 nn.init.normal_(module.weight, std=initializer_range)
171 if module.padding_idx is not None:
172 nn.init.zeros_(module.weight[module.padding_idx])
173
174
175 class XLMRobertaEncoder(nn.Module):
176 def __init__(self, config: XLMRobertaFlashConfig):
177 super().__init__()
178 self.use_flash_attn = get_use_flash_attn(config)
179 self.layers = nn.ModuleList(
180 [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
181 )
182 self._grad_checkpointing = False
183
184 @property
185 def gradient_checkpointing(self):
186 return self._grad_checkpointing
187
188 @gradient_checkpointing.setter
189 def gradient_checkpointing(self, value):
190 self._grad_checkpointing = value
191
192 def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
193 """If subset_mask is not None, we only want output for the subset of the sequence.
194 This means that we only compute the last layer output for these tokens.
195 subset_mask: (batch, seqlen), dtype=torch.bool
196 """
197 if key_padding_mask is None or not self.use_flash_attn:
198 mixer_kwargs = (
199 {"key_padding_mask": key_padding_mask.bool()}
200 if key_padding_mask is not None
201 else None
202 )
203 for layer in self.layers:
204 if self._grad_checkpointing:
205 hidden_states = torch.utils.checkpoint.checkpoint(
206 layer,
207 hidden_states,
208 use_reentrant=False,
209 mixer_kwargs=mixer_kwargs,
210 )
211 else:
212 hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
213 if subset_mask is not None:
214 hidden_states = hidden_states[subset_mask]
215 else:
216 batch, seqlen = hidden_states.shape[:2]
217 hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
218 hidden_states, key_padding_mask
219 )
220 mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
221 if subset_mask is None:
222 for layer in self.layers:
223 if self._grad_checkpointing:
224 hidden_states = torch.utils.checkpoint.checkpoint(
225 layer,
226 hidden_states,
227 use_reentrant=False,
228 mixer_kwargs=mixer_kwargs,
229 )
230 else:
231 hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
232 hidden_states = pad_input(hidden_states, indices, batch, seqlen)
233 else:
234 for layer in self.layers[:-1]:
235 if self._grad_checkpointing:
236 hidden_states = torch.utils.checkpoint.checkpoint(
237 layer,
238 hidden_states,
239 use_reentrant=False,
240 mixer_kwargs=mixer_kwargs,
241 )
242 else:
243 hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
244 if key_padding_mask is not None:
245 subset_idx = torch.nonzero(
246 subset_mask[key_padding_mask], as_tuple=False
247 ).flatten()
248 subset_seqlens = (subset_mask & key_padding_mask).sum(
249 dim=-1, dtype=torch.int32
250 )
251 subset_cu_seqlens = F.pad(
252 torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
253 (1, 0),
254 )
255 else:
256 subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
257 subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
258 subset_cu_seqlens = F.pad(
259 torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
260 (1, 0),
261 )
262 hidden_states_subset, hidden_states = index_first_axis_residual(
263 hidden_states, subset_idx
264 )
265 # It's ok to set max_seqlen_q to be much larger
266 mixer_kwargs = {
267 "x_kv": hidden_states,
268 "cu_seqlens": subset_cu_seqlens,
269 "max_seqlen": max_seqlen_in_batch,
270 "cu_seqlens_k": cu_seqlens,
271 "max_seqlen_k": max_seqlen_in_batch,
272 }
273 if self._grad_checkpointing:
274 torch.utils.checkpoint.checkpoint(
275 self.layers[-1],
276 hidden_states_subset,
277 use_reentrant=False,
278 mixer_kwargs=mixer_kwargs,
279 )
280 else:
281 hidden_states = self.layers[-1](
282 hidden_states_subset, mixer_kwargs=mixer_kwargs
283 )
284 return hidden_states
285
286
287 class XLMRobertaPooler(nn.Module):
288 def __init__(self, config):
289 super().__init__()
290 fused_bias_fc = getattr(config, "fused_bias_fc", False)
291 if fused_bias_fc and FusedDense is None:
292 raise ImportError("fused_dense is not installed")
293 linear_cls = nn.Linear if not fused_bias_fc else FusedDense
294 self.dense = linear_cls(config.hidden_size, config.hidden_size)
295 self.activation = nn.Tanh()
296
297 def forward(self, hidden_states, pool=True):
298 # We "pool" the model by simply taking the hidden state corresponding
299 # to the first token.
300 first_token_tensor = hidden_states[:, 0] if pool else hidden_states
301 pooled_output = self.dense(first_token_tensor)
302 pooled_output = self.activation(pooled_output)
303 return pooled_output
304
305
306 class XLMRobertaPredictionHeadTransform(nn.Module):
307 def __init__(self, config):
308 super().__init__()
309 fused_bias_fc = getattr(config, "fused_bias_fc", False)
310 if fused_bias_fc and FusedDense is None:
311 raise ImportError("fused_dense is not installed")
312 self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
313 if self.fused_dropout_add_ln and layer_norm_fn is None:
314 raise ImportError("Triton is not installed")
315 linear_cls = nn.Linear if not fused_bias_fc else FusedDense
316 self.dense = linear_cls(config.hidden_size, config.hidden_size)
317 approximate = (
318 "tanh"
319 if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
320 else "none"
321 )
322 self.transform_act_fn = nn.GELU(approximate=approximate)
323 self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
324
325 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
326 hidden_states = self.dense(hidden_states)
327 hidden_states = self.transform_act_fn(hidden_states)
328 if not self.fused_dropout_add_ln:
329 hidden_states = self.layer_norm(hidden_states)
330 else:
331 hidden_states = layer_norm_fn(
332 hidden_states,
333 self.layer_norm.weight,
334 self.layer_norm.bias,
335 eps=self.layer_norm.eps,
336 )
337 return hidden_states
338
339
340 class XLMRobertaLMPredictionHead(nn.Module):
341 def __init__(self, config):
342 super().__init__()
343 fused_bias_fc = getattr(config, "fused_bias_fc", False)
344 if fused_bias_fc and FusedDense is None:
345 raise ImportError("fused_dense is not installed")
346 linear_cls = nn.Linear if not fused_bias_fc else FusedDense
347
348 self.transform = XLMRobertaPredictionHeadTransform(config)
349
350 # The output weights are the same as the input embeddings, but there is
351 # an output-only bias for each token.
352 self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
353
354 def forward(self, hidden_states):
355 hidden_states = self.transform(hidden_states)
356 hidden_states = self.decoder(hidden_states)
357 return hidden_states
358
359
360 class XLMRobertaPreTrainingHeads(nn.Module):
361 def __init__(self, config):
362 super().__init__()
363 self.predictions = XLMRobertaLMPredictionHead(config)
364 self.seq_relationship = nn.Linear(config.hidden_size, 2)
365
366 def forward(self, sequence_output, pooled_output):
367 prediction_scores = self.predictions(sequence_output)
368 seq_relationship_score = self.seq_relationship(pooled_output)
369 return prediction_scores, seq_relationship_score
370
371
372 class XLMRobertaPreTrainedModel(PreTrainedModel):
373 """An abstract class to handle weights initialization and
374 a simple interface for dowloading and loading pretrained models.
375 """
376
377 config_class = XLMRobertaFlashConfig
378 base_model_prefix = "roberta"
379 supports_gradient_checkpointing = True
380
381 def _set_gradient_checkpointing(self, module, value=False):
382 if isinstance(module, XLMRobertaEncoder):
383 module.gradient_checkpointing = value
384
385 @classmethod
386 def from_pretrained(
387 cls,
388 *args,
389 **kwargs,
390 ):
391 if not 'torch_dtype' in kwargs:
392 kwargs['torch_dtype'] = 'auto'
393 return super().from_pretrained(*args, **kwargs)
394
395
396
397 class XLMRobertaModel(XLMRobertaPreTrainedModel):
398 def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
399 super().__init__(config)
400 self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
401 if config.vocab_size % self.pad_vocab_size_multiple != 0:
402 config.vocab_size += self.pad_vocab_size_multiple - (
403 config.vocab_size % self.pad_vocab_size_multiple
404 )
405 self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
406 if self.fused_dropout_add_ln and layer_norm_fn is None:
407 raise ImportError("Triton is not installed")
408 assert config.hidden_act in [
409 "gelu",
410 "gelu_new",
411 "gelu_fast",
412 "gelu_pytorch_tanh",
413 ]
414
415 self.embeddings = XLMRobertaEmbeddings(
416 config.hidden_size,
417 config.vocab_size,
418 config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
419 config.type_vocab_size,
420 padding_idx=config.pad_token_id,
421 )
422 self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
423 self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424 self.encoder = XLMRobertaEncoder(config)
425 self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
426
427 self.apply(partial(_init_weights, initializer_range=config.initializer_range))
428
429
430 @torch.inference_mode()
431 def encode(
432 self: 'XLMRobertaModel',
433 sentences: Union[str, List[str]],
434 batch_size: int = 32,
435 show_progress_bar: Optional[bool] = None,
436 output_value: str = 'sentence_embedding',
437 convert_to_numpy: bool = True,
438 convert_to_tensor: bool = False,
439 device: Optional[torch.device] = None,
440 normalize_embeddings: bool = False,
441 truncate_dim: Optional[int] = None,
442 **tokenizer_kwargs,
443 ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
444 """
445 Computes sentence embeddings
446 Args:
447 sentences(`str` or `List[str]`):
448 Sentence or sentences to be encoded
449 batch_size(`int`, *optional*, defaults to 32):
450 Batch size for the computation
451 show_progress_bar(`bool`, *optional*, defaults to None):
452 Show a progress bar when encoding sentences.
453 If set to None, progress bar is only shown when
454 `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
455 output_value(`str`, *optional*, defaults to 'sentence_embedding'):
456 Default sentence_embedding, to get sentence embeddings.
457 Can be set to token_embeddings to get wordpiece token embeddings.
458 Set to None, to get all output values
459 convert_to_numpy(`bool`, *optional*, defaults to True):
460 If true, the output is a list of numpy vectors.
461 Else, it is a list of pytorch tensors.
462 convert_to_tensor(`bool`, *optional*, defaults to False):
463 If true, you get one large tensor as return.
464 Overwrites any setting from convert_to_numpy
465 device(`torch.device`, *optional*, defaults to None):
466 Which torch.device to use for the computation
467 normalize_embeddings(`bool`, *optional*, defaults to False):
468 If set to true, returned vectors will have length 1. In that case, the
469 faster dot-product (util.dot_score) instead of cosine similarity can
470 be used.
471 truncate_dim(`int`, *optional*, defaults to None):
472 The dimension to truncate sentence embeddings to. `None` does no truncation.
473 tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
474 Keyword arguments for the tokenizer
475 Returns:
476 By default, a list of tensors is returned.
477 If convert_to_tensor, a stacked tensor is returned.
478 If convert_to_numpy, a numpy matrix is returned.
479 """
480 from transformers import AutoTokenizer
481
482 self.tokenizer = AutoTokenizer.from_pretrained(
483 self.name_or_path, trust_remote_code=True
484 )
485
486 is_training = self.training
487 self.eval()
488
489 if show_progress_bar is None:
490 show_progress_bar = (
491 logger.getEffectiveLevel() == logging.INFO
492 or logger.getEffectiveLevel() == logging.DEBUG
493 )
494
495 if convert_to_tensor:
496 convert_to_numpy = False
497
498 if output_value != 'sentence_embedding':
499 convert_to_tensor = False
500 convert_to_numpy = False
501
502 input_was_string = False
503 if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
504 sentences = [sentences]
505 input_was_string = True
506
507 if device is not None:
508 self.to(device)
509
510 permutation = np.argsort([-len(i) for i in sentences])
511 inverse_permutation = np.argsort(permutation)
512 sentences = [sentences[idx] for idx in permutation]
513
514 tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
515 tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
516 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
517 )
518 tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
519
520 all_embeddings = []
521
522 if trange is not None:
523 range_iter = trange(
524 0,
525 len(sentences),
526 batch_size,
527 desc="Encoding",
528 disable=not show_progress_bar,
529 )
530 else:
531 range_iter = range(0, len(sentences), batch_size)
532
533 for i in range_iter:
534 encoded_input = self.tokenizer(
535 sentences[i : i + batch_size],
536 return_tensors='pt',
537 **tokenizer_kwargs,
538 ).to(self.device)
539 token_embs = self.forward(**encoded_input)[0]
540
541 # Accumulate in fp32 to avoid overflow
542 token_embs = token_embs.float()
543
544 if output_value == 'token_embeddings':
545 raise NotImplementedError
546 elif output_value is None:
547 raise NotImplementedError
548 else:
549 if self.config.emb_pooler == 'cls':
550 embeddings = self.cls_pooling(
551 token_embs, encoded_input['attention_mask']
552 )
553 else:
554 embeddings = self.mean_pooling(
555 token_embs, encoded_input['attention_mask']
556 )
557
558 if normalize_embeddings:
559 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
560
561 if convert_to_numpy:
562 embeddings = embeddings.cpu()
563 all_embeddings.extend(embeddings)
564
565 all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
566
567 truncate_dim = truncate_dim or self.config.truncate_dim
568 if truncate_dim:
569 all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
570
571 if convert_to_tensor:
572 all_embeddings = torch.stack(all_embeddings)
573 elif convert_to_numpy:
574 all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
575
576 if input_was_string:
577 all_embeddings = all_embeddings[0]
578
579 self.train(is_training)
580 return all_embeddings
581
582
583 def truncate_embeddings(self, embeddings, truncate_dim):
584 if not self.config.matryoshka_dimensions:
585 logger.warning(
586 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
587 )
588 return embeddings
589 elif truncate_dim in self.config.matryoshka_dimensions:
590 return [tensor[:truncate_dim] for tensor in embeddings]
591 else:
592 raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
593 f'Supported dimensions are {self.config.matryoshka_dimensions}.')
594
595 def mean_pooling(
596 self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
597 ):
598 input_mask_expanded = (
599 attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
600 )
601 return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
602 input_mask_expanded.sum(1), min=1e-9
603 )
604
605
606 def cls_pooling(
607 self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
608 ):
609 return token_embeddings[:,0]
610
611
612 def forward(
613 self,
614 input_ids,
615 position_ids=None,
616 token_type_ids=None,
617 attention_mask=None,
618 masked_tokens_mask=None,
619 return_dict=None,
620 **kwargs,
621 ):
622 """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
623 we only want the output for the masked tokens. This means that we only compute the last
624 layer output for these tokens.
625 masked_tokens_mask: (batch, seqlen), dtype=torch.bool
626 """
627
628 if kwargs:
629 for key, value in kwargs.items():
630 if value is not None:
631 logger.warning(
632 'Flash attention implementation does not support kwargs: %s',
633 key,
634 )
635
636 return_dict = (
637 return_dict if return_dict is not None else self.config.use_return_dict
638 )
639
640 hidden_states = self.embeddings(
641 input_ids, position_ids=position_ids, token_type_ids=token_type_ids
642 )
643 # TD [2022-12:18]: Don't need to force residual in fp32
644 # BERT puts embedding LayerNorm before embedding dropout.
645 if not self.fused_dropout_add_ln:
646 hidden_states = self.emb_ln(hidden_states)
647 else:
648 hidden_states = layer_norm_fn(
649 hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
650 )
651 hidden_states = self.emb_drop(hidden_states)
652
653 if masked_tokens_mask is not None:
654 batch_size, seqlen = input_ids.shape[:2]
655 # We also need the first column for the CLS token
656 first_col_mask = torch.zeros(
657 batch_size, seqlen, dtype=torch.bool, device=input_ids.device
658 )
659 first_col_mask[:, 0] = True
660 subset_mask = masked_tokens_mask | first_col_mask
661 else:
662 subset_mask = None
663
664 sequence_output = self.encoder(
665 hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
666 )
667
668 if masked_tokens_mask is None:
669 pooled_output = (
670 self.pooler(sequence_output) if self.pooler is not None else None
671 )
672 else:
673 # TD [2022-03-01]: the indexing here is very tricky.
674 if attention_mask is not None:
675 subset_idx = subset_mask[attention_mask]
676 pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
677 sequence_output = sequence_output[
678 masked_tokens_mask[attention_mask][subset_idx]
679 ]
680 else:
681 pool_input = sequence_output[first_col_mask[subset_mask]]
682 sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
683 pooled_output = (
684 self.pooler(pool_input, pool=False) if self.pooler is not None else None
685 )
686
687 if not return_dict:
688 return sequence_output, pooled_output
689
690 return BaseModelOutputWithPoolingAndCrossAttentions(
691 last_hidden_state=sequence_output,
692 pooler_output=pooled_output,
693 )
694
695
696 class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
697 _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
698
699 def __init__(self, config):
700 super().__init__(config)
701
702 if config.is_decoder:
703 logger.warning(
704 "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
705 "bi-directional self-attention."
706 )
707
708 self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
709 self.lm_head = XLMRobertaLMHead(config)
710
711 # Initialize weights and apply final processing
712 self.post_init()
713
714 def get_input_embeddings(self):
715 return self.roberta.embeddings.word_embeddings
716
717 def get_output_embeddings(self):
718 return self.lm_head.decoder
719
720 def set_output_embeddings(self, new_embeddings):
721 self.lm_head.decoder = new_embeddings
722
723 def forward(
724 self,
725 input_ids: Optional[torch.LongTensor] = None,
726 attention_mask: Optional[torch.FloatTensor] = None,
727 token_type_ids: Optional[torch.LongTensor] = None,
728 position_ids: Optional[torch.LongTensor] = None,
729 head_mask: Optional[torch.FloatTensor] = None,
730 inputs_embeds: Optional[torch.FloatTensor] = None,
731 encoder_hidden_states: Optional[torch.FloatTensor] = None,
732 encoder_attention_mask: Optional[torch.FloatTensor] = None,
733 labels: Optional[torch.LongTensor] = None,
734 output_attentions: Optional[bool] = None,
735 output_hidden_states: Optional[bool] = None,
736 return_dict: Optional[bool] = None,
737 ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
738 r"""
739 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
740 Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
741 config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
742 loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
743 kwargs (`Dict[str, any]`, optional, defaults to *{}*):
744 Used to hide legacy arguments that have been deprecated.
745 """
746 return_dict = (
747 return_dict if return_dict is not None else self.config.use_return_dict
748 )
749
750 outputs = self.roberta(
751 input_ids,
752 attention_mask=attention_mask,
753 token_type_ids=token_type_ids,
754 position_ids=position_ids,
755 head_mask=head_mask,
756 inputs_embeds=inputs_embeds,
757 encoder_hidden_states=encoder_hidden_states,
758 encoder_attention_mask=encoder_attention_mask,
759 output_attentions=output_attentions,
760 output_hidden_states=output_hidden_states,
761 return_dict=return_dict,
762 )
763 sequence_output = outputs[0]
764 prediction_scores = self.lm_head(sequence_output)
765
766 masked_lm_loss = None
767 if labels is not None:
768 # move labels to correct device to enable model parallelism
769 labels = labels.to(prediction_scores.device)
770 loss_fct = CrossEntropyLoss()
771 masked_lm_loss = loss_fct(
772 prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
773 )
774
775 if not return_dict:
776 output = (prediction_scores,) + outputs[2:]
777 return (
778 ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
779 )
780
781 return MaskedLMOutput(
782 loss=masked_lm_loss,
783 logits=prediction_scores,
784 hidden_states=outputs.hidden_states,
785 attentions=outputs.attentions,
786 )
787
788
789 # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
790 class XLMRobertaClassificationHead(nn.Module):
791 """Head for sentence-level classification tasks."""
792
793 def __init__(self, config):
794 super().__init__()
795 fused_bias_fc = getattr(config, "fused_bias_fc", False)
796 if fused_bias_fc and FusedDense is None:
797 raise ImportError("fused_dense is not installed")
798 linear_cls = nn.Linear if not fused_bias_fc else FusedDense
799 self.dense = linear_cls(config.hidden_size, config.hidden_size)
800 classifier_dropout = (
801 config.classifier_dropout
802 if config.classifier_dropout is not None
803 else config.hidden_dropout_prob
804 )
805 self.dropout = nn.Dropout(classifier_dropout)
806 self.out_proj = linear_cls(config.hidden_size, config.num_labels)
807
808 def forward(self, features, **kwargs):
809 x = features[:, 0, :] # take <s> token (equiv. to [CLS])
810 x = self.dropout(x)
811 x = self.dense(x)
812 x = torch.tanh(x)
813 x = self.dropout(x)
814 x = self.out_proj(x)
815 return x
816
817
818 # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
819 class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
820 def __init__(self, config):
821 super().__init__(config)
822 self.num_labels = config.num_labels
823 self.config = config
824
825 self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
826 self.classifier = XLMRobertaClassificationHead(config)
827
828 # Initialize weights and apply final processing
829 self.post_init()
830
831 def forward(
832 self,
833 input_ids: Optional[torch.LongTensor] = None,
834 attention_mask: Optional[torch.FloatTensor] = None,
835 token_type_ids: Optional[torch.LongTensor] = None,
836 position_ids: Optional[torch.LongTensor] = None,
837 head_mask: Optional[torch.FloatTensor] = None,
838 inputs_embeds: Optional[torch.FloatTensor] = None,
839 labels: Optional[torch.LongTensor] = None,
840 output_attentions: Optional[bool] = None,
841 output_hidden_states: Optional[bool] = None,
842 return_dict: Optional[bool] = None,
843 ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
844 r"""
845 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
846 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
847 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849 """
850 return_dict = (
851 return_dict if return_dict is not None else self.config.use_return_dict
852 )
853
854 outputs = self.roberta(
855 input_ids,
856 attention_mask=attention_mask,
857 token_type_ids=token_type_ids,
858 position_ids=position_ids,
859 head_mask=head_mask,
860 inputs_embeds=inputs_embeds,
861 output_attentions=output_attentions,
862 output_hidden_states=output_hidden_states,
863 return_dict=return_dict,
864 )
865 sequence_output = outputs[0]
866 logits = self.classifier(sequence_output)
867
868 loss = None
869 if labels is not None:
870 # move labels to correct device to enable model parallelism
871 labels = labels.to(logits.device)
872 if self.config.problem_type is None:
873 if self.num_labels == 1:
874 self.config.problem_type = "regression"
875 elif self.num_labels > 1 and (
876 labels.dtype == torch.long or labels.dtype == torch.int
877 ):
878 self.config.problem_type = "single_label_classification"
879 else:
880 self.config.problem_type = "multi_label_classification"
881
882 if self.config.problem_type == "regression":
883 loss_fct = MSELoss()
884 if self.num_labels == 1:
885 loss = loss_fct(logits.squeeze(), labels.squeeze())
886 else:
887 loss = loss_fct(logits, labels)
888 elif self.config.problem_type == "single_label_classification":
889 loss_fct = CrossEntropyLoss()
890 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
891 elif self.config.problem_type == "multi_label_classification":
892 loss_fct = BCEWithLogitsLoss()
893 loss = loss_fct(logits, labels)
894
895 if not return_dict:
896 output = (logits,) + outputs[2:]
897 return ((loss,) + output) if loss is not None else output
898
899 return SequenceClassifierOutput(
900 loss=loss,
901 logits=logits,
902 hidden_states=outputs.hidden_states,
903 attentions=outputs.attentions,
904 )
905
906
907 @torch.inference_mode()
908 def compute_score(
909 self,
910 sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
911 batch_size: int = 32,
912 max_length: Optional[int] = None,
913 ) -> List[float]:
914
915 if not hasattr(self, "_tokenizer"):
916 from transformers import AutoTokenizer
917
918 self._tokenizer = AutoTokenizer.from_pretrained(
919 self.name_or_path, trust_remote_code=True
920 )
921
922 assert isinstance(sentence_pairs, list)
923 if isinstance(sentence_pairs[0], str):
924 sentence_pairs = [sentence_pairs]
925
926 all_scores = []
927 for start_index in range(
928 0, len(sentence_pairs), batch_size
929 ):
930 sentences_batch = sentence_pairs[
931 start_index : start_index + batch_size
932 ]
933 inputs = self._tokenizer(
934 sentences_batch,
935 padding=True,
936 truncation=True,
937 return_tensors='pt',
938 max_length=max_length,
939 ).to(self.device)
940 scores = (
941 self.forward(**inputs, return_dict=True)
942 .logits.view(
943 -1,
944 )
945 .float()
946 )
947 scores = torch.sigmoid(scores)
948 all_scores.extend(scores.cpu().numpy().tolist())
949
950 if len(all_scores) == 1:
951 return all_scores[0]
952 return all_scores
953
954 def predict(
955 self,
956 sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
957 batch_size: int = 32,
958 max_length: Optional[int] = None,
959 ) -> List[float]:
960 # used for beir evaluation
961 return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
962
963 def rerank(
964 self,
965 query: str,
966 documents: List[str],
967 batch_size: int = 32,
968 max_length: int = 1024,
969 max_query_length: int = 512,
970 overlap_tokens: int = 80,
971 top_n: Optional[int] = None,
972 **kwargs,
973 ):
974 assert max_length >= max_query_length * 2, (
975 f'max_length ({max_length}) must be greater than or equal to '
976 f'max_query_length ({max_query_length}) * 2'
977 )
978
979 if not hasattr(self, "_tokenizer"):
980 from transformers import AutoTokenizer
981
982 self._tokenizer = AutoTokenizer.from_pretrained(
983 self.name_or_path, trust_remote_code=True
984 )
985
986 # preproc of tokenization
987 sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
988 query,
989 documents,
990 tokenizer=self._tokenizer,
991 max_length=max_length,
992 max_query_length=max_query_length,
993 overlap_tokens=overlap_tokens,
994 )
995
996 tot_scores = []
997 with torch.no_grad():
998 for k in range(0, len(sentence_pairs), batch_size):
999 batch = self._tokenizer.pad(
1000 sentence_pairs[k : k + batch_size],
1001 padding=True,
1002 max_length=max_length,
1003 pad_to_multiple_of=None,
1004 return_tensors="pt",
1005 )
1006 batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
1007 scores = (
1008 self.forward(**batch_on_device, return_dict=True)
1009 .logits.view(
1010 -1,
1011 )
1012 .float()
1013 )
1014 scores = torch.sigmoid(scores)
1015 tot_scores.extend(scores.cpu().numpy().tolist())
1016
1017 # ranking
1018 merge_scores = [0 for _ in range(len(documents))]
1019 for pid, score in zip(sentence_pairs_pids, tot_scores):
1020 merge_scores[pid] = max(merge_scores[pid], score)
1021
1022 merge_scores_argsort = np.argsort(merge_scores)[::-1]
1023 sorted_documents = []
1024 sorted_scores = []
1025 for mid in merge_scores_argsort:
1026 sorted_scores.append(merge_scores[mid])
1027 sorted_documents.append(documents[mid])
1028
1029 top_n = min(top_n or len(sorted_documents), len(sorted_documents))
1030
1031 return [
1032 {
1033 'document': sorted_documents[i],
1034 'relevance_score': sorted_scores[i],
1035 'index': merge_scores_argsort[i],
1036 }
1037 for i in range(top_n)
1038 ]
1039
1040
1041 def reranker_tokenize_preproc(
1042 query: str,
1043 passages: List[str],
1044 tokenizer=None,
1045 max_length: int = 1024,
1046 max_query_length: int = 512,
1047 overlap_tokens: int = 80,
1048 ):
1049 from copy import deepcopy
1050
1051 assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
1052 sep_id = tokenizer.sep_token_id
1053
1054 def _merge_inputs(chunk1_raw, chunk2):
1055 chunk1 = deepcopy(chunk1_raw)
1056 chunk1['input_ids'].append(sep_id)
1057 chunk1['input_ids'].extend(chunk2['input_ids'])
1058 chunk1['input_ids'].append(sep_id)
1059 chunk1['attention_mask'].append(1)
1060 chunk1['attention_mask'].extend(chunk2['attention_mask'])
1061 chunk1['attention_mask'].append(1)
1062 if 'token_type_ids' in chunk1:
1063 token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
1064 chunk1['token_type_ids'].extend(token_type_ids)
1065 return chunk1
1066
1067 # Note: the long query will be truncated to 256 tokens by default
1068 query_inputs = tokenizer.encode_plus(
1069 query, truncation=True, padding=False, max_length=max_query_length
1070 )
1071
1072 max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
1073 # assert (
1074 # max_passage_inputs_length > 100
1075 # ), "Your query is too long! Please make sure your query less than 500 tokens!"
1076
1077 overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
1078
1079 res_merge_inputs = []
1080 res_merge_inputs_pids = []
1081 for pid, passage in enumerate(passages):
1082 passage_inputs = tokenizer.encode_plus(
1083 passage,
1084 truncation=False,
1085 padding=False,
1086 add_special_tokens=False,
1087 max_length=0,
1088 )
1089 passage_inputs_length = len(passage_inputs['input_ids'])
1090
1091 if passage_inputs_length <= max_passage_inputs_length:
1092 qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
1093 res_merge_inputs.append(qp_merge_inputs)
1094 res_merge_inputs_pids.append(pid)
1095 else:
1096 start_id = 0
1097 while start_id < passage_inputs_length:
1098 end_id = start_id + max_passage_inputs_length
1099 # make sure the length of the last chunk is `max_passage_inputs_length`
1100 if end_id >= passage_inputs_length:
1101 sub_passage_inputs = {
1102 k: v[-max_passage_inputs_length:]
1103 for k, v in passage_inputs.items()
1104 }
1105 else:
1106 sub_passage_inputs = {
1107 k: v[start_id:end_id] for k, v in passage_inputs.items()
1108 }
1109 start_id = (
1110 end_id - overlap_tokens_implt
1111 if end_id < passage_inputs_length
1112 else end_id
1113 )
1114
1115 qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
1116 res_merge_inputs.append(qp_merge_inputs)
1117 res_merge_inputs_pids.append(pid)
1118
1119 return res_merge_inputs, res_merge_inputs_pids
1120