zl程序教程

您现在的位置是:首页 >  其他

当前栏目

加载部分神经网络预训练参数后改写网络的方法

训练方法网络神经网络 参数 加载 部分 改写
2023-09-14 09:05:38 时间

对于pytorch直接

checkpoint = 'D:/Speech-Transformer/BEST_checkpoint_85.tar'
checkpoint = torch.load(checkpoint, map_location='cpu')

model = checkpoint['model']
model.encoder.requires_grad_=False
model.decoder=decoder
model.to(device)

class model():
    def __init__(self):
        self.txt="初始值"
    def forward(self):
        print("原函数")
        print("初始值",self.txt)

def df(self):
    # 这个函数魔改前向函数
    print(self.txt,"我是替代函数")

net=model()
# 这里可以加载预训练模型
net.forward()
# 任意改变初始化定义
net.forward=df
net.forward(net)



if __name__ == '__main__':
    pass