torch.save()和torch.load()是PyTorch中用于模型保存和加载的函数。它们提供了一种方便的方式来保存和恢复模型的状态、结构和参数。可以使用它们来保存和加载整个模型或其他任意的Python对象,并且可以在加载模型时指定目标设备。

1.语法介绍

1.1 torch.save()语法

        torch.save()函数用于将PyTorch模型保存到磁盘上的文件中,以便以后可以重新加载和使用。它的基本语法如下:

torch.save(obj, f, pickle_module=<module 'pickle' from '...'>, pickle_protocol=2)

        参数说明:

                obj是要保存的对象,通常是一个模型的状态字典(state_dict())。

                f是文件的路径或文件对象,用于存储模型。

                pickle_module是用于序列化的Python模块,默认为pickle。

                pickle_protocol是序列化时使用的协议版本,默认为2。

1.2 torch.load()语法

        torch.load()函数用于从磁盘上的文件加载保存的模型。它的基本语法如下:

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '...'>)

        参数说明:

                f是要加载的文件的路径或文件对象。

                map_location用于指定加载模型的设备(CPU或特定的GPU设备)。默认情况下,加载的模型将被存储在与保存模型时相同的设备上。

                pickle_module是用于反序列化的Python模块,默认为pickle。

2. 基本使用示例介绍

2.1 保存和加载整个模型

        除了保存和加载模型的状态字典外,torch.save()和torch.load()还可以用于保存和加载整个模型,包括模型的结构、参数和其他相关信息。

        要保存整个模型,使用以下代码:

torch.save(model, 'model.pth')

        要加载整个模型,使用以下代码: 

model = torch.load('model.pth')

        注意,加载整个模型时,需要确保模型的定义代码可用,因为它将用于重新创建模型的结构。

2.2 保存和加载其他对象

        torch.save()和torch.load()不仅限于保存和加载模型,还可以用于保存和加载其他任意的Python对象。只需将要保存的对象传递给torch.save(),然后使用torch.load()来加载该对象。

        例如:

data = [1, 2, 3, 4, 5]
torch.save(data, 'data.pth')

loaded_data = torch.load('data.pth')

        这样可以方便地保存和加载各种数据,如训练集、测试集、预处理数据等。 

 2.3 跨设备加载模型

        torch.load()函数允许在加载模型时指定目标设备。通过使用map_location参数,可以将模型加载到不同的设备上,例如从GPU加载到CPU或从一种GPU加载到另一种GPU。

        以下是一个示例:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 从GPU加载到CPU
model = torch.load('model.pth', map_location='cpu')

# 从一种GPU加载到另一种GPU
model = torch.load('model.pth', map_location='cuda:1')

         这对于在不同设备上运行模型或在没有GPU的机器上加载训练好的GPU模型非常有用。

2.4 序列化兼容性

        torch.save()使用Python的pickle模块进行序列化,默认使用协议版本2。这个默认版本在PyTorch 1.6及更高版本中是兼容的。如果您需要与旧版本的PyTorch或其他Python库进行兼容,您可以通过设置pickle_protocol参数来选择不同的协议版本。 

torch.save(model.state_dict(), 'model.pth', pickle_protocol=4)

        在选择协议版本时,需要权衡序列化的性能和兼容性。 

3. 模型保存和加载

        当涉及到模型保存和加载时,还有一些其他的注意事项和用法:

3.1 保存和加载模型的状态字典

        通常情况下,我们只保存和加载模型的状态字典(state_dict()),而不是整个模型。状态字典包含了模型的参数和缓冲区(如权重和偏置),但不包括模型的结构。这种做法更加灵活,因为它允许在加载模型时自由选择模型的结构,并且可以与不同的模型架构进行兼容。

#保存模型的状态字典:
torch.save(model.state_dict(), 'model.pth')

#加载模型的状态字典:
model = MyModel()
model.load_state_dict(torch.load('model.pth'))

        请确保在加载模型之前,模型的定义与保存时的模型结构相匹配。 

3.2 冻结某些层或参数

        在某些情况下,可能希望冻结模型的某些层或参数,即在加载模型后不更新它们的参数。可以通过设置参数的requires_grad属性来实现这一点。

        例如,假设模型有一个名为fc的全连接层,您可以冻结该层的参数:

model = MyModel()
model.load_state_dict(torch.load('model.pth'))

# 冻结全连接层的参数
for param in model.fc.parameters():
param.requires_grad = False

3.3 多个模型的保存和加载

        如果您需要保存和加载多个模型,您可以将它们保存为一个字典,并使用一个文件来存储整个字典。

        保存多个模型:

state = {
    'model1': model1.state_dict(),
    'model2': model2.state_dict()
}
torch.save(state, 'models.pth')

        加载多个模型: 

state = torch.load('models.pth')
model1.load_state_dict(state['model1'])
model2.load_state_dict(state['model2'])

         这种方法可以方便地保存和加载多个相关模型。

 3.3 保存和加载检查点

        在训练过程中,可以定期保存模型的检查点,以便在训练过程中发生意外情况时能够恢复模型。通过定期保存检查点,可以避免从头开始训练,并从最新的检查点继续训练。

# 训练循环中的保存检查点
if epoch % checkpoint_interval == 0:
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, f'checkpoint_{epoch}.pth')

        在发生中断或需要恢复训练时,可以加载最新的检查点: 

# 加载最新的检查点
latest_checkpoint = get_latest_checkpoint(checkpoint_dir)
checkpoint = torch.load(latest_checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

        这样,可以从最新的检查点恢复训练。 

 

 

 

 

 

 

 

 

 

 

 

 

 

03-22 13:33