前言

深度学习框架提供了内置函数来保存和加载整个网络。需要注意的是,这将保存模型的参数而不是整个模型。

加载和保存

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden=nn.Linear(20,256)
        self.output=nn.Linear(256,10)

    def forward(self,X):
        return self.output(F.relu(self.hidden(X)))

net=MLP()
X=torch.randn(size=(2,20))
Y=net(X)

torch.save(net.state_dict(),'mlp.param')
clone=MLP()
clone.load_state_dict(torch.load('mlp.param'))
01-10 08:49