README.md
7.4 KB · 166 lines · markdown Raw
1 ---
2 language:
3 - en
4 license: mit
5 tags:
6 - text-classification
7 - zero-shot-classification
8 datasets:
9 - multi_nli
10 - facebook/anli
11 - fever
12 - lingnli
13 - alisawuffles/WANLI
14 metrics:
15 - accuracy
16 pipeline_tag: zero-shot-classification
17 model-index:
18 - name: DeBERTa-v3-large-mnli-fever-anli-ling-wanli
19 results:
20 - task:
21 type: text-classification
22 name: Natural Language Inference
23 dataset:
24 name: MultiNLI-matched
25 type: multi_nli
26 split: validation_matched
27 metrics:
28 - type: accuracy
29 value: 0,912
30 verified: false
31 - task:
32 type: text-classification
33 name: Natural Language Inference
34 dataset:
35 name: MultiNLI-mismatched
36 type: multi_nli
37 split: validation_mismatched
38 metrics:
39 - type: accuracy
40 value: 0,908
41 verified: false
42 - task:
43 type: text-classification
44 name: Natural Language Inference
45 dataset:
46 name: ANLI-all
47 type: anli
48 split: test_r1+test_r2+test_r3
49 metrics:
50 - type: accuracy
51 value: 0,702
52 verified: false
53 - task:
54 type: text-classification
55 name: Natural Language Inference
56 dataset:
57 name: ANLI-r3
58 type: anli
59 split: test_r3
60 metrics:
61 - type: accuracy
62 value: 0,64
63 verified: false
64 - task:
65 type: text-classification
66 name: Natural Language Inference
67 dataset:
68 name: WANLI
69 type: alisawuffles/WANLI
70 split: test
71 metrics:
72 - type: accuracy
73 value: 0,77
74 verified: false
75 - task:
76 type: text-classification
77 name: Natural Language Inference
78 dataset:
79 name: LingNLI
80 type: lingnli
81 split: test
82 metrics:
83 - type: accuracy
84 value: 0,87
85 verified: false
86 ---
87
88 # DeBERTa-v3-large-mnli-fever-anli-ling-wanli
89 ## Model description
90 This model was fine-tuned on the [MultiNLI](https://huggingface.co/datasets/multi_nli), [Fever-NLI](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md), Adversarial-NLI ([ANLI](https://huggingface.co/datasets/anli)), [LingNLI](https://arxiv.org/pdf/2104.07179.pdf) and [WANLI](https://huggingface.co/datasets/alisawuffles/WANLI) datasets, which comprise 885 242 NLI hypothesis-premise pairs. This model is the best performing NLI model on the Hugging Face Hub as of 06.06.22 and can be used for zero-shot classification. It significantly outperforms all other large models on the [ANLI benchmark](https://github.com/facebookresearch/anli).
91
92 The foundation model is [DeBERTa-v3-large from Microsoft](https://huggingface.co/microsoft/deberta-v3-large). DeBERTa-v3 combines several recent innovations compared to classical Masked Language Models like BERT, RoBERTa etc., see the [paper](https://arxiv.org/abs/2111.09543)
93
94
95 ### How to use the model
96 #### Simple zero-shot classification pipeline
97 ```python
98 from transformers import pipeline
99 classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli")
100 sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
101 candidate_labels = ["politics", "economy", "entertainment", "environment"]
102 output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
103 print(output)
104 ```
105 #### NLI use-case
106 ```python
107 from transformers import AutoTokenizer, AutoModelForSequenceClassification
108 import torch
109 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
110
111 model_name = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"
112 tokenizer = AutoTokenizer.from_pretrained(model_name)
113 model = AutoModelForSequenceClassification.from_pretrained(model_name)
114
115 premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
116 hypothesis = "The movie was not good."
117
118 input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
119 output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
120 prediction = torch.softmax(output["logits"][0], -1).tolist()
121 label_names = ["entailment", "neutral", "contradiction"]
122 prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
123 print(prediction)
124 ```
125
126 ### Training data
127 DeBERTa-v3-large-mnli-fever-anli-ling-wanli was trained on the [MultiNLI](https://huggingface.co/datasets/multi_nli), [Fever-NLI](https://github.com/easonnie/combine-FEVER-NSMN/blob/master/other_resources/nli_fever.md), Adversarial-NLI ([ANLI](https://huggingface.co/datasets/anli)), [LingNLI](https://arxiv.org/pdf/2104.07179.pdf) and [WANLI](https://huggingface.co/datasets/alisawuffles/WANLI) datasets, which comprise 885 242 NLI hypothesis-premise pairs. Note that [SNLI](https://huggingface.co/datasets/snli) was explicitly excluded due to quality issues with the dataset. More data does not necessarily make for better NLI models.
128
129 ### Training procedure
130 DeBERTa-v3-large-mnli-fever-anli-ling-wanli was trained using the Hugging Face trainer with the following hyperparameters. Note that longer training with more epochs hurt performance in my tests (overfitting).
131
132
133 ```
134 training_args = TrainingArguments(
135 num_train_epochs=4, # total number of training epochs
136 learning_rate=5e-06,
137 per_device_train_batch_size=16, # batch size per device during training
138 gradient_accumulation_steps=2, # doubles the effective batch_size to 32, while decreasing memory requirements
139 per_device_eval_batch_size=64, # batch size for evaluation
140 warmup_ratio=0.06, # number of warmup steps for learning rate scheduler
141 weight_decay=0.01, # strength of weight decay
142 fp16=True # mixed precision training
143 )
144 ```
145
146 ### Eval results
147 The model was evaluated using the test sets for MultiNLI, ANLI, LingNLI, WANLI and the dev set for Fever-NLI. The metric used is accuracy.
148 The model achieves state-of-the-art performance on each dataset. Surprisingly, it outperforms the previous [state-of-the-art on ANLI](https://github.com/facebookresearch/anli) (ALBERT-XXL) by 8,3%. I assume that this is because ANLI was created to fool masked language models like RoBERTa (or ALBERT), while DeBERTa-v3 uses a better pre-training objective (RTD), disentangled attention and I fine-tuned it on higher quality NLI data.
149
150 |Datasets|mnli_test_m|mnli_test_mm|anli_test|anli_test_r3|ling_test|wanli_test|
151 | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
152 |Accuracy|0.912|0.908|0.702|0.64|0.87|0.77|
153 |Speed (text/sec, A100 GPU)|696.0|697.0|488.0|425.0|828.0|980.0|
154
155 ## Limitations and bias
156 Please consult the original DeBERTa-v3 paper and literature on different NLI datasets for more information on the training data and potential biases. The model will reproduce statistical patterns in the training data.
157
158 ## Citation
159 If you use this model, please cite: Laurer, Moritz, Wouter van Atteveldt, Andreu Salleras Casas, and Kasper Welbers. 2022. ‘Less Annotating, More Classifying – Addressing the Data Scarcity Issue of Supervised Machine Learning with Deep Transfer Learning and BERT - NLI’. Preprint, June. Open Science Framework. https://osf.io/74b8k.
160
161 ### Ideas for cooperation or questions?
162 If you have questions or ideas for cooperation, contact me at m{dot}laurer{at}vu{dot}nl or [LinkedIn](https://www.linkedin.com/in/moritz-laurer/)
163
164 ### Debugging and issues
165 Note that DeBERTa-v3 was released on 06.12.21 and older versions of HF Transformers seem to have issues running the model (e.g. resulting in an issue with the tokenizer). Using Transformers>=4.13 might solve some issues.
166