modeling_kimi_k25.py
49.8 KB · 1249 lines · python Raw
1 # coding=utf-8
2 # Copyright 2025-2026 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
3 #
4 # The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for Kimi-K2.5.
5 #
6 # Licensing Information:
7 # - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
8 # - Other parts of the code are licensed under the MIT License.
9 #
10 # Apache License, Version 2.0:
11 # Licensed under the Apache License, Version 2.0 (the "License");
12 # you may not use this file except in compliance with the License.
13 # You may obtain a copy of the License at
14 #
15 # http://www.apache.org/licenses/LICENSE-2.0
16 #
17 # Unless required by applicable law or agreed to in writing, software
18 # distributed under the License is distributed on an "AS IS" BASIS,
19 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 # See the License for the specific language governing permissions and
21 # limitations under the License.
22 #
23 # MIT License:
24 # Permission is hereby granted, free of charge, to any person obtaining a copy
25 # of this software and associated documentation files (the "Software"), to deal
26 # in the Software without restriction, including without limitation the rights
27 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
28 # copies of the Software, and to permit persons to whom the Software is
29 # furnished to do so, subject to the following conditions:
30 #
31 # The above copyright notice and this permission notice shall be included in all
32 # copies or substantial portions of the Software.
33 #
34 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
38 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
39 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 # SOFTWARE.
41 import math
42 from collections.abc import Sequence
43 from copy import deepcopy
44 from typing import Optional
45
46 import numpy as np
47 import torch
48 import torch.nn as nn
49 import torch.nn.functional as F
50 from transformers import activations
51
52 try:
53 from transformers.activations import PytorchGELUTanh
54 except ImportError:
55 from transformers.activations import GELUTanh
56 activations.PytorchGELUTanh = GELUTanh
57 PytorchGELUTanh = GELUTanh
58 from transformers.activations import PytorchGELUTanh
59 from transformers.cache_utils import Cache
60 from transformers.configuration_utils import PretrainedConfig
61 from transformers.modeling_utils import PreTrainedModel
62 from transformers.models.llava.modeling_llava import \
63 LlavaCausalLMOutputWithPast
64 from transformers.utils import is_flash_attn_2_available
65
66 from .configuration_kimi_k25 import KimiK25Config
67 from .modeling_deepseek import DeepseekV3ForCausalLM
68
69 # Flash attention imports
70 if is_flash_attn_2_available():
71 from flash_attn import flash_attn_varlen_func
72 else:
73 flash_attn_varlen_func = None
74
75
76 def multihead_attention(
77 q: torch.Tensor,
78 k: torch.Tensor,
79 v: torch.Tensor,
80 q_cu_seqlens: torch.Tensor | None = None,
81 k_cu_seqlens: torch.Tensor | None = None,
82 max_seqlen_q: int | None = None,
83 max_seqlen_k: int | None = None,
84 deterministic: bool = False,
85 ):
86 """Multi-head attention using flash attention 2.
87
88 Args:
89 q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
90 or (tot_seqlens, num_heads, head_dim) if packing.
91 q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
92 The first element should be 0 and the last element should be q.shape[0].
93 k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
94 The first element should be 0 and the last element should be k.shape[0].
95
96 Returns:
97 output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
98 where dim = num_heads * head_dim
99 """
100 attn_out = flash_attn_varlen_func(
101 q,
102 k,
103 v,
104 q_cu_seqlens,
105 k_cu_seqlens,
106 max_seqlen_q,
107 max_seqlen_k,
108 causal=False,
109 deterministic=deterministic,
110 )
111 if isinstance(attn_out, tuple):
112 attn_out = attn_out[0]
113
114 attn_out = attn_out.flatten(start_dim=-2)
115
116 return attn_out
117
118
119 def eager_attention(
120 q: torch.Tensor,
121 k: torch.Tensor,
122 v: torch.Tensor,
123 q_cu_seqlens: Optional[torch.Tensor] = None,
124 k_cu_seqlens: Optional[torch.Tensor] = None,
125 **kwargs,
126 ) -> torch.Tensor:
127 seq_length = q.shape[0]
128 attention_mask = torch.zeros([1, seq_length, seq_length],
129 device=q.device,
130 dtype=torch.bool)
131 for i in range(1, len(q_cu_seqlens)):
132 attention_mask[
133 ...,
134 q_cu_seqlens[i - 1]:q_cu_seqlens[i],
135 q_cu_seqlens[i - 1]:q_cu_seqlens[i],
136 ] = True
137 q = q.transpose(0, 1)
138 k = k.transpose(0, 1)
139 v = v.transpose(0, 1)
140
141 attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
142 attn_weight += attention_mask
143 attn_weight = torch.softmax(attn_weight, dim=-1,
144 dtype=torch.float32).to(q.dtype)
145
146 attn_output = attn_weight @ v
147 attn_output = attn_output.transpose(0, 1)
148 attn_output = attn_output.reshape(seq_length, -1)
149 return attn_output
150
151
152 VL_VISION_ATTENTION_FUNCTIONS = {
153 "flash_attention_2": multihead_attention,
154 "eager": eager_attention,
155 }
156
157
158 def _apply_rope_input_validation(x, freqs_cis):
159 assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
160 assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
161 assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
162 assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
163
164
165 def get_rope_shape_decorate(func):
166 _get_rope_shape_first_call_flag = set()
167
168 def wrapper(org, interpolation_mode, shape):
169 key = (org.requires_grad, torch.is_grad_enabled(), interpolation_mode)
170 if key not in _get_rope_shape_first_call_flag:
171 _get_rope_shape_first_call_flag.add(key)
172 _ = func(org, interpolation_mode, shape=(64, 64))
173 return func(org, interpolation_mode, shape)
174
175 return wrapper
176
177
178 @get_rope_shape_decorate
179 @torch.compile(dynamic=True)
180 def get_rope_shape(org, interpolation_mode, shape):
181 return (F.interpolate(
182 org.permute((2, 0, 1)).unsqueeze(0),
183 size=shape,
184 mode=interpolation_mode,
185 ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1))
186
187
188 def apply_rope(xq: torch.Tensor, xk: torch.Tensor,
189 freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
190 """
191 Args: (The leading dimensions of all inputs should be the same)
192 xq: query, tensor of shape (..., num_heads, head_dim)
193 xk: key, tensor of shape (..., num_heads, head_dim)
194 freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
195 Returns:
196 xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
197 """
198 _apply_rope_input_validation(xq, freqs_cis)
199 _apply_rope_input_validation(xk, freqs_cis)
200
201 freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
202 # ..., num_heads, head_dim/2
203 xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
204 xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
205 xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(
206 -2) # ..., num_heads, head_dim
207 xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(
208 -2) # ..., num_heads, head_dim
209 return xq_out.type_as(xq), xk_out.type_as(xk)
210
211
212 def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
213 """
214 From:
215 https://github.com/OpenGVLab/InternVideo/blob/421f6d2361fc8f61a3394244571f2601a4e99e29/InternVideo2/multi_modality/models/backbones/internvideo2/pos_embed.py#L86
216 embed_dim: output dimension for each position
217 pos: a list of positions to be encoded: size (M,)
218 out: (M, D)
219 """
220 assert embed_dim % 2 == 0
221 omega = np.arange(embed_dim // 2, dtype=np.float32)
222 omega /= embed_dim / 2.0
223 omega = 1.0 / 10000**omega # (D/2,)
224
225 pos = pos.reshape(-1) # (M,)
226 out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
227
228 emb_sin = np.sin(out) # (M, D/2)
229 emb_cos = np.cos(out) # (M, D/2)
230
231 emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
232 return emb
233
234
235 def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
236 """
237 t_size: int of the temporal size
238 return:
239 pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
240 """
241 grid_t = np.arange(t_size, dtype=np.float32)
242 pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
243 if cls_token:
244 pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
245 axis=0)
246 return pos_embed
247
248
249 class Learnable2DInterpPosEmbDivided_fixed(nn.Module):
250
251 def __init__(self,
252 height: int,
253 width: int,
254 num_frames: int,
255 dim: int,
256 interpolation_mode: str = 'bicubic') -> None:
257 super().__init__()
258 self.height = height
259 self.width = width
260 self.num_frames = num_frames
261 self.dim = dim
262 self.interpolation_mode = interpolation_mode
263 self.weight = nn.Parameter(torch.empty(height, width, dim))
264 self.register_buffer('time_weight',
265 torch.from_numpy(
266 get_1d_sincos_pos_embed(
267 self.dim,
268 self.num_frames)).float().unsqueeze(1),
269 persistent=False)
270
271 self.reset_parameters()
272
273 def reset_parameters(self):
274 nn.init.normal_(self.weight)
275
276 def forward(self, x: torch.Tensor,
277 grid_thws: torch.Tensor) -> torch.Tensor:
278 pos_embs = []
279 for t, h, w in grid_thws.tolist():
280 assert t <= self.num_frames, f't:{t} > self.num_frames:{self.num_frames}'
281 if (h, w) == self.weight.shape[:-1]:
282 pos_emb_2d = self.weight.flatten(end_dim=1)
283 else:
284 pos_emb_2d = get_rope_shape(
285 self.weight,
286 interpolation_mode=self.interpolation_mode,
287 shape=(h, w),
288 )
289
290 if t == 1:
291 pos_emb_3d = pos_emb_2d
292 else:
293 pos_emb_3d = pos_emb_2d.unsqueeze(0).repeat(
294 t, 1, 1) + self.time_weight[0:t]
295
296 pos_embs.append(pos_emb_3d.reshape(-1, pos_emb_3d.shape[-1]))
297
298 out = x + torch.cat(pos_embs)
299 return out
300
301
302 class MoonVision3dPatchEmbed(nn.Module):
303
304 def __init__(self,
305 out_dim: int,
306 in_dim: int = 3,
307 patch_size: int | tuple[int, int] = (14, 14),
308 pos_emb_height: int = 14,
309 pos_emb_width: int = 14,
310 pos_emb_time: int = 4,
311 pos_emb_type: str = 'divided_fixed'):
312 super().__init__()
313 assert isinstance(
314 patch_size,
315 int | Sequence), f'Invalid patch_size type: {type(patch_size)}'
316 if isinstance(patch_size, int):
317 patch_size = (patch_size, patch_size)
318 assert (len(patch_size) == 2
319 ), f'Expected patch_size to be a tuple of 2, got {patch_size}'
320 self.patch_size = patch_size
321
322 self.proj = nn.Conv2d(in_dim,
323 out_dim,
324 kernel_size=patch_size,
325 stride=patch_size)
326
327 if pos_emb_type == 'divided_fixed':
328 self.pos_emb = Learnable2DInterpPosEmbDivided_fixed(
329 height=pos_emb_height,
330 width=pos_emb_width,
331 num_frames=pos_emb_time,
332 dim=out_dim)
333 else:
334 raise NotImplementedError(
335 f'Not support pos_emb_type: {pos_emb_type}')
336
337 def forward(self, x: torch.Tensor,
338 grid_thws: torch.Tensor) -> torch.Tensor:
339 """
340 Args:
341 x (L, Channels): input tensor
342 grid_hws (N, 3): temporal, height and width
343
344 Returns:
345 (L, Cout) tensor
346 """
347 x = self.proj(x).view(x.size(0), -1)
348 # apply positional embedding
349 x = self.pos_emb(x, grid_thws)
350 return x
351
352
353 class Rope2DPosEmbRepeated(nn.Module):
354 """2D rotary position embedding with multi-resolution support.
355
356 This class is intended to be used in the following way:
357 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
358 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
359 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
360 The rope is shared across all attention layers and all heads.
361
362 Refs:
363 - RoFormer: https://arxiv.org/abs/2104.09864
364 - VisionLLaMA: https://arxiv.org/abs/2403.00522
365 - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
366
367 Args:
368 dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
369 max_height (int): the maximum height of the 2D grid
370 max_width (int): the maximum width of the 2D grid
371 theta_base (float): the base of the theta
372 device (str): the device to store the precomputed cis
373 """
374
375 def __init__(self,
376 dim: int,
377 max_height: int,
378 max_width: int,
379 theta_base=10000):
380 super().__init__()
381 self.dim = dim
382 assert self.dim % 4 == 0, 'dim must be divisible by 4'
383 self.max_height = max_height
384 self.max_width = max_width
385 self.theta_base = theta_base
386
387 def extra_repr(self):
388 return f'dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}'
389
390 def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
391 """Calculate the cis(freqs) for each position in the 2D grid.
392
393 Return: complex tensor of shape (max_height, max_width, dim//2) and value:
394 height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
395 weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
396 note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
397 """
398 N = self.max_height * self.max_width
399 flat_pos = torch.arange(0, N).float().to(device)
400 x_pos = flat_pos % self.max_width
401 y_pos = flat_pos // self.max_width
402 dim_range = (torch.arange(0, self.dim,
403 4)[:(self.dim // 4)].float().to(device)
404 ) # C/4
405 freqs = 1.0 / (self.theta_base**(dim_range / self.dim))
406 x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
407 y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
408 x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
409 y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
410 # N, C/4, 2
411 freqs_cis = torch.cat(
412 [x_cis.unsqueeze(dim=-1),
413 y_cis.unsqueeze(dim=-1)], dim=-1)
414 # max_height, max_width, C/2
415 freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
416 return freqs_cis
417
418 def get_freqs_cis(self, grid_thws: torch.Tensor,
419 device: torch.device) -> torch.Tensor:
420 """
421 Args:
422 grid_thws (torch.Tensor): grid time, height and width
423
424 Returns:
425 freqs_cis: tensor of shape (sum(t * height * width), dim//2)
426 """
427 if not hasattr(self, 'freqs_cis'):
428 self.register_buffer('freqs_cis',
429 self._precompute_freqs_cis(device),
430 persistent=False)
431
432 shapes = grid_thws.tolist()
433 assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width
434 for t, h, w in shapes), (
435 shapes,
436 self.max_height,
437 self.max_width,
438 )
439 freqs_cis = torch.cat(
440 [
441 self.freqs_cis[:h, :w].reshape(-1, self.dim // 2).repeat(t, 1)
442 for t, h, w in shapes
443 ],
444 dim=0,
445 )
446 return freqs_cis
447
448
449 class MLP2(nn.Module):
450 """
451 Args:
452 dims: [in_dim, hidden_dim, out_dim]
453 bias: whether to use bias in linear layer.
454 """
455
456 def __init__(self, dims: list[int], activation, bias=True):
457 super().__init__()
458 assert len(dims) == 3
459 self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
460 self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
461 self.activation = activation
462 for m in [self.fc0, self.fc1]:
463 nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
464 if m.bias is not None:
465 nn.init.zeros_(m.bias)
466
467 def forward(self, x: torch.Tensor) -> torch.Tensor:
468 x = self.fc0(x)
469 x = self.activation(x)
470 return self.fc1(x)
471
472
473 class MoonViTEncoderLayer(nn.Module):
474
475 def __init__(
476 self,
477 num_heads: int,
478 hidden_dim: int,
479 mlp_dim: int,
480 *,
481 attn_implementation: str = 'flash_attention_2',
482 activation=F.gelu,
483 attn_bias: bool = False,
484 use_deterministic_attn: bool = False,
485 ):
486 super().__init__()
487 self.num_heads = num_heads
488 self.hidden_dim = hidden_dim
489 self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
490 self.attn_implementation = attn_implementation
491 self.use_deterministic_attn = use_deterministic_attn
492
493 self.norm0 = nn.LayerNorm(hidden_dim)
494 self.norm1 = nn.LayerNorm(hidden_dim)
495 self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
496 self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
497 self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
498
499 def attention_qkvpacked(
500 self,
501 x: torch.Tensor,
502 cu_seqlens: torch.Tensor,
503 max_seqlen: torch.Tensor,
504 rope_freqs_cis: torch.Tensor | None = None,
505 ):
506 """
507 Args:
508 x (torch.Tensor): (batch_size, seqlen, hidden_dim)
509 cu_seqlens (torch.Tensor):
510 """
511 xqkv = self.wqkv(x)
512
513 qkv_shape = xqkv.size()[:-1] + (
514 3,
515 self.num_heads,
516 self.hidden_size_per_attention_head,
517 )
518 # xqkv: (batch_size, seqlen, 3, nheads, headdim)
519 xqkv = xqkv.view(*qkv_shape)
520 xq, xk, xv = torch.unbind(xqkv, dim=-3)
521
522 xq, xk = apply_rope(xq, xk, rope_freqs_cis)
523
524 attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
525 attn_out = attn_func(xq,
526 xk,
527 xv,
528 q_cu_seqlens=cu_seqlens,
529 k_cu_seqlens=cu_seqlens,
530 max_seqlen_k=max_seqlen,
531 max_seqlen_q=max_seqlen,
532 deterministic=self.use_deterministic_attn)
533
534 attn_out = self.wo(attn_out)
535 return attn_out
536
537 def forward(
538 self,
539 hidden_states: torch.Tensor,
540 cu_seqlens: torch.Tensor,
541 max_seqlen: int,
542 rope_freqs_cis: torch.Tensor | None = None,
543 ):
544 residual = hidden_states
545 hidden_states = self.norm0(hidden_states)
546
547 hidden_states = self.attention_qkvpacked(hidden_states, cu_seqlens,
548 max_seqlen, rope_freqs_cis)
549 hidden_states = residual + hidden_states
550
551 residual = hidden_states
552 hidden_states = self.norm1(hidden_states)
553 hidden_states = self.mlp(hidden_states)
554 hidden_states = residual + hidden_states
555
556 return hidden_states
557
558
559 class MoonViT3dEncoder(nn.Module):
560
561 def __init__(self,
562 hidden_dim: int,
563 num_layers: int,
564 block_cfg: dict,
565 video_attn_type: str = 'spatial_temporal',
566 use_deterministic_attn: bool = False) -> None:
567 super().__init__()
568
569 assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
570 self.video_attn_type = video_attn_type
571 self.rope_2d = Rope2DPosEmbRepeated(
572 block_cfg['hidden_dim'] // block_cfg['num_heads'], 512, 512)
573 self.blocks = nn.ModuleList([
574 MoonViTEncoderLayer(**block_cfg,
575 use_deterministic_attn=use_deterministic_attn)
576 for _ in range(num_layers)
577 ])
578 self.final_layernorm = nn.LayerNorm(hidden_dim)
579
580 def forward(
581 self,
582 hidden_states: torch.Tensor,
583 grid_thws: torch.Tensor,
584 ) -> torch.Tensor:
585 rope_freqs_cis = self.rope_2d.get_freqs_cis(
586 grid_thws=grid_thws, device=hidden_states.device)
587
588 lengths = torch.cat((
589 torch.zeros(1, dtype=grid_thws.dtype, device=grid_thws.device),
590 grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2],
591 ))
592
593 max_seqlen = lengths.max()
594 cu_seqlens = lengths.to(hidden_states.device).cumsum(dim=0,
595 dtype=torch.int32)
596 for block in self.blocks:
597 hidden_states = block(hidden_states,
598 cu_seqlens,
599 max_seqlen,
600 rope_freqs_cis=rope_freqs_cis)
601
602 hidden_states = self.final_layernorm(hidden_states)
603 return hidden_states
604
605
606 def tpool_patch_merger(
607 x: torch.Tensor,
608 grid_thws: torch.Tensor,
609 merge_kernel_size: tuple[int, int] = (2, 2),
610 ) -> list[torch.Tensor]:
611 d_model = x.size(-1)
612
613 outputs = []
614 pre_sum = 0
615 for t, h, w in grid_thws.tolist():
616 # Get the current sequence
617 seq = x[pre_sum:pre_sum + t * h * w]
618 # Reshape along self.merge_kernel_size and concat to the last dimension
619 kernel_height, kernel_width = merge_kernel_size
620 new_height, new_width = h // kernel_height, w // kernel_width
621 reshaped_seq = seq.view(t, new_height, kernel_height, new_width,
622 kernel_width, d_model)
623 reshaped_seq = reshaped_seq.permute(0, 1,
624 3, 2, 4, 5).contiguous().mean(
625 dim=0) # temporal pooling
626 padded_seq = reshaped_seq.view(new_height * new_width,
627 kernel_height * kernel_width, -1)
628 outputs.append(padded_seq)
629 pre_sum += t * h * w
630
631 return outputs
632
633
634 class MoonViT3dPretrainedModel(PreTrainedModel):
635 config_class = None
636 model_type = 'moonvit3d'
637 _no_split_modules = ['PackingTransformer']
638 _supports_flash_attn_2 = True
639 _supports_sdpa = True
640
641 def __init__(self, config, *inputs, **kwargs):
642 super().__init__(config, *inputs, **kwargs)
643 config = deepcopy(config)
644 self.merge_kernel_size = config.merge_kernel_size
645 self.patch_size = config.patch_size
646 self.merge_type = config.merge_type
647
648 self.patch_embed = MoonVision3dPatchEmbed(
649 out_dim=config.hidden_size,
650 patch_size=config.patch_size,
651 pos_emb_height=config.init_pos_emb_height,
652 pos_emb_width=config.init_pos_emb_width,
653 pos_emb_time=config.init_pos_emb_time,
654 pos_emb_type=config.pos_emb_type,
655 )
656
657 self.encoder = MoonViT3dEncoder(hidden_dim=config.hidden_size,
658 num_layers=config.num_hidden_layers,
659 block_cfg={
660 'num_heads':
661 config.num_attention_heads,
662 'hidden_dim':
663 config.hidden_size,
664 'mlp_dim':
665 config.intermediate_size,
666 'activation':
667 PytorchGELUTanh(),
668 'attn_bias':
669 True,
670 'attn_implementation':
671 config._attn_implementation,
672 },
673 video_attn_type=config.video_attn_type)
674
675 def forward(self, pixel_values: torch.Tensor,
676 grid_thws: torch.Tensor) -> torch.Tensor:
677 """
678 Args:
679 pixel_values (torch.Tensor): The input pixel values.
680 grid_thws (torch.Tensor): Temporal, height and width.
681
682 Returns:
683 torch.Tensor: The output tokens.
684 """
685 # grid_thws = grid_thws.to('cpu')
686 assert grid_thws.ndim == 2, f'grid_thws should be 2D, got {grid_thws.ndim}'
687 assert grid_thws.size(1) == 3, f'No support for thw: {grid_thws}'
688 hidden_states = self.patch_embed(pixel_values, grid_thws)
689 hidden_states = self.encoder(hidden_states, grid_thws)
690 if self.merge_type == 'sd2_tpool': # spatial downsampling 2x with temporal pooling all
691 hidden_states = tpool_patch_merger(
692 hidden_states,
693 grid_thws,
694 merge_kernel_size=self.merge_kernel_size)
695 else:
696 raise NotImplementedError(f'Not support {self.merge_type}')
697
698 return hidden_states
699
700
701 # ============================================================================
702 # MM Projector Helper Classes (from mm_projector/modeling_mm_projectors.py)
703 # ============================================================================
704
705
706 class IdentityMap(nn.Module):
707
708 def __init__(self):
709 super().__init__()
710
711 def forward(self, x, *args, **kwargs):
712 return x
713
714
715 class MLP(nn.Module):
716
717 def __init__(self, config):
718 super().__init__()
719 # TODO, use faster LayerNorm
720 self.pre_norm = nn.LayerNorm(config.mm_hidden_size)
721 self.proj = nn.Sequential(
722 nn.Linear(config.mm_hidden_size, config.hidden_size), nn.GELU(),
723 nn.Linear(config.hidden_size, config.hidden_size))
724
725 def forward(self, x, *args, **kwargs):
726 assert isinstance(x,
727 list | tuple), f'x is not a list or tuple: {type(x)}'
728 lengths = [item.shape[0] for item in x]
729 x = torch.cat(x, dim=0)
730 x = self.pre_norm(x)
731 x = self.proj(x)
732 x = torch.split(x, lengths, dim=0)
733
734 return x
735
736
737 class PatchMergerMLP(nn.Module):
738
739 def __init__(self, config):
740 super().__init__()
741 eps = config.projector_ln_eps
742 self.hidden_size = config.mm_hidden_size * (
743 config.merge_kernel_size[0] * config.merge_kernel_size[1])
744 self.pre_norm = nn.LayerNorm(config.mm_hidden_size, eps=eps)
745 self.proj = nn.Sequential(
746 nn.Linear(self.hidden_size, self.hidden_size),
747 nn.GELU(),
748 nn.Linear(self.hidden_size, config.hidden_size),
749 )
750
751 def forward(self, x, *args, **kwargs):
752 if isinstance(x, list) or isinstance(x, tuple):
753 x = [
754 self.proj(self.pre_norm(item).view(item.shape[0], -1))
755 for item in x
756 ]
757 else:
758 # B, N, N_k, C = x.shape
759 B = x.shape[0]
760 x = self.proj(self.pre_norm(x).view(B, -1, self.hidden_size))
761 return x
762
763
764 class KimiK25PreTrainedModel(PreTrainedModel):
765 config_class = KimiK25Config
766 base_model_prefix = "model"
767 _no_split_modules = [
768 "MoonViT3dPretrainedModel",
769 "MoonViTEncoderLayer",
770 "DeepseekDecoderLayer",
771 "PatchMergerMLP",
772 ]
773 _skip_keys_device_placement = "past_key_values"
774 _supports_flash_attn_2 = True
775 _supports_sdpa = False
776
777 def _init_weights(self, module):
778 # important: this ported version of Llava isn't meant for training from scratch - only
779 # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
780 # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
781 std = (self.config.initializer_range if hasattr(
782 self.config, "initializer_range") else
783 self.config.text_config.initializer_range)
784
785 if hasattr(module, "class_embedding"):
786 module.class_embedding.data.normal_(mean=0.0, std=std)
787
788 if isinstance(module, (nn.Linear, nn.Conv2d)):
789 module.weight.data.normal_(mean=0.0, std=std)
790 if module.bias is not None:
791 module.bias.data.zero_()
792 elif isinstance(module, nn.Embedding):
793 module.weight.data.normal_(mean=0.0, std=std)
794 if module.padding_idx is not None:
795 module.weight.data[module.padding_idx].zero_()
796
797
798 class VisionTowerConfig(PretrainedConfig):
799 model_type = 'moonvit3d'
800
801 def __init__(self, config: KimiK25Config, **kwargs):
802 super().__init__(**kwargs)
803 self.patch_size = config.patch_size
804 self.init_pos_emb_height = config.init_pos_emb_height
805 self.init_pos_emb_width = config.init_pos_emb_width
806 self.init_pos_emb_time = config.init_pos_emb_time
807 self.pos_emb_type = config.pos_emb_type
808 self.num_attention_heads = config.vt_num_attention_heads
809 self.num_hidden_layers = config.vt_num_hidden_layers
810 self.hidden_size = config.vt_hidden_size
811 self.intermediate_size = config.vt_intermediate_size
812 self.merge_kernel_size = config.merge_kernel_size
813 self.video_attn_type = config.video_attn_type
814 self.merge_type = config.merge_type
815 self._attn_implementation = config._attn_implementation
816
817
818 class ProjectorConfig:
819
820 def __init__(self, config: KimiK25Config):
821 self.mm_projector_type = config.mm_projector_type
822 self.mm_hidden_size = config.mm_hidden_size
823 self.hidden_size = config.text_hidden_size
824 self.merge_kernel_size = config.merge_kernel_size
825 self.projector_hidden_act = config.projector_hidden_act
826 self.projector_ln_eps = config.projector_ln_eps
827
828
829 # ref https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/llava/modeling_llava.py#L240
830 class KimiK25ForConditionalGeneration(KimiK25PreTrainedModel):
831
832 def __init__(self, config: KimiK25Config):
833 super().__init__(config)
834
835 vt_config = VisionTowerConfig(config.vision_config)
836 self.vision_tower = MoonViT3dPretrainedModel(vt_config)
837
838 proj_config = ProjectorConfig(config.vision_config)
839 if proj_config.mm_projector_type == 'identity':
840 self.mm_projector = IdentityMap()
841 elif proj_config.mm_projector_type == 'mlp':
842 self.mm_projector = MLP(proj_config)
843 elif proj_config.mm_projector_type == 'patchmerger':
844 self.mm_projector = PatchMergerMLP(proj_config)
845 else:
846 raise ValueError(
847 f"Unsupported mm_projector_type: {proj_config.mm_projector_type}"
848 )
849
850 self.language_model = DeepseekV3ForCausalLM(config.text_config)
851 self.post_init()
852
853 if hasattr(self.language_model, 'dtype'):
854 target_dtype = self.language_model.dtype
855 self.vision_tower = self.vision_tower.to(dtype=target_dtype)
856 self.mm_projector = self.mm_projector.to(dtype=target_dtype)
857
858 def get_input_embeddings(self):
859 return self.language_model.get_input_embeddings()
860
861 def set_input_embeddings(self, value):
862 self.language_model.set_input_embeddings(value)
863
864 def get_output_embeddings(self):
865 return self.language_model.get_output_embeddings()
866
867 def set_output_embeddings(self, new_embeddings):
868 self.language_model.set_output_embeddings(new_embeddings)
869
870 def set_decoder(self, decoder):
871 self.language_model.set_decoder(decoder)
872
873 def get_decoder(self):
874 return self.language_model.get_decoder()
875
876 def tie_weights(self):
877 return self.language_model.tie_weights()
878
879 def resize_token_embeddings(self,
880 new_num_tokens: int | None = None,
881 pad_to_multiple_of=None) -> nn.Embedding:
882 model_embeds = self.language_model.resize_token_embeddings(
883 new_num_tokens, pad_to_multiple_of)
884 # update vocab size
885 self.config.text_config.vocab_size = model_embeds.num_embeddings
886 self.vocab_size = model_embeds.num_embeddings
887 return model_embeds
888
889 def _merge_input_ids_with_image_features(
890 self,
891 image_features: list[torch.Tensor],
892 inputs_embeds: torch.Tensor,
893 input_ids: torch.Tensor,
894 attention_mask: torch.Tensor,
895 labels: torch.Tensor | None = None,
896 ):
897 """
898 Args:
899 image_features (:obj:`torch.Tensor` of shape :obj:`(num_image_tokens, embed_dim)`):
900 The image features to merge with the input embeddings.
901 inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, embed_dim)`):
902 The input embeddings.
903 input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
904 The input ids.
905 attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
906 The attention mask.
907 labels (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, *optional*):
908 The labels.
909 """
910 _, embed_dim = image_features[0].shape
911 feature_lengths = [x.shape[0] for x in image_features]
912 image_features = torch.cat(image_features, dim=0)
913
914 image_token_index: int = self.config.media_placeholder_token_id
915 pad_token_id: int = self.config.pad_token_id
916 ignore_index: int = self.config.ignore_index
917
918 batch_size, sequence_length = input_ids.shape
919 left_padding = not torch.sum(
920 input_ids[:, -1] == torch.tensor(pad_token_id))
921
922 # 1. Create a mask to know where special image tokens are
923 _token_occupation_table = torch.ones_like(input_ids.flatten())
924 _token_occupation_table[input_ids.flatten() ==
925 image_token_index] = torch.tensor(
926 feature_lengths,
927 dtype=torch.long,
928 device=input_ids.device)
929 _token_occupation_table = _token_occupation_table.reshape(
930 input_ids.shape)
931
932 max_embed_dim = _token_occupation_table.sum(-1).max().item()
933 assert (
934 max_embed_dim >= sequence_length
935 ), f"The maximum embedding dimension ({max_embed_dim}) is less than the sequence length ({sequence_length})"
936 batch_indices, non_image_indices = torch.where(
937 input_ids != image_token_index)
938
939 # 2. Compute the positions where text should be written
940 # Calculate new positions for text tokens in merged image-text sequence.
941 new_token_positions = torch.cumsum(_token_occupation_table, -1) - 1
942 nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
943 if left_padding:
944 new_token_positions += nb_image_pad[:,
945 None] # offset for left padding
946 text_to_overwrite = new_token_positions[batch_indices,
947 non_image_indices]
948
949 # 3. Create the full embedding, already padded to the maximum position
950 final_embedding = torch.zeros(
951 batch_size,
952 max_embed_dim,
953 embed_dim,
954 dtype=inputs_embeds.dtype,
955 device=inputs_embeds.device,
956 )
957 final_attention_mask = torch.zeros(batch_size,
958 max_embed_dim,
959 dtype=attention_mask.dtype,
960 device=inputs_embeds.device)
961 if labels is not None:
962 final_labels = torch.full(
963 (batch_size, max_embed_dim),
964 ignore_index,
965 dtype=input_ids.dtype,
966 device=input_ids.device,
967 )
968 # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
969 # set the corresponding tensors into their correct target device.
970 target_device = inputs_embeds.device
971 batch_indices, non_image_indices, text_to_overwrite = (
972 batch_indices.to(target_device),
973 non_image_indices.to(target_device),
974 text_to_overwrite.to(target_device),
975 )
976 attention_mask = attention_mask.to(target_device)
977
978 # 4. Fill the embeddings based on the mask.
979 final_embedding[batch_indices,
980 text_to_overwrite] = inputs_embeds[batch_indices,
981 non_image_indices]
982 final_attention_mask[batch_indices,
983 text_to_overwrite] = attention_mask[
984 batch_indices, non_image_indices]
985 if labels is not None:
986 final_labels[batch_indices,
987 text_to_overwrite] = labels[batch_indices,
988 non_image_indices]
989
990 # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
991 image_to_overwrite = torch.full((batch_size, max_embed_dim),
992 True,
993 dtype=torch.bool,
994 device=inputs_embeds.device)
995 image_to_overwrite[batch_indices, text_to_overwrite] = False
996 image_to_overwrite &= image_to_overwrite.cumsum(
997 -1) - 1 >= nb_image_pad[:, None].to(target_device)
998
999 if image_to_overwrite.sum() != image_features.shape[:-1].numel():
1000 raise ValueError(
1001 f"The input provided to the model are wrong. The number of image tokens is {image_to_overwrite.sum()} while"
1002 f" the number of image features given to the model is {image_features.shape[:-1].numel()}. "
1003 "This prevents correct indexing and breaks batch generation.")
1004
1005 final_embedding[image_to_overwrite] = (
1006 image_features.contiguous().reshape(-1,
1007 embed_dim).to(target_device))
1008 final_attention_mask |= image_to_overwrite
1009 position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
1010 (final_attention_mask == 0), 1)
1011
1012 # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
1013 batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
1014 indices_to_mask = new_token_positions[batch_indices, pad_indices]
1015
1016 final_embedding[batch_indices, indices_to_mask] = 0
1017
1018 if labels is None:
1019 final_labels = None
1020
1021 return final_embedding, final_attention_mask, final_labels, position_ids
1022
1023 def _extract_image_features(self, pixel_values: torch.Tensor,
1024 grid_thws: torch.Tensor) -> list[torch.Tensor]:
1025 """
1026 Args:
1027 pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
1028 The pixel values of the images processed by image processor.
1029 grid_thws (:obj:`torch.Tensor` of shape :obj:`(batch_size, 3)`):
1030 The grid, height, width of the images.
1031
1032 Returns:
1033 selected_image_feature (:obj:`torch.FloatTensor` of shape :obj:`(num_image_tokens, embed_dim)`):
1034 The selected image features to use as input to the projector head.
1035
1036 """
1037
1038 target_dtype = self.vision_tower.patch_embed.proj.weight.dtype
1039 pixel_values = pixel_values.to(target_dtype)
1040
1041 image_features = self.vision_tower(pixel_values, grid_thws)
1042 return image_features
1043
1044 def forward(
1045 self,
1046 input_ids: torch.LongTensor | None = None,
1047 pixel_values: torch.FloatTensor | list[torch.FloatTensor]
1048 | None = None,
1049 grid_thws: torch.Tensor | None = None,
1050 attention_mask: torch.Tensor | None = None,
1051 position_ids: torch.LongTensor | None = None,
1052 past_key_values: list[torch.FloatTensor] | None = None,
1053 inputs_embeds: torch.FloatTensor | None = None,
1054 labels: torch.LongTensor | None = None,
1055 use_cache: bool | None = None,
1056 output_attentions: bool | None = None,
1057 output_hidden_states: bool | None = None,
1058 return_dict: bool | None = None,
1059 ) -> tuple | LlavaCausalLMOutputWithPast:
1060 r"""
1061 Args:
1062 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1063 Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1064 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1065 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1066
1067 ```"""
1068 assert self.vision_tower is not None, "vision_tower is not loaded"
1069 output_attentions = (output_attentions if output_attentions is not None
1070 else self.config.output_attentions)
1071 output_hidden_states = (output_hidden_states
1072 if output_hidden_states is not None else
1073 self.config.output_hidden_states)
1074 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1075
1076 if inputs_embeds is None:
1077 # 1. Extra the input embeddings
1078 inputs_embeds = self.get_input_embeddings()(input_ids)
1079
1080 # 2. Merge text and images
1081 if pixel_values is not None and len(
1082 pixel_values) > 0 and input_ids.shape[1] != 1:
1083 image_features = self._extract_image_features(
1084 pixel_values, grid_thws)
1085 if self.mm_projector:
1086 image_features = self.mm_projector(image_features)
1087
1088 inputs_embeds = inputs_embeds.to(
1089 image_features[0].dtype) # num_tokens, embed_dim
1090 inputs_embeds, attention_mask, labels, position_ids = (
1091 self._merge_input_ids_with_image_features(
1092 image_features,
1093 inputs_embeds,
1094 input_ids,
1095 attention_mask,
1096 labels,
1097 ))
1098
1099 # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1100 # generation with cache
1101 elif (past_key_values is not None and pixel_values is not None
1102 and input_ids.shape[1] == 1):
1103 # Retrieve the first layer to inspect the logits and mask out the hidden states
1104 # that are set to 0
1105 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
1106
1107 # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
1108 batch_index, non_attended_tokens = torch.where(
1109 first_layer_past_key_value.float().sum(-2) == 0)
1110
1111 # Get the target length
1112 target_length = input_ids.shape[1]
1113 past_length = first_layer_past_key_value.shape[-1]
1114
1115 extended_attention_mask = torch.ones(
1116 (attention_mask.shape[0], past_length),
1117 dtype=attention_mask.dtype,
1118 device=attention_mask.device,
1119 )
1120
1121 # Filter out only the tokens that can be un-attended, this can happen
1122 # if one uses Llava + Fused modules where the cache on the
1123 # first iteration is already big enough, or if one passes custom cache
1124 valid_indices = non_attended_tokens < extended_attention_mask.size(
1125 -1)
1126 new_batch_index = batch_index[valid_indices]
1127 new_non_attended_tokens = non_attended_tokens[valid_indices]
1128
1129 # Zero-out the places where we don't need to attend
1130 extended_attention_mask[new_batch_index,
1131 new_non_attended_tokens] = 0
1132
1133 attention_mask = torch.cat(
1134 (extended_attention_mask, attention_mask[:,
1135 -target_length:]),
1136 dim=1)
1137 position_ids = torch.sum(attention_mask,
1138 dim=1).unsqueeze(-1) - 1
1139
1140 outputs = self.language_model(
1141 attention_mask=attention_mask,
1142 position_ids=position_ids,
1143 past_key_values=past_key_values,
1144 inputs_embeds=inputs_embeds,
1145 use_cache=use_cache,
1146 output_attentions=output_attentions,
1147 output_hidden_states=output_hidden_states,
1148 return_dict=return_dict,
1149 )
1150
1151 logits = outputs[0]
1152
1153 loss = None
1154 if labels is not None:
1155 # Shift so that tokens < n predict n
1156 if attention_mask is not None:
1157 shift_attention_mask = attention_mask[..., 1:]
1158 shift_logits = logits[..., :-1, :][shift_attention_mask.to(
1159 logits.device) != 0].contiguous()
1160 shift_labels = labels[..., 1:][shift_attention_mask.to(
1161 labels.device) != 0].contiguous()
1162 else:
1163 shift_logits = logits[..., :-1, :].contiguous()
1164 shift_labels = labels[..., 1:].contiguous()
1165 # Flatten the tokens
1166 loss_fct = nn.CrossEntropyLoss()
1167 loss = loss_fct(
1168 shift_logits.view(-1, shift_logits.size(-1)),
1169 shift_labels.view(-1).to(shift_logits.device),
1170 )
1171
1172 if not return_dict:
1173 output = (logits, ) + outputs[1:]
1174 return (loss, ) + output if loss is not None else output
1175
1176 return LlavaCausalLMOutputWithPast(
1177 loss=loss,
1178 logits=logits,
1179 past_key_values=outputs.past_key_values,
1180 hidden_states=outputs.hidden_states,
1181 attentions=outputs.attentions,
1182 )
1183
1184 def prepare_inputs_for_generation(
1185 self,
1186 input_ids,
1187 past_key_values=None,
1188 inputs_embeds=None,
1189 pixel_values=None,
1190 grid_thws=None,
1191 attention_mask=None,
1192 **kwargs,
1193 ):
1194 if past_key_values is not None:
1195 if isinstance(past_key_values, Cache):
1196 cache_length = past_key_values.get_seq_length()
1197 past_length = getattr(past_key_values, 'seen_tokens',
1198 cache_length)
1199 else:
1200 cache_length = past_length = past_key_values[0][0].shape[2]
1201
1202 # Keep only the unprocessed tokens:
1203 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1204 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1205 # input)
1206 if attention_mask is not None and attention_mask.shape[
1207 1] > input_ids.shape[1]:
1208 input_ids = input_ids[:, -(attention_mask.shape[1] -
1209 past_length):]
1210 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1211 # input_ids based on the past_length.
1212 elif past_length < input_ids.shape[1]:
1213 input_ids = input_ids[:, past_length:]
1214 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1215 elif self.config.media_placeholder_token_id in input_ids:
1216 input_ids = input_ids[:, input_ids.shape[1] - 1:]
1217 # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1218 # older attention values, as their corresponding values are not part of the input.
1219 if cache_length < past_length and attention_mask is not None:
1220 attention_mask = attention_mask[:, -(cache_length +
1221 input_ids.shape[1]):]
1222
1223 position_ids = kwargs.get("position_ids", None)
1224 if attention_mask is not None and position_ids is None:
1225 # create position_ids on the fly for batch generation
1226 position_ids = attention_mask.long().cumsum(-1) - 1
1227 position_ids.masked_fill_(attention_mask == 0, 1)
1228 if past_key_values:
1229 position_ids = position_ids[:, -input_ids.shape[1]:]
1230
1231 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1232 if inputs_embeds is not None and past_key_values is None:
1233 model_inputs = {"inputs_embeds": inputs_embeds}
1234 else:
1235 model_inputs = {"input_ids": input_ids}
1236
1237 model_inputs.update({
1238 "position_ids": position_ids,
1239 "past_key_values": past_key_values,
1240 "use_cache": kwargs.get("use_cache"),
1241 "attention_mask": attention_mask,
1242 "pixel_values": pixel_values,
1243 "grid_thws": grid_thws,
1244 })
1245 return model_inputs
1246
1247 def _reorder_cache(self, *args, **kwargs):
1248 return self.language_model._reorder_cache(*args, **kwargs)
1249