news 2025/12/30 19:07:13

大模型高效微调--P-Tuning v2

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
大模型高效微调--P-Tuning v2

文章目录

      • P-Tuning v2 概述
      • 核心改进
      • 关键技术细节
      • 代码示例
      • 性能对比
      • 局限性

https://github.com/THUDM/P-tuning-v2

P-Tuning v2 概述

P-Tuning v2 是清华大学团队提出的一种参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)方法,旨在改进传统微调方法在大型预训练语言模型(如GPT、BERT)上的效率和性能。它是P-Tuning的升级版本,通过优化提示(Prompt)设计和参数更新策略,显著提升了模型在低资源场景下的表现。

核心改进

连续提示优化
P-Tuning v2 引入了可训练的连续提示(Continuous Prompts),取代了传统离散提示。这些提示以嵌入向量的形式插入到模型的输入层或中间层,通过梯度下降动态调整,避免了人工设计提示的局限性。

分层提示注入
与P-Tuning仅在输入层添加提示不同,P-Tuning v2 在模型的每一层(或关键层)注入提示向量,形成分层提示结构。这种设计能更深度地引导模型行为,尤其适合深层Transformer架构。

参数效率提升
P-Tuning v2 仅需微调少量额外参数(通常占模型总参数的0.1%-1%),大幅降低了计算和存储开销,同时保持了与全参数微调相近的性能。

关键技术细节

提示向量初始化
提示向量通常随机初始化或从任务相关词嵌入中采样。实验表明,合理的初始化能加速收敛并提升最终效果。

训练目标
P-Tuning v2 通过标准的下游任务损失(如交叉熵)优化提示参数,同时可结合适配器(Adapter)或LoRA等轻量级模块进一步减少可训练参数。

适用场景

  • 小样本学习(Few-shot Learning)
  • 多任务学习(通过不同提示区分任务)
  • 资源受限的设备部署

代码示例

P-Tuning v2的核心逻辑:

importtorchclassPrefixEncoder(torch.nn.Module):r''' The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size, prefix-length, 2*layers*hidden) '''def__init__(self,config):super().__init__()self.prefix_projection=config.prefix_projectionifself.prefix_projection:# Use a two-layer MLP to encode the prefixself.embedding=torch.nn.Embedding(config.pre_seq_len,config.hidden_size)self.trans=torch.nn.Sequential(torch.nn.Linear(config.hidden_size,config.prefix_hidden_size),torch.nn.Tanh(),torch.nn.Linear(config.prefix_hidden_size,config.num_hidden_layers*2*config.hidden_size))else:self.embedding=torch.nn.Embedding(config.pre_seq_len,config.num_hidden_layers*2*config.hidden_size)defforward(self,prefix:torch.Tensor):ifself.prefix_projection:prefix_tokens=self.embedding(prefix)past_key_values=self.trans(prefix_tokens)else:past_key_values=self.embedding(prefix)returnpast_key_values

  • https://github.com/THUDM/P-tuning-v2/blob/main/model/token_classification.py
classBertPrefixForTokenClassification(BertPreTrainedModel):def__init__(self,config):super().__init__(config)self.num_labels=config.num_labels self.bert=BertModel(config,add_pooling_layer=False)self.dropout=torch.nn.Dropout(config.hidden_dropout_prob)self.classifier=torch.nn.Linear(config.hidden_size,config.num_labels)from_pretrained=Falseiffrom_pretrained:self.classifier.load_state_dict(torch.load('model/checkpoint.pkl'))forparaminself.bert.parameters():param.requires_grad=Falseself.pre_seq_len=config.pre_seq_len self.n_layer=config.num_hidden_layers self.n_head=config.num_attention_heads self.n_embd=config.hidden_size//config.num_attention_heads self.prefix_tokens=torch.arange(self.pre_seq_len).long()self.prefix_encoder=PrefixEncoder(config)bert_param=0forname,paraminself.bert.named_parameters():bert_param+=param.numel()all_param=0forname,paraminself.named_parameters():all_param+=param.numel()total_param=all_param-bert_paramprint('total param is {}'.format(total_param))# 9860105defget_prompt(self,batch_size):prefix_tokens=self.prefix_tokens.unsqueeze(0).expand(batch_size,-1).to(self.bert.device)past_key_values=self.prefix_encoder(prefix_tokens)# bsz, seqlen, _ = past_key_values.shapepast_key_values=past_key_values.view(batch_size,self.pre_seq_len,self.n_layer*2,self.n_head,self.n_embd)past_key_values=self.dropout(past_key_values)past_key_values=past_key_values.permute([2,0,3,1,4]).split(2)returnpast_key_valuesdefforward(self,input_ids=None,attention_mask=None,token_type_ids=None,position_ids=None,head_mask=None,inputs_embeds=None,labels=None,output_attentions=None,output_hidden_states=None,return_dict=None,):return_dict=return_dictifreturn_dictisnotNoneelseself.config.use_return_dict batch_size=input_ids.shape[0]past_key_values=self.get_prompt(batch_size=batch_size)prefix_attention_mask=torch.ones(batch_size,self.pre_seq_len).to(self.bert.device)attention_mask=torch.cat((prefix_attention_mask,attention_mask),dim=1)outputs=self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,past_key_values=past_key_values,)sequence_output=outputs[0]sequence_output=self.dropout(sequence_output)logits=self.classifier(sequence_output)attention_mask=attention_mask[:,self.pre_seq_len:].contiguous()loss=NoneiflabelsisnotNone:loss_fct=CrossEntropyLoss()# Only keep active parts of the lossifattention_maskisnotNone:active_loss=attention_mask.view(-1)==1active_logits=logits.view(-1,self.num_labels)active_labels=torch.where(active_loss,labels.view(-1),torch.tensor(loss_fct.ignore_index).type_as(labels))loss=loss_fct(active_logits,active_labels)else:loss=loss_fct(logits.view(-1,self.num_labels),labels.view(-1))ifnotreturn_dict:output=(logits,)+outputs[2:]return((loss,)+output)iflossisnotNoneelseoutputreturnTokenClassifierOutput(loss=loss,logits=logits,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)

性能对比

在SuperGLUE基准测试中,P-Tuning v2 仅微调0.5%参数时,性能可达全参数微调的90%以上,同时训练速度提升3-5倍。对于超大规模模型(如百亿参数),其优势更加显著。

局限性

  • 提示长度和层数需通过实验调优
  • 对某些需要全局参数调整的任务(如文本生成)可能需结合其他PEFT方法

参考: https://github.com/zejunwang1/chatglm_tuning/blob/main/train_ptuning.py

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/27 18:08:22

7、PowerShell代码签名:保障脚本安全的全面指南

PowerShell代码签名:保障脚本安全的全面指南 1. 代码签名的重要性 将代码与其创建和发布实体关联起来,能够消除运行代码的匿名性。给代码签名证书添加数字签名,就像使用品牌名称来建立信任和可靠性一样。PowerShell脚本和配置文件的用户可以根据这些信息,明智地决定是否运…

作者头像 李华
网站建设 2025/12/27 19:34:17

12、网络带宽与 Windows Server 2003 相关技术解析

网络带宽与 Windows Server 2003 相关技术解析 一、提升带宽上限的必要性 随着组织对局域网(LANs)和广域网(WANs)的依赖程度不断加深,更多的应用程序和信息被部署到网络中。对于这些组织而言,快速检索信息变得至关重要,而这也正是对额外带宽需求最为常见的体现。 传统…

作者头像 李华
网站建设 2025/12/29 21:40:38

17、Windows Server 2003 Active Directory 部署与管理全解析

Windows Server 2003 Active Directory 部署与管理全解析 1. Active Directory 规划 在使用 Windows NT 系统时,可能存在多个具有信任关系的域。理论上可以直接升级各域并保留现有信任关系,但这样会失去 Active Directory 的优势。而运行 Windows 2000 域时,也需要进行一定…

作者头像 李华
网站建设 2025/12/28 7:55:49

Linly-Talker支持多语言输出:全球化数字人布局利器

Linly-Talker:如何用一张照片打造全球化的智能数字人? 在跨境电商直播间里,一位说着流利阿拉伯语的虚拟主播正微笑着介绍产品;远在东南亚的用户无需等待翻译,就能听到母语级别的客服回应。这背后并非庞大的制作团队&am…

作者头像 李华
网站建设 2025/12/28 23:24:41

4、PowerShell 深入解析与实践指南

PowerShell 深入解析与实践指南 1. 别名使用注意事项 在 PowerShell 中定义别名时,并非所有人都与你有相同的逻辑。若想让他人理解你的脚本,使用别名时需谨慎,避免过多使用,可考虑创建可复用的函数。创建脚本别名时,应使用易于他人理解的名称,例如,除了对脚本进行编码…

作者头像 李华
网站建设 2025/12/29 2:31:29

Linly-Talker在金融客服中的应用案例分享

Linly-Talker在金融客服中的应用案例分享 在银行网点排长队咨询理财产品、深夜想查账单却找不到人工客服——这些场景正逐渐成为过去。随着金融服务向全天候、个性化和高效率演进,传统客服模式的短板日益凸显:人力成本居高不下、服务时间受限、响应延迟…

作者头像 李华