embedding.py
| 1 | # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py |
| 2 | # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0 |
| 3 | |
| 4 | # Copyright (c) 2022, Tri Dao. |
| 5 | |
| 6 | import torch |
| 7 | import torch.nn as nn |
| 8 | from einops import rearrange |
| 9 | from torch import Tensor |
| 10 | |
| 11 | from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids |
| 12 | |
| 13 | |
| 14 | class XLMRobertaEmbeddings(nn.Module): |
| 15 | def __init__( |
| 16 | self, |
| 17 | embed_dim, |
| 18 | vocab_size, |
| 19 | max_position_embeddings, |
| 20 | type_vocab_size, |
| 21 | padding_idx=None, |
| 22 | device=None, |
| 23 | dtype=None, |
| 24 | ): |
| 25 | """ |
| 26 | If max_position_embeddings <= 0, there's no position embeddings |
| 27 | If type_vocab_size <= 0, there's no token type embeddings |
| 28 | """ |
| 29 | factory_kwargs = {"device": device, "dtype": dtype} |
| 30 | super().__init__() |
| 31 | self.word_embeddings = nn.Embedding( |
| 32 | vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs |
| 33 | ) |
| 34 | self.max_position_embeddings = max_position_embeddings |
| 35 | self.type_vocab_size = type_vocab_size |
| 36 | if self.max_position_embeddings > 0: |
| 37 | self.position_embeddings = nn.Embedding( |
| 38 | max_position_embeddings, embed_dim, **factory_kwargs |
| 39 | ) |
| 40 | if self.type_vocab_size > 0: |
| 41 | self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs) |
| 42 | |
| 43 | def forward(self, input_ids, position_ids=None, token_type_ids=None): |
| 44 | """ |
| 45 | input_ids: (batch, seqlen) |
| 46 | position_ids: (batch, seqlen) |
| 47 | token_type_ids: (batch, seqlen) |
| 48 | """ |
| 49 | batch_size, seqlen = input_ids.shape |
| 50 | embeddings = self.word_embeddings(input_ids) |
| 51 | if self.max_position_embeddings > 0: |
| 52 | if position_ids is None: |
| 53 | position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device) |
| 54 | # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) |
| 55 | position_embeddings = self.position_embeddings(position_ids) |
| 56 | embeddings = embeddings + position_embeddings |
| 57 | if self.type_vocab_size > 0: |
| 58 | if token_type_ids is None: |
| 59 | token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device) |
| 60 | token_type_embeddings = self.token_type_embeddings(token_type_ids) |
| 61 | embeddings = embeddings + token_type_embeddings |
| 62 | return embeddings |
| 63 | |