custom_st.py
8.7 KB · 234 lines · python Raw
1 import json
2 import logging
3 import os
4 from io import BytesIO
5 from typing import Any, Dict, List, Optional, Tuple, Union
6
7 import torch
8 from torch import nn
9 from transformers import AutoConfig, AutoModel, AutoTokenizer
10
11 logger = logging.getLogger(__name__)
12
13
14 class Transformer(nn.Module):
15 """Huggingface AutoModel to generate token embeddings.
16 Loads the correct class, e.g. BERT / RoBERTa etc.
17
18 Args:
19 model_name_or_path: Huggingface models name
20 (https://huggingface.co/models)
21 max_seq_length: Truncate any inputs longer than max_seq_length
22 model_args: Keyword arguments passed to the Huggingface
23 Transformers model
24 tokenizer_args: Keyword arguments passed to the Huggingface
25 Transformers tokenizer
26 config_args: Keyword arguments passed to the Huggingface
27 Transformers config
28 cache_dir: Cache dir for Huggingface Transformers to store/load
29 models
30 do_lower_case: If true, lowercases the input (independent if the
31 model is cased or not)
32 tokenizer_name_or_path: Name or path of the tokenizer. When
33 None, then model_name_or_path is used
34 """
35
36 save_in_root: bool = True
37
38 def __init__(
39 self,
40 model_name_or_path: str,
41 max_seq_length: int = None,
42 model_args: Dict[str, Any] = None,
43 tokenizer_args: Dict[str, Any] = None,
44 config_args: Dict[str, Any] = None,
45 cache_dir: str = None,
46 do_lower_case: bool = False,
47 tokenizer_name_or_path: str = None,
48 **kwargs,
49 ) -> None:
50 super().__init__()
51 self.config_keys = ["max_seq_length", "do_lower_case"]
52 self.do_lower_case = do_lower_case
53 if model_args is None:
54 model_args = {}
55 if tokenizer_args is None:
56 tokenizer_args = {}
57 if config_args is None:
58 config_args = {}
59
60 if cache_dir is not None:
61 config_args["cache_dir"] = cache_dir
62 model_args["cache_dir"] = cache_dir
63 tokenizer_args["cache_dir"] = cache_dir
64
65 if kwargs.get("backend", "torch") != "torch":
66 logger.warning(
67 f'"jinaai/jina-embeddings-v3" is currently not compatible with the {kwargs["backend"]} backend. '
68 'Continuing with the "torch" backend.'
69 )
70
71 self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args)
72
73 self._lora_adaptations = self.config.lora_adaptations
74 if (
75 not isinstance(self._lora_adaptations, list)
76 or len(self._lora_adaptations) < 1
77 ):
78 raise ValueError(
79 f"`lora_adaptations` must be a list and contain at least one element"
80 )
81 self._adaptation_map = {
82 name: idx for idx, name in enumerate(self._lora_adaptations)
83 }
84
85 self.default_task = model_args.pop('default_task', None)
86
87 self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, **model_args)
88
89 if max_seq_length is not None and "model_max_length" not in tokenizer_args:
90 tokenizer_args["model_max_length"] = max_seq_length
91 self.tokenizer = AutoTokenizer.from_pretrained(
92 tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
93 **tokenizer_args,
94 )
95
96 # No max_seq_length set. Try to infer from model
97 if max_seq_length is None:
98 if (
99 hasattr(self.auto_model, "config")
100 and hasattr(self.auto_model.config, "max_position_embeddings")
101 and hasattr(self.tokenizer, "model_max_length")
102 ):
103 max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length)
104
105 self.max_seq_length = max_seq_length
106
107 if tokenizer_name_or_path is not None:
108 self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
109
110
111 @property
112 def default_task(self):
113 return self._default_task
114
115 @default_task.setter
116 def default_task(self, task: Union[None, str]):
117 self._validate_task(task)
118 self._default_task = task
119
120
121 def _validate_task(self, task: str):
122 if task and task not in self._lora_adaptations:
123 raise ValueError(
124 f"Unsupported task '{task}'. "
125 f"Supported tasks are: {', '.join(self.config.lora_adaptations)}. "
126 f"Alternatively, don't pass the `task` argument to disable LoRA."
127 )
128
129 def forward(
130 self, features: Dict[str, torch.Tensor], task: Optional[str] = None
131 ) -> Dict[str, torch.Tensor]:
132 """Returns token_embeddings, cls_token"""
133 self._validate_task(task)
134 task = task or self.default_task
135 adapter_mask = None
136 if task:
137 task_id = self._adaptation_map[task]
138 num_examples = features['input_ids'].size(0)
139 adapter_mask = torch.full(
140 (num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
141 )
142
143 lora_arguments = (
144 {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
145 )
146 features.pop('prompt_length', None)
147 output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
148 output_tokens = output_states[0]
149 features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
150 return features
151
152 def get_word_embedding_dimension(self) -> int:
153 return self.auto_model.config.hidden_size
154
155 def tokenize(
156 self,
157 texts: Union[List[str], List[dict], List[Tuple[str, str]]],
158 padding: Union[str, bool] = True
159 ) -> Dict[str, torch.Tensor]:
160 """Tokenizes a text and maps tokens to token-ids"""
161 output = {}
162 if isinstance(texts[0], str):
163 to_tokenize = [texts]
164 elif isinstance(texts[0], dict):
165 to_tokenize = []
166 output["text_keys"] = []
167 for lookup in texts:
168 text_key, text = next(iter(lookup.items()))
169 to_tokenize.append(text)
170 output["text_keys"].append(text_key)
171 to_tokenize = [to_tokenize]
172 else:
173 batch1, batch2 = [], []
174 for text_tuple in texts:
175 batch1.append(text_tuple[0])
176 batch2.append(text_tuple[1])
177 to_tokenize = [batch1, batch2]
178
179 # strip
180 to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
181
182 # Lowercase
183 if self.do_lower_case:
184 to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
185
186 output.update(
187 self.tokenizer(
188 *to_tokenize,
189 padding=padding,
190 truncation="longest_first",
191 return_tensors="pt",
192 max_length=self.max_seq_length,
193 )
194 )
195 return output
196
197 def get_config_dict(self) -> Dict[str, Any]:
198 return {key: self.__dict__[key] for key in self.config_keys}
199
200 def save(self, output_path: str, safe_serialization: bool = True) -> None:
201 self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
202 self.tokenizer.save_pretrained(output_path)
203
204 with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
205 json.dump(self.get_config_dict(), fOut, indent=2)
206
207
208 @classmethod
209 def load(cls, input_path: str) -> "Transformer":
210 # Old classes used other config names than 'sentence_bert_config.json'
211 for config_name in [
212 "sentence_bert_config.json",
213 "sentence_roberta_config.json",
214 "sentence_distilbert_config.json",
215 "sentence_camembert_config.json",
216 "sentence_albert_config.json",
217 "sentence_xlm-roberta_config.json",
218 "sentence_xlnet_config.json",
219 ]:
220 sbert_config_path = os.path.join(input_path, config_name)
221 if os.path.exists(sbert_config_path):
222 break
223
224 with open(sbert_config_path) as fIn:
225 config = json.load(fIn)
226 # Don't allow configs to set trust_remote_code
227 if "model_args" in config and "trust_remote_code" in config["model_args"]:
228 config["model_args"].pop("trust_remote_code")
229 if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
230 config["tokenizer_args"].pop("trust_remote_code")
231 if "config_args" in config and "trust_remote_code" in config["config_args"]:
232 config["config_args"].pop("trust_remote_code")
233 return cls(model_name_or_path=input_path, **config)
234