mlp.py
6.1 KB · 194 lines · python Raw
1 # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2 # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
4 # Copyright (c) 2023, Tri Dao.
5
6 import torch
7 import torch.nn as nn
8 import torch.nn.functional as F
9 from torch.distributed import ProcessGroup
10
11
12 try:
13 from flash_attn.ops.activations import swiglu
14 except ImportError:
15 swiglu = None
16
17 try:
18 from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
19 except ImportError:
20 ColumnParallelLinear, RowParallelLinear = None, None
21
22 try:
23 from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24 except ImportError:
25 FusedMLP, ParallelFusedMLP = None, None
26
27
28 class Mlp(nn.Module):
29 def __init__(
30 self,
31 in_features,
32 hidden_features=None,
33 out_features=None,
34 activation=F.gelu,
35 bias1=True,
36 bias2=True,
37 return_residual=False,
38 device=None,
39 dtype=None,
40 ):
41 factory_kwargs = {"device": device, "dtype": dtype}
42 super().__init__()
43 out_features = out_features if out_features is not None else in_features
44 hidden_features = hidden_features if hidden_features is not None else in_features * 4
45 self.return_residual = return_residual
46 self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47 self.activation = activation
48 self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
50 def forward(self, x):
51 y = self.fc1(x)
52 y = self.activation(y)
53 y = self.fc2(y)
54 return y if not self.return_residual else (y, x)
55
56
57 class ParallelMLP(nn.Module):
58 def __init__(
59 self,
60 in_features,
61 hidden_features=None,
62 out_features=None,
63 activation=F.gelu,
64 process_group: ProcessGroup = None,
65 sequence_parallel=True,
66 bias1=True,
67 bias2=True,
68 device=None,
69 dtype=None,
70 ):
71 factory_kwargs = {"device": device, "dtype": dtype}
72 super().__init__()
73 assert ColumnParallelLinear is not None, "Need to install fused_dense"
74 assert RowParallelLinear is not None, "Need to install fused_dense"
75 out_features = out_features if out_features is not None else in_features
76 hidden_features = hidden_features if hidden_features is not None else in_features * 4
77 self.fc1 = ColumnParallelLinear(
78 in_features,
79 hidden_features,
80 process_group,
81 bias=bias1,
82 sequence_parallel=sequence_parallel,
83 **factory_kwargs,
84 )
85 self.activation = activation
86 self.fc2 = RowParallelLinear(
87 hidden_features,
88 out_features,
89 process_group,
90 bias=bias2,
91 sequence_parallel=sequence_parallel,
92 **factory_kwargs,
93 )
94
95 def forward(self, x):
96 y = self.fc1(x)
97 y = self.activation(y)
98 y = self.fc2(y)
99 return y
100
101
102 class GatedMlp(nn.Module):
103 def __init__(
104 self,
105 in_features,
106 hidden_features=None,
107 out_features=None,
108 activation=F.sigmoid,
109 bias1=True,
110 bias2=True,
111 multiple_of=128,
112 return_residual=False,
113 device=None,
114 dtype=None,
115 ):
116 factory_kwargs = {"device": device, "dtype": dtype}
117 super().__init__()
118 out_features = out_features if out_features is not None else in_features
119 hidden_features = (
120 hidden_features if hidden_features is not None else int(8 * in_features / 3)
121 )
122 hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
123 self.return_residual = return_residual
124 self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
125 self.activation = activation
126 self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
127
128 def forward(self, x):
129 y = self.fc1(x)
130 if self.activation == F.sigmoid: # Special case for GLU
131 y = F.glu(y, dim=-1)
132 elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
133 y, gate = y.chunk(2, dim=-1)
134 y = swiglu(gate, y)
135 else:
136 y, gate = y.chunk(2, dim=-1)
137 y = y * self.activation(gate)
138 y = self.fc2(y)
139 return y if not self.return_residual else (y, x)
140
141
142 class ParallelGatedMlp(nn.Module):
143 """Parallel GatedMlp"""
144
145 def __init__(
146 self,
147 in_features,
148 process_group,
149 hidden_features=None,
150 out_features=None,
151 activation=F.sigmoid,
152 bias1=True,
153 bias2=True,
154 multiple_of=128,
155 sequence_parallel=True,
156 device=None,
157 dtype=None,
158 ):
159 factory_kwargs = {"device": device, "dtype": dtype}
160 super().__init__()
161 out_features = out_features if out_features is not None else in_features
162 hidden_features = (
163 hidden_features if hidden_features is not None else int(8 * in_features / 3)
164 )
165 hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
166 if ColumnParallelLinear is None or RowParallelLinear is None:
167 raise ImportError("fused_dense is not installed")
168 self.fc1 = ColumnParallelLinear(
169 in_features,
170 2 * hidden_features,
171 process_group,
172 bias=bias1,
173 sequence_parallel=sequence_parallel,
174 **factory_kwargs,
175 )
176 self.activation = activation
177 self.fc2 = RowParallelLinear(
178 hidden_features,
179 out_features,
180 process_group,
181 bias=bias2,
182 sequence_parallel=sequence_parallel,
183 **factory_kwargs,
184 )
185
186 def forward(self, x):
187 y = self.fc1(x)
188 if self.activation == F.sigmoid: # Special case for GLU
189 y = F.glu(y, dim=-1)
190 else:
191 y, gate = y.chunk(2, dim=-1)
192 y = y * self.activation(gate)
193 y = self.fc2(y)
194 return y