qwen2.5 填充权重

填充权重

# https://github.com/QwenLM/Qwen2.5/issues/578
def padding_and_saving_weight(model: torch.nn.Module, output_dir: str):
    assert model.config.intermediate_size == 29568, "intermediate_size 不是 29568"
    
    pad_size = 128
    old_intermediate_size=29568
    new_intermediate_size=old_intermediate_size+pad_size

    exponent = math.log(pad_size, 2)
    assert exponent.is_integer(), f"{pad_size} 不是2的次方数"
    exponent = int(exponent)

    assert (old_intermediate_size/pad_size).is_integer(), f"{old_intermediate_size} 不能被 {pad_size} 整除"

    need_pad_values = [int(old_intermediate_size // (2 ** i)) for i in range(exponent + 1)]

    sd = model.state_dict()
    for i, k in enumerate(sd):
        v = sd[k]
        if len(v.shape) == 2 and ( ('mlp.up_proj.' in k) or ('mlp.gate_proj.' in k) or ('mlp.down_proj.' in k)):
            if v.shape[0] in need_pad_values :
                need_pad_size = v.shape[0]*new_intermediate_size/old_intermediate_size - v.shape[0]
                assert need_pad_size.is_integer() , f"{need_pad_size} 不是整数"
                need_pad_size = int(need_pad_size)
                
                prev_v = F.pad(v.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(v.shape[0]*2, -1)[:need_pad_size*2]
                new_v = torch.cat([prev_v, v[need_pad_size:]], dim=0)
                sd[k] = new_v
                print(k, i, v.shape, '-->', new_v.shape)
            elif   v.shape[1] in need_pad_values:
                need_pad_size = v.shape[1]*new_intermediate_size/old_intermediate_size - v.shape[1]
                assert need_pad_size.is_integer() , f"{need_pad_size} 不是整数"
                need_pad_size = int(need_pad_size)
                
                prev_v= F.pad(v.unsqueeze(2), (0, 1)).reshape(v.shape[0], v.shape[1]*2)[:, :need_pad_size*2]
                new_v = torch.cat([prev_v, v[:, need_pad_size:]], dim=1)
                sd[k] = new_v
                print(k, i, v.shape, '-->', new_v.shape)
                
    model.config.intermediate_size=new_intermediate_size
    model.save_pretrained(output_dir, state_dict=sd, max_shard_size="4GB", safe_serialization=True)
    
    

最后更新于