train.py
11.8 KB · 360 lines · python Raw
1 """
2 Training Script for MHD Hybrid Nanofluid Thermal Surrogate Model
3
4 Trains a multi-output MLP on synthetic physics data generated from
5 governing equations of MHD hybrid nanofluid flow for EV battery cooling.
6
7 Usage:
8 python train.py
9
10 Outputs:
11 /app/outputs/model.pt - Trained model weights
12 /app/outputs/normalizer.json - Data normalization parameters
13 /app/outputs/config.json - Model configuration
14 /app/outputs/training_log.json - Training metrics
15 /app/outputs/evaluation.json - Final evaluation results
16 """
17
18 import torch
19 import torch.nn as nn
20 from torch.utils.data import DataLoader, TensorDataset
21 import numpy as np
22 import pandas as pd
23 import json
24 import os
25 import time
26 from sklearn.model_selection import train_test_split
27 from sklearn.metrics import r2_score, mean_absolute_error
28
29 from data_generator import generate_dataset, compute_thermal_performance
30 from model import ThermalSurrogateModel, PhysicsLoss, DataNormalizer, get_model_config
31
32 # ============================================================
33 # Configuration
34 # ============================================================
35 OUTPUT_DIR = '/app/outputs'
36 os.makedirs(OUTPUT_DIR, exist_ok=True)
37
38 SEED = 42
39 torch.manual_seed(SEED)
40 np.random.seed(SEED)
41
42 config = get_model_config()
43
44 INPUT_FEATURES = config['input_features']
45 OUTPUT_FEATURES = config['output_features']
46
47
48 def prepare_data():
49 """Generate and prepare training/validation/test data."""
50 print("=" * 60)
51 print("DATA PREPARATION")
52 print("=" * 60)
53
54 # Generate large dataset
55 print("Generating 5000 LHS samples...")
56 df = generate_dataset(n_samples=5000, seed=SEED)
57
58 # Extract features
59 X = df[INPUT_FEATURES].values.astype(np.float32)
60 y = df[OUTPUT_FEATURES].values.astype(np.float32)
61
62 # Split: 70% train, 15% val, 15% test
63 X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=SEED)
64 X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.50, random_state=SEED)
65
66 print(f" Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")
67
68 # Normalize
69 normalizer = DataNormalizer()
70 normalizer.fit(X_train, y_train)
71
72 X_train_n = normalizer.transform_input(X_train)
73 X_val_n = normalizer.transform_input(X_val)
74 X_test_n = normalizer.transform_input(X_test)
75 y_train_n = normalizer.transform_output(y_train)
76 y_val_n = normalizer.transform_output(y_val)
77 y_test_n = normalizer.transform_output(y_test)
78
79 # Convert to tensors
80 train_dataset = TensorDataset(
81 torch.tensor(X_train_n, dtype=torch.float32),
82 torch.tensor(y_train_n, dtype=torch.float32)
83 )
84 val_dataset = TensorDataset(
85 torch.tensor(X_val_n, dtype=torch.float32),
86 torch.tensor(y_val_n, dtype=torch.float32)
87 )
88
89 train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
90 val_loader = DataLoader(val_dataset, batch_size=config['batch_size'])
91
92 test_data = {
93 'X_test': X_test,
94 'y_test': y_test,
95 'X_test_n': torch.tensor(X_test_n, dtype=torch.float32),
96 'y_test_n': torch.tensor(y_test_n, dtype=torch.float32),
97 }
98
99 return train_loader, val_loader, test_data, normalizer, df
100
101
102 def train_model(train_loader, val_loader, config):
103 """Train the surrogate model."""
104 print("\n" + "=" * 60)
105 print("TRAINING")
106 print("=" * 60)
107
108 device = torch.device('cpu')
109
110 model = ThermalSurrogateModel(
111 input_dim=config['input_dim'],
112 hidden_dims=config['hidden_dims'],
113 output_dim=config['output_dim'],
114 dropout=config['dropout']
115 ).to(device)
116
117 total_params = sum(p.numel() for p in model.parameters())
118 print(f" Model parameters: {total_params:,}")
119 print(f" Architecture: {config['hidden_dims']}")
120 print(f" Learning rate: {config['learning_rate']}")
121 print(f" Epochs: {config['epochs']}")
122
123 optimizer = torch.optim.Adam(
124 model.parameters(),
125 lr=config['learning_rate'],
126 weight_decay=config['weight_decay']
127 )
128 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
129 optimizer,
130 mode='min',
131 patience=config['scheduler_patience'],
132 factor=config['scheduler_factor'],
133 min_lr=1e-6
134 )
135
136 criterion = PhysicsLoss(lambda_physics=config['physics_lambda'])
137
138 # Training loop
139 best_val_loss = float('inf')
140 best_epoch = 0
141 patience_counter = 0
142 max_patience = 300
143 training_log = []
144
145 start_time = time.time()
146
147 for epoch in range(config['epochs']):
148 # Training
149 model.train()
150 train_loss_sum = 0
151 train_data_loss_sum = 0
152 n_batches = 0
153
154 for X_batch, y_batch in train_loader:
155 X_batch, y_batch = X_batch.to(device), y_batch.to(device)
156
157 pred = model(X_batch)
158 total_loss, data_loss, physics_loss = criterion(pred, y_batch, X_batch)
159
160 optimizer.zero_grad()
161 total_loss.backward()
162 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
163 optimizer.step()
164
165 train_loss_sum += total_loss.item()
166 train_data_loss_sum += data_loss.item()
167 n_batches += 1
168
169 avg_train_loss = train_loss_sum / n_batches
170 avg_train_data = train_data_loss_sum / n_batches
171
172 # Validation
173 model.eval()
174 val_loss_sum = 0
175 val_batches = 0
176
177 with torch.no_grad():
178 for X_batch, y_batch in val_loader:
179 X_batch, y_batch = X_batch.to(device), y_batch.to(device)
180 pred = model(X_batch)
181 total_loss, _, _ = criterion(pred, y_batch, X_batch)
182 val_loss_sum += total_loss.item()
183 val_batches += 1
184
185 avg_val_loss = val_loss_sum / val_batches
186
187 scheduler.step(avg_val_loss)
188 current_lr = optimizer.param_groups[0]['lr']
189
190 # Logging
191 if (epoch + 1) % 100 == 0 or epoch == 0:
192 elapsed = time.time() - start_time
193 print(f" Epoch {epoch+1:4d}/{config['epochs']} | "
194 f"Train: {avg_train_data:.6f} | Val: {avg_val_loss:.6f} | "
195 f"LR: {current_lr:.2e} | Time: {elapsed:.0f}s")
196
197 training_log.append({
198 'epoch': epoch + 1,
199 'train_loss': avg_train_loss,
200 'train_data_loss': avg_train_data,
201 'val_loss': avg_val_loss,
202 'lr': current_lr
203 })
204
205 # Early stopping
206 if avg_val_loss < best_val_loss:
207 best_val_loss = avg_val_loss
208 best_epoch = epoch + 1
209 patience_counter = 0
210 # Save best model
211 torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'model_best.pt'))
212 else:
213 patience_counter += 1
214
215 if patience_counter >= max_patience:
216 print(f"\n Early stopping at epoch {epoch+1} (best: {best_epoch})")
217 break
218
219 total_time = time.time() - start_time
220 print(f"\n Training completed in {total_time:.1f}s")
221 print(f" Best validation loss: {best_val_loss:.6f} at epoch {best_epoch}")
222
223 # Load best model
224 model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, 'model_best.pt'), weights_only=True))
225
226 return model, training_log
227
228
229 def evaluate_model(model, test_data, normalizer, config):
230 """Comprehensive evaluation of the trained model."""
231 print("\n" + "=" * 60)
232 print("EVALUATION")
233 print("=" * 60)
234
235 device = torch.device('cpu')
236 model.eval()
237
238 X_test_n = test_data['X_test_n'].to(device)
239 y_test = test_data['y_test']
240
241 with torch.no_grad():
242 pred_n = model(X_test_n).cpu().numpy()
243
244 # Inverse transform predictions
245 pred = normalizer.inverse_transform_output(pred_n)
246
247 results = {}
248
249 print("\n Per-Output Metrics:")
250 print(f" {'Feature':<18} {'R²':>8} {'MAE':>10} {'MAPE%':>8}")
251 print(" " + "-" * 48)
252
253 for i, feat in enumerate(OUTPUT_FEATURES):
254 r2 = r2_score(y_test[:, i], pred[:, i])
255 mae = mean_absolute_error(y_test[:, i], pred[:, i])
256 # MAPE (avoid division by zero)
257 mask = np.abs(y_test[:, i]) > 1e-6
258 mape = np.mean(np.abs((y_test[mask, i] - pred[mask, i]) / y_test[mask, i])) * 100
259
260 results[feat] = {'R2': float(r2), 'MAE': float(mae), 'MAPE': float(mape)}
261 print(f" {feat:<18} {r2:>8.4f} {mae:>10.4f} {mape:>8.2f}")
262
263 # Overall R²
264 overall_r2 = np.mean([results[f]['R2'] for f in OUTPUT_FEATURES])
265 print(f"\n Overall R²: {overall_r2:.4f}")
266
267 # Validate against paper's key points
268 print("\n Paper Validation Points:")
269
270 # PSO optimal point
271 pso_input = np.array([[32.4, 0.038, 0.187]], dtype=np.float32)
272 pso_input_n = normalizer.transform_input(pso_input)
273 with torch.no_grad():
274 pso_pred_n = model(torch.tensor(pso_input_n, dtype=torch.float32)).cpu().numpy()
275 pso_pred = normalizer.inverse_transform_output(pso_pred_n)[0]
276
277 print(f" PSO Optimal (Ha=32.4, phi=0.038, u_in=0.187):")
278 print(f" T_max: Paper=40.8°C, Model={pso_pred[0]:.1f}°C")
279 print(f" Nu: Paper=18.7, Model={pso_pred[1]:.1f}")
280 print(f" S_gen: Paper=0.685, Model={pso_pred[2]:.3f}")
281
282 # Conventional cooling point (low phi, no MHD, moderate flow)
283 conv_input = np.array([[0.0, 0.01, 0.15]], dtype=np.float32)
284 conv_input_n = normalizer.transform_input(conv_input)
285 with torch.no_grad():
286 conv_pred_n = model(torch.tensor(conv_input_n, dtype=torch.float32)).cpu().numpy()
287 conv_pred = normalizer.inverse_transform_output(conv_pred_n)[0]
288
289 print(f"\n Conventional Cooling (Ha=0, phi=0.01, u_in=0.15):")
290 print(f" T_max: Paper≈61.3°C, Model={conv_pred[0]:.1f}°C")
291 print(f" Nu: Paper≈12.4, Model={conv_pred[1]:.1f}")
292
293 results['overall_R2'] = float(overall_r2)
294 results['pso_optimal_prediction'] = {
295 'T_max': float(pso_pred[0]),
296 'Nu': float(pso_pred[1]),
297 'S_gen': float(pso_pred[2])
298 }
299
300 return results
301
302
303 def main():
304 print("=" * 60)
305 print("MHD Hybrid Nanofluid Thermal Surrogate Model Training")
306 print("EV Battery Thermal Management System")
307 print("=" * 60)
308
309 # Prepare data
310 train_loader, val_loader, test_data, normalizer, df = prepare_data()
311
312 # Train model
313 model, training_log = train_model(train_loader, val_loader, config)
314
315 # Evaluate
316 eval_results = evaluate_model(model, test_data, normalizer, config)
317
318 # Save everything
319 print("\n" + "=" * 60)
320 print("SAVING ARTIFACTS")
321 print("=" * 60)
322
323 # Model weights
324 torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'model.pt'))
325 print(f" Model saved: {OUTPUT_DIR}/model.pt")
326
327 # Normalizer
328 normalizer.save(os.path.join(OUTPUT_DIR, 'normalizer.json'))
329 print(f" Normalizer saved: {OUTPUT_DIR}/normalizer.json")
330
331 # Config
332 with open(os.path.join(OUTPUT_DIR, 'config.json'), 'w') as f:
333 json.dump(config, f, indent=2)
334 print(f" Config saved: {OUTPUT_DIR}/config.json")
335
336 # Training log
337 with open(os.path.join(OUTPUT_DIR, 'training_log.json'), 'w') as f:
338 json.dump(training_log, f)
339 print(f" Training log saved: {OUTPUT_DIR}/training_log.json")
340
341 # Evaluation results
342 with open(os.path.join(OUTPUT_DIR, 'evaluation.json'), 'w') as f:
343 json.dump(eval_results, f, indent=2)
344 print(f" Evaluation saved: {OUTPUT_DIR}/evaluation.json")
345
346 # Dataset
347 df.to_csv(os.path.join(OUTPUT_DIR, 'training_data.csv'), index=False)
348 print(f" Dataset saved: {OUTPUT_DIR}/training_data.csv")
349
350 print("\n" + "=" * 60)
351 print("TRAINING COMPLETE")
352 print(f" Overall R²: {eval_results['overall_R2']:.4f}")
353 print("=" * 60)
354
355 return model, normalizer, eval_results
356
357
358 if __name__ == "__main__":
359 main()
360