siglip2.py
23.3 KB · 565 lines · python Raw
1 # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2 # you may not use this file except in compliance with the License.
3 # You may obtain a copy of the License at
4 #
5 # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6 #
7 # Unless required by applicable law or agreed to in writing, software
8 # distributed under the License is distributed on an "AS IS" BASIS,
9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 # See the License for the specific language governing permissions and
11 # limitations under the License.
12 # ==============================================================================
13 #
14 # Copyright 2025 The HuggingFace Inc. team.
15 #
16 # Licensed under the Apache License, Version 2.0 (the "License");
17 # you may not use this file except in compliance with the License.
18 # You may obtain a copy of the License at
19 #
20 # http://www.apache.org/licenses/LICENSE-2.0
21 #
22 # Unless required by applicable law or agreed to in writing, software
23 # distributed under the License is distributed on an "AS IS" BASIS,
24 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 # See the License for the specific language governing permissions and
26 # limitations under the License.
27 # ==============================================================================
28
29 from typing import Optional, Tuple, Union
30 import warnings
31
32 import torch
33 import torch.nn as nn
34 import torch.nn.functional as F
35
36 from transformers.activations import ACT2FN
37 from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
38 from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
39
40
41 class Config(object):
42 def __init__(self, config):
43 if config is not None:
44 for key, value in config.items():
45 setattr(self, key, value)
46
47 def __getitem__(self, key):
48 return getattr(self, key, None)
49
50 def __setitem__(self, key, value):
51 return setattr(self, key, value)
52
53
54 class Siglip2VisionEmbeddings(nn.Module):
55 def __init__(self, config):
56 super().__init__()
57 self.config = config
58 self.embed_dim = config.hidden_size
59 self.patch_size = config.patch_size
60
61 self.patch_embedding = nn.Linear(
62 in_features=config.num_channels * self.patch_size * self.patch_size,
63 out_features=self.embed_dim,
64 )
65
66 self.num_patches = config.num_patches
67 self.position_embedding_size = int(self.num_patches**0.5)
68 self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
69
70 @staticmethod
71 def resize_positional_embeddings(
72 positional_embeddings: torch.Tensor,
73 spatial_shapes: torch.LongTensor,
74 max_length: int,
75 ) -> torch.Tensor:
76 """
77 Resize positional embeddings to image-specific size and pad to a fixed size.
78
79 Args:
80 positional_embeddings (`torch.Tensor`):
81 Position embeddings of shape (height, width, embed_dim)
82 spatial_shapes (`torch.LongTensor`):
83 Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
84 max_length (`int`):
85 Maximum length of the positional embeddings to pad resized positional embeddings to
86
87 Returns:
88 `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
89 """
90 batch_size = spatial_shapes.shape[0]
91 embed_dim = positional_embeddings.shape[-1]
92 source_dtype = positional_embeddings.dtype
93
94 resulted_positional_embeddings = torch.empty(
95 (batch_size, max_length, embed_dim),
96 device=positional_embeddings.device,
97 dtype=source_dtype,
98 )
99
100 # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
101 positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
102
103 # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
104 if positional_embeddings.device.type == "cpu":
105 positional_embeddings = positional_embeddings.to(torch.float32)
106
107 for i in range(batch_size):
108 # (1, dim, height, width) -> (1, dim, target_height, target_width)
109 height, width = spatial_shapes[i]
110 resized_embeddings = F.interpolate(
111 positional_embeddings,
112 size=(height, width),
113 mode="bilinear",
114 align_corners=False,
115 antialias=True,
116 )
117
118 # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
119 resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
120
121 # Cast to original dtype
122 resized_embeddings = resized_embeddings.to(source_dtype)
123
124 resulted_positional_embeddings[i, : height * width] = resized_embeddings
125 resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
126
127 return resulted_positional_embeddings
128
129 def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
130 """
131 Args:
132 pixel_values (`torch.FloatTensor`):
133 Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
134 spatial_shapes (`List[Tuple[int, int]]`):
135 Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
136 """
137
138 # Apply patch embeddings to already patchified pixel values
139 target_dtype = self.patch_embedding.weight.dtype
140 patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
141
142 # Get positional resized and padded positional embeddings
143 positional_embeddings = self.position_embedding.weight.reshape(
144 self.position_embedding_size, self.position_embedding_size, -1
145 )
146 resized_positional_embeddings = self.resize_positional_embeddings(
147 positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
148 )
149
150 # Add positional embeddings to patch embeddings
151 embeddings = patch_embeds + resized_positional_embeddings
152 return embeddings
153
154
155 class Siglip2Attention(nn.Module):
156 """Multi-headed attention from 'Attention Is All You Need' paper"""
157
158 def __init__(self, config):
159 super().__init__()
160 self.config = config
161 self.embed_dim = config.hidden_size
162 self.num_heads = config.num_attention_heads
163 self.head_dim = self.embed_dim // self.num_heads
164 if self.head_dim * self.num_heads != self.embed_dim:
165 raise ValueError(
166 f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
167 f" {self.num_heads})."
168 )
169 self.scale = self.head_dim**-0.5
170 self.dropout = config.attention_dropout
171
172 self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
173 self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
174 self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
175 self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
176
177 def forward(
178 self,
179 hidden_states: torch.Tensor,
180 attention_mask: Optional[torch.Tensor] = None,
181 output_attentions: Optional[bool] = False,
182 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
183 """Input shape: Batch x Time x Channel"""
184
185 batch_size, q_len, _ = hidden_states.size()
186
187 query_states = self.q_proj(hidden_states)
188 key_states = self.k_proj(hidden_states)
189 value_states = self.v_proj(hidden_states)
190
191 query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
192 key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
193 value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
194
195 k_v_seq_len = key_states.shape[-2]
196 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
197
198 if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
199 raise ValueError(
200 f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
201 f" {attn_weights.size()}"
202 )
203
204 if attention_mask is not None:
205 if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
206 raise ValueError(
207 f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, "
208 f"but is {attention_mask.size()}"
209 )
210 attn_weights = attn_weights + attention_mask
211
212 # upcast attention to fp32
213 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
214 attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
215 attn_output = torch.matmul(attn_weights, value_states)
216
217 if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
218 raise ValueError(
219 f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
220 f" {attn_output.size()}"
221 )
222
223 attn_output = attn_output.transpose(1, 2).contiguous()
224 attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
225
226 attn_output = self.out_proj(attn_output)
227
228 return attn_output, attn_weights
229
230 class Siglip2SdpaAttention(Siglip2Attention):
231 """
232 Siglip2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
233 `Siglip2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt
234 to SDPA API.
235 """
236
237 is_causal = False
238
239 # Adapted from Siglip2Attention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
240 def forward(
241 self,
242 hidden_states: torch.Tensor,
243 attention_mask: Optional[torch.Tensor] = None,
244 output_attentions: Optional[bool] = False,
245 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
246 if output_attentions:
247 # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
248 # once this is implemented.
249 warnings.warn(
250 "Siglip2Model is using Siglip2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
251 "does not support `output_attentions=True`. Falling back to the manual attention implementation, "
252 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. '
253 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
254 )
255 return super().forward(
256 hidden_states=hidden_states,
257 attention_mask=attention_mask,
258 output_attentions=output_attentions,
259 )
260
261 batch_size, q_len, _ = hidden_states.size()
262
263 query_states = self.q_proj(hidden_states)
264 key_states = self.k_proj(hidden_states)
265 value_states = self.v_proj(hidden_states)
266
267 query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
268 key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
269 value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
270
271 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
272 # custom attn_mask,
273 # Reference: https://github.com/pytorch/pytorch/issues/112577.
274 if query_states.device.type == "cuda" and attention_mask is not None:
275 query_states = query_states.contiguous()
276 key_states = key_states.contiguous()
277 value_states = value_states.contiguous()
278
279 # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an
280 # inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph options.
281 # An inline conditional prevents dynamic shapes from compiling.
282 is_causal = True if self.is_causal and q_len > 1 else False
283
284 attn_output = torch.nn.functional.scaled_dot_product_attention(
285 query_states,
286 key_states,
287 value_states,
288 attn_mask=attention_mask,
289 dropout_p=self.dropout if self.training else 0.0,
290 is_causal=is_causal,
291 )
292
293 attn_output = attn_output.transpose(1, 2).contiguous()
294 attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
295
296 attn_output = self.out_proj(attn_output)
297
298 return attn_output, None
299
300
301 class Siglip2MLP(nn.Module):
302 def __init__(self, config):
303 super().__init__()
304 self.config = config
305 self.activation_fn = ACT2FN[config.hidden_act]
306 self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
307 self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
308
309 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
310 hidden_states = self.fc1(hidden_states)
311 hidden_states = self.activation_fn(hidden_states)
312 hidden_states = self.fc2(hidden_states)
313 return hidden_states
314
315
316 class Siglip2EncoderLayer(nn.Module):
317 def __init__(self, config):
318 super().__init__()
319 self.embed_dim = config.hidden_size
320 self.self_attn = Siglip2Attention(config=config)
321 self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
322 self.mlp = Siglip2MLP(config)
323 self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
324
325 # Ignore copy
326 def forward(
327 self,
328 hidden_states: torch.Tensor,
329 attention_mask: torch.Tensor,
330 output_attentions: Optional[bool] = False,
331 ) -> Tuple[torch.FloatTensor]:
332 """
333 Args:
334 hidden_states (`torch.FloatTensor`):
335 Input to the layer of shape `(batch, seq_len, embed_dim)`.
336 attention_mask (`torch.FloatTensor`):
337 Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very
338 large negative values.
339 output_attentions (`bool`, *optional*, defaults to `False`):
340 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
341 returned tensors for more detail.
342 """
343 residual = hidden_states
344
345 hidden_states = self.layer_norm1(hidden_states)
346 hidden_states, attn_weights = self.self_attn(
347 hidden_states=hidden_states,
348 attention_mask=attention_mask,
349 output_attentions=output_attentions,
350 )
351 hidden_states = residual + hidden_states
352
353 residual = hidden_states
354 hidden_states = self.layer_norm2(hidden_states)
355 hidden_states = self.mlp(hidden_states)
356 hidden_states = residual + hidden_states
357
358 outputs = (hidden_states,)
359
360 if output_attentions:
361 outputs += (attn_weights,)
362
363 return outputs
364
365
366 class Siglip2Encoder(nn.Module):
367 """
368 Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
369 [`Siglip2EncoderLayer`].
370
371 Args:
372 config: Siglip2Config
373 """
374
375 def __init__(self, config):
376 super().__init__()
377 self.config = config
378 self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
379 self.gradient_checkpointing = True
380
381 # Ignore copy
382 def forward(
383 self,
384 inputs_embeds,
385 attention_mask: Optional[torch.Tensor] = None,
386 output_attentions: Optional[bool] = None,
387 output_hidden_states: Optional[bool] = None,
388 return_dict: Optional[bool] = None,
389 ) -> Union[Tuple, BaseModelOutput]:
390 r"""
391 Args:
392 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
393 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
394 This is useful if you want more control over how to convert `input_ids` indices into associated vectors
395 than the model's internal embedding lookup matrix.
396 attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
397 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
398
399 - 1 for tokens that are **not masked**,
400 - 0 for tokens that are **masked**.
401
402 [What are attention masks?](../glossary#attention-mask)
403 output_attentions (`bool`, *optional*):
404 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
405 returned tensors for more detail.
406 output_hidden_states (`bool`, *optional*):
407 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
408 for more detail.
409 return_dict (`bool`, *optional*):
410 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
411 """
412 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413 output_hidden_states = (
414 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
415 )
416 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
417
418 encoder_states = () if output_hidden_states else None
419 all_attentions = () if output_attentions else None
420
421 hidden_states = inputs_embeds
422 for layer_index, encoder_layer in enumerate(self.layers): # len(self.layers): 27
423 if output_hidden_states:
424 encoder_states = encoder_states + (hidden_states,)
425
426 layer_outputs = encoder_layer(
427 hidden_states,
428 attention_mask,
429 output_attentions=output_attentions,
430 )
431
432 hidden_states = layer_outputs[0]
433
434 if output_attentions:
435 all_attentions = all_attentions + (layer_outputs[1],)
436
437 if output_hidden_states:
438 encoder_states = encoder_states + (hidden_states,)
439
440 if not return_dict:
441 return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
442 return BaseModelOutput(
443 last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
444 )
445
446
447 class Siglip2MultiheadAttentionPoolingHead(nn.Module):
448 """Multihead Attention Pooling."""
449
450 def __init__(self, config):
451 super().__init__()
452
453 self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
454 self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
455 self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
456 self.mlp = Siglip2MLP(config)
457 self.num_heads = config.num_attention_heads
458
459 def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
460 batch_size = hidden_state.shape[0]
461 probe = self.probe.repeat(batch_size, 1, 1)
462
463 if attention_mask is not None:
464 target_len, source_len = probe.shape[1], hidden_state.shape[1]
465 attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
466 attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
467 attention_mask = attention_mask.reshape(-1, target_len, source_len)
468
469 hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
470
471 residual = hidden_state
472 hidden_state = self.layernorm(hidden_state)
473 hidden_state = residual + self.mlp(hidden_state)
474
475 return hidden_state[:, 0]
476
477
478 class Siglip2VisionTransformer(nn.Module):
479 def __init__(self, config):
480 super().__init__()
481 config = Config(config)
482 self.config = config
483 embed_dim = config.hidden_size
484
485 self.embeddings = Siglip2VisionEmbeddings(config)
486 self.encoder = Siglip2Encoder(config)
487 self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
488 self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
489 if self.use_head:
490 self.head = Siglip2MultiheadAttentionPoolingHead(config)
491 self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
492
493 def forward(
494 self,
495 pixel_values: torch.FloatTensor,
496 attention_mask: torch.Tensor,
497 spatial_shapes: torch.LongTensor,
498 output_attentions: Optional[bool] = None,
499 output_hidden_states: Optional[bool] = None,
500 return_dict: Optional[bool] = None,
501 ) -> Union[Tuple, BaseModelOutputWithPooling]:
502 r"""
503 Returns:
504
505 """
506 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
507 output_hidden_states = (
508 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
509 )
510 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
511
512 hidden_states = self.embeddings(pixel_values, spatial_shapes)
513
514 if attention_mask is not None and not self._use_flash_attention_2:
515 # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
516 encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
517 else:
518 encoder_attention_mask = attention_mask
519
520 encoder_outputs = self.encoder(
521 inputs_embeds=hidden_states,
522 attention_mask=encoder_attention_mask,
523 output_attentions=output_attentions,
524 output_hidden_states=output_hidden_states,
525 return_dict=return_dict,
526 )
527
528 last_hidden_state = encoder_outputs[0]
529 last_hidden_state = self.post_layernorm(last_hidden_state)
530
531 pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
532 if not return_dict:
533 return (last_hidden_state, pooler_output) + encoder_outputs[1:]
534
535 return BaseModelOutputWithPooling(
536 last_hidden_state=last_hidden_state,
537 pooler_output=pooler_output,
538 hidden_states=encoder_outputs.hidden_states,
539 attentions=encoder_outputs.attentions,
540 )
541
542
543 class LightProjector(nn.Module):
544 def __init__(self, config):
545 config = Config(config)
546 super().__init__()
547
548 if config.projector_type == "linear":
549 modules = nn.Linear(config.input_dim, config.n_embed)
550
551 elif config.projector_type == "mlp_gelu":
552 modules = [nn.Linear(config.input_dim, config.n_embed)]
553 for _ in range(1, config.depth):
554 modules.append(nn.GELU())
555 modules.append(nn.Linear(config.n_embed, config.n_embed))
556 modules = nn.Sequential(*modules)
557
558 else:
559 raise ValueError(f"Unknown projector type: {config.projector_type}")
560
561 self.layers = modules
562
563 def forward(self, x):
564 return self.layers(x)
565