Uni-Mol微调(finetune)过程
过程 uni 微调
2023-09-14 09:14:41 时间
# 是的,上述代码可以看作是一种finetune过程。 # 该代码定义了一个分类头(ClassificationHead)用于对预训练模型进行微调,在微调过程中,使用新的训练数据对分类头进行训练,以实现新任务的分类目标。 # 具体地,register_classification_head方法用于注册一个分类头到预训练模型中,而build_model方法则是用于构建完整的微调模型。 # 在构建微调模型时,首先使用预训练模型的权重初始化微调模型,然后将分类头添加进去,使得微调模型可以同时实现预测任务和分类任务。
finetune使用的模块
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim,
inner_dim,
num_classes,
activation_fn,
pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
预训练
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = ClassificationHead(
input_dim=self.args.encoder_embed_dim,
inner_dim=inner_dim or self.args.encoder_embed_dim,
num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout,
)
def build_model(self, args):
from unicore import models
model = models.build_model(args, self)
model.register_classification_head(
self.args.classification_head_name,
num_classes=self.args.num_classes,
)
return model
相关文章
- BM3D 算法原理详细解析 按过程步骤讲解(附C++实现代码)[通俗易懂]
- uni开发app用什么调试方便_配置台式机后调试过程
- 一个请求的组成、静态页面和动态页面、HTML, CSS和JS、浏览器渲染的过程
- SQL数据库存储过程示例解析
- 最佳实践Exploring Optimal Practices for Running Oracle Stored Procedures(oracle存储过程运行)
- 如何安全高效地备份Oracle存储过程?(备份oracle存储过程)
- MSSQL端口号挖掘:探究过程中的那些谜团(mssql端口号是多少)
- Oracle存储过程提高程序执行效率(9.oracle存储过程)
- Oracle事务从执行到完成(oracle事务执行过程)
- Oracle AB表之间的切换过程(oracle ab表切换)