inference/convert.py
| 1 | import os |
| 2 | import shutil |
| 3 | from argparse import ArgumentParser |
| 4 | from glob import glob |
| 5 | from tqdm import tqdm, trange |
| 6 | |
| 7 | import torch |
| 8 | from safetensors.torch import safe_open, save_file |
| 9 | |
| 10 | |
| 11 | mapping = { |
| 12 | "embed_tokens": ("embed", 0), |
| 13 | "input_layernorm": ("attn_norm", None), |
| 14 | "post_attention_layernorm": ("ffn_norm", None), |
| 15 | "q_proj": ("wq", 0), |
| 16 | "q_a_proj": ("wq_a", None), |
| 17 | "q_a_layernorm": ("q_norm", None), |
| 18 | "q_b_proj": ("wq_b", 0), |
| 19 | "kv_a_proj_with_mqa": ("wkv_a", None), |
| 20 | "kv_a_layernorm": ("kv_norm", None), |
| 21 | "kv_b_proj": ("wkv_b", 0), |
| 22 | "o_proj": ("wo", 1), |
| 23 | "gate": ("gate", None), |
| 24 | "gate_proj": ("w1", 0), |
| 25 | "down_proj": ("w2", 1), |
| 26 | "up_proj": ("w3", 0), |
| 27 | "norm": ("norm", None), |
| 28 | "lm_head": ("head", 0), |
| 29 | "scale": ("scale", None), |
| 30 | "wq_b": ("wq_b", None), |
| 31 | "wk": ("wk", None), |
| 32 | "k_norm": ("k_norm", None), |
| 33 | "weights_proj": ("weights_proj", None), |
| 34 | } |
| 35 | |
| 36 | |
| 37 | def main(hf_ckpt_path, save_path, n_experts, mp): |
| 38 | """ |
| 39 | Converts and saves model checkpoint files into a specified format. |
| 40 | |
| 41 | Args: |
| 42 | hf_ckpt_path (str): Path to the directory containing the input checkpoint files. |
| 43 | save_path (str): Path to the directory where the converted checkpoint files will be saved. |
| 44 | n_experts (int): Total number of experts in the model. |
| 45 | mp (int): Model parallelism factor. |
| 46 | |
| 47 | Returns: |
| 48 | None |
| 49 | """ |
| 50 | torch.set_num_threads(8) |
| 51 | n_local_experts = n_experts // mp |
| 52 | state_dicts = [{} for _ in range(mp)] |
| 53 | |
| 54 | for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): |
| 55 | with safe_open(file_path, framework="pt", device="cpu") as f: |
| 56 | for name in f.keys(): |
| 57 | if "model.layers.61" in name: |
| 58 | continue |
| 59 | param: torch.Tensor = f.get_tensor(name) |
| 60 | if name.startswith("model."): |
| 61 | name = name[len("model."):] |
| 62 | name = name.replace("self_attn", "attn") |
| 63 | name = name.replace("mlp", "ffn") |
| 64 | name = name.replace("weight_scale_inv", "scale") |
| 65 | name = name.replace("e_score_correction_bias", "bias") |
| 66 | key = name.split(".")[-2] |
| 67 | assert key in mapping, f"Key {key} not found in mapping" |
| 68 | new_key, dim = mapping[key] |
| 69 | name = name.replace(key, new_key) |
| 70 | for i in range(mp): |
| 71 | new_param = param |
| 72 | if "experts" in name and "shared_experts" not in name: |
| 73 | idx = int(name.split(".")[-3]) |
| 74 | if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: |
| 75 | continue |
| 76 | elif dim is not None: |
| 77 | assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" |
| 78 | shard_size = param.size(dim) // mp |
| 79 | new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() |
| 80 | state_dicts[i][name] = new_param |
| 81 | |
| 82 | os.makedirs(save_path, exist_ok=True) |
| 83 | |
| 84 | for i in trange(mp): |
| 85 | save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) |
| 86 | |
| 87 | for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): |
| 88 | new_file_path = os.path.join(save_path, os.path.basename(file_path)) |
| 89 | shutil.copyfile(file_path, new_file_path) |
| 90 | |
| 91 | |
| 92 | if __name__ == "__main__": |
| 93 | parser = ArgumentParser() |
| 94 | parser.add_argument("--hf-ckpt-path", type=str, required=True) |
| 95 | parser.add_argument("--save-path", type=str, required=True) |
| 96 | parser.add_argument("--n-experts", type=int, required=True) |
| 97 | parser.add_argument("--model-parallel", type=int, required=True) |
| 98 | args = parser.parse_args() |
| 99 | assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" |
| 100 | main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) |
| 101 | |