configuration_kimi_k25.py
5.3 KB · 124 lines · python Raw
1 from transformers.configuration_utils import PretrainedConfig
2
3 try:
4 from configuration_deepseek import DeepseekV3Config
5 except ImportError:
6 from .configuration_deepseek import DeepseekV3Config
7
8
9 class KimiK25VisionConfig(PretrainedConfig):
10
11 def __init__(
12 self,
13 patch_size: int = 14,
14 init_pos_emb_height: int = 64,
15 init_pos_emb_width: int = 64,
16 init_pos_emb_time: int = 4,
17 pos_emb_type: str = 'divided_fixed',
18 vt_num_attention_heads: int = 16,
19 vt_num_hidden_layers: int = 27,
20 vt_hidden_size: int = 1152,
21 vt_intermediate_size: int = 4304,
22 merge_kernel_size: tuple = (2, 2),
23 video_attn_type: str = 'spatial_temporal',
24 merge_type: str = 'sd2_tpool',
25 _attn_implementation: str = 'flash_attention_2',
26 # MM Projector parameters
27 mm_projector_type: str = 'patchmerger',
28 mm_hidden_size: int | None = None,
29 projector_hidden_act: str = "gelu",
30 projector_ln_eps: float = 1e-5,
31 # Other parameters
32 ignore_index: int = -100,
33 media_placeholder_token_id: int = 163605,
34 pad_token_id: int = 0,
35 use_unified_vision_chunk: bool = True,
36 video_placeholder="<|kimi_k25_video_placeholder|>",
37 text_hidden_size=7168,
38 **vision_config_kwargs):
39
40 self.patch_size = patch_size
41 self.init_pos_emb_height = init_pos_emb_height
42 self.init_pos_emb_width = init_pos_emb_width
43 self.init_pos_emb_time = init_pos_emb_time
44 self.pos_emb_type = pos_emb_type
45 self.vt_num_attention_heads = vt_num_attention_heads
46 self.vt_num_hidden_layers = vt_num_hidden_layers
47 self.vt_hidden_size = vt_hidden_size
48 self.vt_intermediate_size = vt_intermediate_size
49 self.merge_kernel_size = merge_kernel_size
50 self.video_attn_type = video_attn_type
51 self.merge_type = merge_type
52 self._attn_implementation = _attn_implementation
53
54 # MM Projector config
55 self.mm_projector_type = mm_projector_type
56 self.mm_hidden_size = mm_hidden_size if mm_hidden_size is not None else vt_hidden_size
57 self.projector_hidden_act = projector_hidden_act
58 self.projector_ln_eps = projector_ln_eps
59 self.text_hidden_size = text_hidden_size
60
61
62 class KimiK25Config(PretrainedConfig):
63 """Kimi-K2.5 model configuration.
64
65 Args:
66 text_config (dict | DeepseekV3Config): Configuration for the text model.
67
68 Vision Tower Parameters (from MoonViT3dConfig):
69 patch_size (int): Patch size for vision tower.
70 init_pos_emb_height (int): Initial position embedding height.
71 init_pos_emb_width (int): Initial position embedding width.
72 init_pos_emb_time (int): Initial position embedding time dimension.
73 pos_emb_type (str): Type of position embedding.
74 vt_num_attention_heads (int): Number of attention heads in vision tower.
75 vt_num_hidden_layers (int): Number of hidden layers in vision tower.
76 vt_hidden_size (int): Hidden size of vision tower.
77 vt_intermediate_size (int): Intermediate size in vision tower FFN.
78 merge_kernel_size (tuple): Kernel size for patch merging.
79 video_attn_type (str): Type of video attention.
80 merge_type (str): Type of merge operation.
81 _attn_implementation (str): Attention implementation type.
82
83 MM Projector Parameters (from MultiModalProjectorConfig):
84 mm_projector_type (str): Type of multimodal projector.
85 mm_hidden_size (int): Hidden size from vision tower (should match vt_hidden_size).
86 projector_hidden_act (str): Activation function for projector.
87 projector_ln_eps (float): Layer norm epsilon for projector.
88
89 Other Parameters:
90 ignore_index (int): The ignore index for the loss function.
91 media_placeholder_token_id (int): The token ID to use for media placeholders.
92 pad_token_id (int): The token ID to use for padding.
93 """
94
95 model_type = "kimi_k25"
96
97 def __init__(
98 self,
99 text_config: dict | DeepseekV3Config = None,
100 vision_config: dict | KimiK25VisionConfig = None,
101 # Other parameters
102 ignore_index: int = -100,
103 media_placeholder_token_id: int = 163605,
104 pad_token_id: int = 0,
105 use_unified_vision_chunk: bool = True,
106 video_placeholder="<|kimi_k25_video_placeholder|>",
107 **kwargs,
108 ):
109 if isinstance(text_config, dict):
110 text_config = DeepseekV3Config(**text_config)
111 if isinstance(vision_config, dict):
112 vision_config = KimiK25VisionConfig(**vision_config)
113 self.text_config = text_config
114 self.vision_config = vision_config
115 # Other config
116 self.ignore_index = ignore_index
117 self.media_placeholder_token_id = media_placeholder_token_id
118 self.use_unified_vision_chunk = use_unified_vision_chunk
119 self.video_placeholder = video_placeholder
120 if getattr(self.text_config, "quantization_config", None) is not None:
121 self.quantization_config = self.text_config.quantization_config
122
123 super().__init__(pad_token_id=pad_token_id, **kwargs)
124