README.md
| 1 | --- |
| 2 | license: apache-2.0 |
| 3 | pipeline_tag: text-classification |
| 4 | tags: |
| 5 | - transformers |
| 6 | - sentence-transformers |
| 7 | - text-embeddings-inference |
| 8 | language: |
| 9 | - multilingual |
| 10 | --- |
| 11 | |
| 12 | # Reranker |
| 13 | |
| 14 | **More details please refer to our Github: [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master).** |
| 15 | |
| 16 | - [Model List](#model-list) |
| 17 | - [Usage](#usage) |
| 18 | - [Fine-tuning](#fine-tune) |
| 19 | - [Evaluation](#evaluation) |
| 20 | - [Citation](#citation) |
| 21 | |
| 22 | Different from embedding model, reranker uses question and document as input and directly output similarity instead of embedding. |
| 23 | You can get a relevance score by inputting query and passage to the reranker. |
| 24 | And the score can be mapped to a float value in [0,1] by sigmoid function. |
| 25 | |
| 26 | |
| 27 | ## Model List |
| 28 | |
| 29 | | Model | Base model | Language | layerwise | feature | |
| 30 | |:--------------------------------------------------------------------------|:--------:|:-----------------------------------------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:| |
| 31 | | [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) | Chinese and English | - | Lightweight reranker model, easy to deploy, with fast inference. | |
| 32 | | [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | [xlm-roberta-large](https://huggingface.co/FacebookAI/xlm-roberta-large) | Chinese and English | - | Lightweight reranker model, easy to deploy, with fast inference. | |
| 33 | | [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) | [bge-m3](https://huggingface.co/BAAI/bge-m3) | Multilingual | - | Lightweight reranker model, possesses strong multilingual capabilities, easy to deploy, with fast inference. | |
| 34 | | [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma) | [gemma-2b](https://huggingface.co/google/gemma-2b) | Multilingual | - | Suitable for multilingual contexts, performs well in both English proficiency and multilingual capabilities. | |
| 35 | | [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise) | [MiniCPM-2B-dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16) | Multilingual | 8-40 | Suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers for output, facilitating accelerated inference. | |
| 36 | |
| 37 | |
| 38 | You can select the model according your senario and resource. |
| 39 | - For **multilingual**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma) |
| 40 | |
| 41 | - For **Chinese or English**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise). |
| 42 | |
| 43 | - For **efficiency**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and the low layer of [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise). |
| 44 | |
| 45 | - For better performance, recommand [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise) and [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma) |
| 46 | |
| 47 | ## Usage |
| 48 | ### Using FlagEmbedding |
| 49 | |
| 50 | ``` |
| 51 | pip install -U FlagEmbedding |
| 52 | ``` |
| 53 | |
| 54 | #### For normal reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 ) |
| 55 | |
| 56 | Get relevance scores (higher scores indicate more relevance): |
| 57 | |
| 58 | ```python |
| 59 | from FlagEmbedding import FlagReranker |
| 60 | reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation |
| 61 | |
| 62 | score = reranker.compute_score(['query', 'passage']) |
| 63 | print(score) # -5.65234375 |
| 64 | |
| 65 | # You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score |
| 66 | score = reranker.compute_score(['query', 'passage'], normalize=True) |
| 67 | print(score) # 0.003497010252573502 |
| 68 | |
| 69 | scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]) |
| 70 | print(scores) # [-8.1875, 5.26171875] |
| 71 | |
| 72 | # You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score |
| 73 | scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], normalize=True) |
| 74 | print(scores) # [0.00027803096387751553, 0.9948403768236574] |
| 75 | ``` |
| 76 | |
| 77 | #### For LLM-based reranker |
| 78 | |
| 79 | ```python |
| 80 | from FlagEmbedding import FlagLLMReranker |
| 81 | reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation |
| 82 | # reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation |
| 83 | |
| 84 | score = reranker.compute_score(['query', 'passage']) |
| 85 | print(score) |
| 86 | |
| 87 | scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]) |
| 88 | print(scores) |
| 89 | ``` |
| 90 | |
| 91 | #### For LLM-based layerwise reranker |
| 92 | |
| 93 | ```python |
| 94 | from FlagEmbedding import LayerWiseFlagLLMReranker |
| 95 | reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation |
| 96 | # reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation |
| 97 | |
| 98 | score = reranker.compute_score(['query', 'passage'], cutoff_layers=[28]) # Adjusting 'cutoff_layers' to pick which layers are used for computing the score. |
| 99 | print(score) |
| 100 | |
| 101 | scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], cutoff_layers=[28]) |
| 102 | print(scores) |
| 103 | ``` |
| 104 | |
| 105 | ### Using Huggingface transformers |
| 106 | |
| 107 | #### For normal reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 ) |
| 108 | |
| 109 | Get relevance scores (higher scores indicate more relevance): |
| 110 | |
| 111 | ```python |
| 112 | import torch |
| 113 | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| 114 | |
| 115 | tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-m3') |
| 116 | model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-v2-m3') |
| 117 | model.eval() |
| 118 | |
| 119 | pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']] |
| 120 | with torch.no_grad(): |
| 121 | inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) |
| 122 | scores = model(**inputs, return_dict=True).logits.view(-1, ).float() |
| 123 | print(scores) |
| 124 | ``` |
| 125 | |
| 126 | #### For LLM-based reranker |
| 127 | |
| 128 | ```python |
| 129 | import torch |
| 130 | from transformers import AutoModelForCausalLM, AutoTokenizer |
| 131 | |
| 132 | def get_inputs(pairs, tokenizer, prompt=None, max_length=1024): |
| 133 | if prompt is None: |
| 134 | prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." |
| 135 | sep = "\n" |
| 136 | prompt_inputs = tokenizer(prompt, |
| 137 | return_tensors=None, |
| 138 | add_special_tokens=False)['input_ids'] |
| 139 | sep_inputs = tokenizer(sep, |
| 140 | return_tensors=None, |
| 141 | add_special_tokens=False)['input_ids'] |
| 142 | inputs = [] |
| 143 | for query, passage in pairs: |
| 144 | query_inputs = tokenizer(f'A: {query}', |
| 145 | return_tensors=None, |
| 146 | add_special_tokens=False, |
| 147 | max_length=max_length * 3 // 4, |
| 148 | truncation=True) |
| 149 | passage_inputs = tokenizer(f'B: {passage}', |
| 150 | return_tensors=None, |
| 151 | add_special_tokens=False, |
| 152 | max_length=max_length, |
| 153 | truncation=True) |
| 154 | item = tokenizer.prepare_for_model( |
| 155 | [tokenizer.bos_token_id] + query_inputs['input_ids'], |
| 156 | sep_inputs + passage_inputs['input_ids'], |
| 157 | truncation='only_second', |
| 158 | max_length=max_length, |
| 159 | padding=False, |
| 160 | return_attention_mask=False, |
| 161 | return_token_type_ids=False, |
| 162 | add_special_tokens=False |
| 163 | ) |
| 164 | item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs |
| 165 | item['attention_mask'] = [1] * len(item['input_ids']) |
| 166 | inputs.append(item) |
| 167 | return tokenizer.pad( |
| 168 | inputs, |
| 169 | padding=True, |
| 170 | max_length=max_length + len(sep_inputs) + len(prompt_inputs), |
| 171 | pad_to_multiple_of=8, |
| 172 | return_tensors='pt', |
| 173 | ) |
| 174 | |
| 175 | tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-gemma') |
| 176 | model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-gemma') |
| 177 | yes_loc = tokenizer('Yes', add_special_tokens=False)['input_ids'][0] |
| 178 | model.eval() |
| 179 | |
| 180 | pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']] |
| 181 | with torch.no_grad(): |
| 182 | inputs = get_inputs(pairs, tokenizer) |
| 183 | scores = model(**inputs, return_dict=True).logits[:, -1, yes_loc].view(-1, ).float() |
| 184 | print(scores) |
| 185 | ``` |
| 186 | |
| 187 | #### For LLM-based layerwise reranker |
| 188 | |
| 189 | ```python |
| 190 | import torch |
| 191 | from transformers import AutoModelForCausalLM, AutoTokenizer |
| 192 | |
| 193 | def get_inputs(pairs, tokenizer, prompt=None, max_length=1024): |
| 194 | if prompt is None: |
| 195 | prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." |
| 196 | sep = "\n" |
| 197 | prompt_inputs = tokenizer(prompt, |
| 198 | return_tensors=None, |
| 199 | add_special_tokens=False)['input_ids'] |
| 200 | sep_inputs = tokenizer(sep, |
| 201 | return_tensors=None, |
| 202 | add_special_tokens=False)['input_ids'] |
| 203 | inputs = [] |
| 204 | for query, passage in pairs: |
| 205 | query_inputs = tokenizer(f'A: {query}', |
| 206 | return_tensors=None, |
| 207 | add_special_tokens=False, |
| 208 | max_length=max_length * 3 // 4, |
| 209 | truncation=True) |
| 210 | passage_inputs = tokenizer(f'B: {passage}', |
| 211 | return_tensors=None, |
| 212 | add_special_tokens=False, |
| 213 | max_length=max_length, |
| 214 | truncation=True) |
| 215 | item = tokenizer.prepare_for_model( |
| 216 | [tokenizer.bos_token_id] + query_inputs['input_ids'], |
| 217 | sep_inputs + passage_inputs['input_ids'], |
| 218 | truncation='only_second', |
| 219 | max_length=max_length, |
| 220 | padding=False, |
| 221 | return_attention_mask=False, |
| 222 | return_token_type_ids=False, |
| 223 | add_special_tokens=False |
| 224 | ) |
| 225 | item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs |
| 226 | item['attention_mask'] = [1] * len(item['input_ids']) |
| 227 | inputs.append(item) |
| 228 | return tokenizer.pad( |
| 229 | inputs, |
| 230 | padding=True, |
| 231 | max_length=max_length + len(sep_inputs) + len(prompt_inputs), |
| 232 | pad_to_multiple_of=8, |
| 233 | return_tensors='pt', |
| 234 | ) |
| 235 | |
| 236 | tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True) |
| 237 | model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True, torch_dtype=torch.bfloat16) |
| 238 | model = model.to('cuda') |
| 239 | model.eval() |
| 240 | |
| 241 | pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']] |
| 242 | with torch.no_grad(): |
| 243 | inputs = get_inputs(pairs, tokenizer).to(model.device) |
| 244 | all_scores = model(**inputs, return_dict=True, cutoff_layers=[28]) |
| 245 | all_scores = [scores[:, -1].view(-1, ).float() for scores in all_scores[0]] |
| 246 | print(all_scores) |
| 247 | ``` |
| 248 | |
| 249 | ## Fine-tune |
| 250 | |
| 251 | ### Data Format |
| 252 | |
| 253 | Train data should be a json file, where each line is a dict like this: |
| 254 | |
| 255 | ``` |
| 256 | {"query": str, "pos": List[str], "neg":List[str], "prompt": str} |
| 257 | ``` |
| 258 | |
| 259 | `query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts, `prompt` indicates the relationship between query and texts. If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives. |
| 260 | |
| 261 | See [toy_finetune_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/llm_reranker/toy_finetune_data.jsonl) for a toy data file. |
| 262 | |
| 263 | ### Train |
| 264 | |
| 265 | You can fine-tune the reranker with the following code: |
| 266 | |
| 267 | **For llm-based reranker** |
| 268 | |
| 269 | ```shell |
| 270 | torchrun --nproc_per_node {number of gpus} \ |
| 271 | -m FlagEmbedding.llm_reranker.finetune_for_instruction.run \ |
| 272 | --output_dir {path to save model} \ |
| 273 | --model_name_or_path google/gemma-2b \ |
| 274 | --train_data ./toy_finetune_data.jsonl \ |
| 275 | --learning_rate 2e-4 \ |
| 276 | --num_train_epochs 1 \ |
| 277 | --per_device_train_batch_size 1 \ |
| 278 | --gradient_accumulation_steps 16 \ |
| 279 | --dataloader_drop_last True \ |
| 280 | --query_max_len 512 \ |
| 281 | --passage_max_len 512 \ |
| 282 | --train_group_size 16 \ |
| 283 | --logging_steps 1 \ |
| 284 | --save_steps 2000 \ |
| 285 | --save_total_limit 50 \ |
| 286 | --ddp_find_unused_parameters False \ |
| 287 | --gradient_checkpointing \ |
| 288 | --deepspeed stage1.json \ |
| 289 | --warmup_ratio 0.1 \ |
| 290 | --bf16 \ |
| 291 | --use_lora True \ |
| 292 | --lora_rank 32 \ |
| 293 | --lora_alpha 64 \ |
| 294 | --use_flash_attn True \ |
| 295 | --target_modules q_proj k_proj v_proj o_proj |
| 296 | ``` |
| 297 | |
| 298 | **For llm-based layerwise reranker** |
| 299 | |
| 300 | ```shell |
| 301 | torchrun --nproc_per_node {number of gpus} \ |
| 302 | -m FlagEmbedding.llm_reranker.finetune_for_layerwise.run \ |
| 303 | --output_dir {path to save model} \ |
| 304 | --model_name_or_path openbmb/MiniCPM-2B-dpo-bf16 \ |
| 305 | --train_data ./toy_finetune_data.jsonl \ |
| 306 | --learning_rate 2e-4 \ |
| 307 | --num_train_epochs 1 \ |
| 308 | --per_device_train_batch_size 1 \ |
| 309 | --gradient_accumulation_steps 16 \ |
| 310 | --dataloader_drop_last True \ |
| 311 | --query_max_len 512 \ |
| 312 | --passage_max_len 512 \ |
| 313 | --train_group_size 16 \ |
| 314 | --logging_steps 1 \ |
| 315 | --save_steps 2000 \ |
| 316 | --save_total_limit 50 \ |
| 317 | --ddp_find_unused_parameters False \ |
| 318 | --gradient_checkpointing \ |
| 319 | --deepspeed stage1.json \ |
| 320 | --warmup_ratio 0.1 \ |
| 321 | --bf16 \ |
| 322 | --use_lora True \ |
| 323 | --lora_rank 32 \ |
| 324 | --lora_alpha 64 \ |
| 325 | --use_flash_attn True \ |
| 326 | --target_modules q_proj k_proj v_proj o_proj \ |
| 327 | --start_layer 8 \ |
| 328 | --head_multi True \ |
| 329 | --head_type simple \ |
| 330 | --lora_extra_parameters linear_head |
| 331 | ``` |
| 332 | |
| 333 | Our rerankers are initialized from [google/gemma-2b](https://huggingface.co/google/gemma-2b) (for llm-based reranker) and [openbmb/MiniCPM-2B-dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16) (for llm-based layerwise reranker), and we train it on a mixture of multilingual datasets: |
| 334 | |
| 335 | - [bge-m3-data](https://huggingface.co/datasets/Shitao/bge-m3-data) |
| 336 | - [quora train data](https://huggingface.co/datasets/quora) |
| 337 | - [fever train data](https://fever.ai/dataset/fever.html) |
| 338 | |
| 339 | ## Evaluation |
| 340 | |
| 341 | - llama-index. |
| 342 | |
| 343 |  |
| 344 | |
| 345 | |
| 346 | - BEIR. |
| 347 | |
| 348 | rereank the top 100 results from bge-en-v1.5 large. |
| 349 | |
| 350 |  |
| 351 | |
| 352 | rereank the top 100 results from e5 mistral 7b instruct. |
| 353 | |
| 354 |  |
| 355 | |
| 356 | - CMTEB-retrieval. |
| 357 | It rereank the top 100 results from bge-zh-v1.5 large. |
| 358 | |
| 359 |  |
| 360 | |
| 361 | - miracl (multi-language). |
| 362 | It rereank the top 100 results from bge-m3. |
| 363 | |
| 364 |  |
| 365 | |
| 366 | |
| 367 | |
| 368 | ## Citation |
| 369 | |
| 370 | If you find this repository useful, please consider giving a star and citation |
| 371 | |
| 372 | ```bibtex |
| 373 | @misc{li2023making, |
| 374 | title={Making Large Language Models A Better Foundation For Dense Retrieval}, |
| 375 | author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao}, |
| 376 | year={2023}, |
| 377 | eprint={2312.15503}, |
| 378 | archivePrefix={arXiv}, |
| 379 | primaryClass={cs.CL} |
| 380 | } |
| 381 | @misc{chen2024bge, |
| 382 | title={BGE M3-Embedding: Multi-Lingual, Multi-Functionality, Multi-Granularity Text Embeddings Through Self-Knowledge Distillation}, |
| 383 | author={Jianlv Chen and Shitao Xiao and Peitian Zhang and Kun Luo and Defu Lian and Zheng Liu}, |
| 384 | year={2024}, |
| 385 | eprint={2402.03216}, |
| 386 | archivePrefix={arXiv}, |
| 387 | primaryClass={cs.CL} |
| 388 | } |
| 389 | ``` |