README.md
3.3 KB · 109 lines · markdown Raw
1 ---
2 license: apache-2.0
3 pipeline_tag: tabular-classification
4 ---
5
6 # Mitra Classifier
7
8 Mitra classifier is a tabular foundation model that is pre-trained on purely synthetic datasets sampled from a mix of random classifiers.
9
10 ## Architecture
11
12 Mitra is based on a 12-layer Transformer of 72 M parameters, pre-trained by incorporating an in-context learning paradigm.
13
14 ## Usage
15
16 To use Mitra classifier, install AutoGluon by running:
17
18 ```sh
19 pip install uv
20 uv pip install autogluon.tabular[mitra]
21 ```
22
23 A minimal example showing how to perform inference using the Mitra classifier:
24
25 ```python
26 import pandas as pd
27 from autogluon.tabular import TabularDataset, TabularPredictor
28 from sklearn.model_selection import train_test_split
29 from sklearn.datasets import load_wine
30
31 # Load datasets
32 wine_data = load_wine()
33 wine_df = pd.DataFrame(wine_data.data, columns=wine_data.feature_names)
34 wine_df['target'] = wine_data.target
35
36 print("Dataset shapes:")
37 print(f"Wine: {wine_df.shape}")
38
39 # Create train/test splits (80/20)
40 wine_train, wine_test = train_test_split(wine_df, test_size=0.2, random_state=42, stratify=wine_df['target'])
41
42 print("Training set sizes:")
43 print(f"Wine: {len(wine_train)} samples")
44
45 # Convert to TabularDataset
46 wine_train_data = TabularDataset(wine_train)
47 wine_test_data = TabularDataset(wine_test)
48
49 # Create predictor with Mitra
50 print("Training Mitra classifier on classification dataset...")
51 mitra_predictor = TabularPredictor(label='target')
52 mitra_predictor.fit(
53 wine_train_data,
54 hyperparameters={
55 'MITRA': {'fine_tune': False}
56 },
57 )
58
59 print("\nMitra training completed!")
60
61 # Make predictions
62 mitra_predictions = mitra_predictor.predict(wine_test_data)
63 print("Sample Mitra predictions:")
64 print(mitra_predictions.head(10))
65
66 # Show prediction probabilities for first few samples
67 mitra_predictions = mitra_predictor.predict_proba(wine_test_data)
68 print(mitra_predictions.head())
69
70 # Show model leaderboard
71 print("\nMitra Model Leaderboard:")
72 mitra_predictor.leaderboard(wine_test_data)
73 ```
74
75 A minimal example showing how to perform fine-tuning using the Mitra classifier:
76
77 ```python
78 mitra_predictor_ft = TabularPredictor(label='target')
79 mitra_predictor_ft.fit(
80 wine_train_data,
81 hyperparameters={
82 'MITRA': {'fine_tune': True, 'fine_tune_steps': 10}
83 },
84 time_limit=120, # 2 minutes
85 )
86
87 print("\nMitra fine-tuning completed!")
88
89 # Show model leaderboard
90 print("\nMitra Model Leaderboard:")
91 mitra_predictor_ft.leaderboard(wine_test_data)
92 ```
93
94 ## License
95
96 This project is licensed under the Apache-2.0 License.
97
98 ## Reference
99
100 ```
101 @article{zhang2025mitra,
102 title={Mitra: Mixed synthetic priors for enhancing tabular foundation models},
103 author={Zhang, Xiyuan and Maddix, Danielle C and Yin, Junming and Erickson, Nick and Ansari, Abdul Fatir and Han, Boran and Zhang, Shuai and Akoglu, Leman and Faloutsos, Christos and Mahoney, Michael W and others},
104 journal={arXiv preprint arXiv:2510.21204},
105 year={2025}
106 }
107 ```
108
109 Amazon Science blog: [Mitra: Mixed synthetic priors for enhancing tabular foundation models](https://www.amazon.science/blog/mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models?utm_campaign=mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models&utm_medium=organic-asw&utm_source=linkedin&utm_content=2025-7-22-mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models&utm_term=2025-july)