如果已经训练好了一个模型,你就可以save和load这模型。
For saving and loading models in PyTorch, there are three main methods you should be aware of.
在 PyTorch 中,pickle
是一个用于序列化和反序列化Python对象的标准库模块。它可以将Python对象转换为字节流 (即序列化),并将字节流转换回Python对象 (即反序列化)。pickle
模块在很多情况下都非常有用,特别是在保存和加载模型,保存训练中间状态等方面。
在深度学习中,经常需要保存训练好的模型或者训练过程中的中间结果,以便后续的使用或分析。PyTorch提高了方便的API来保存和加载模型,其中就包括了使用pickle
模块进行对象的序列化和反序列化。
save model
import torch
from pathlib import Path
# 1. Create models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)
# 2. Create model save path
MODEL_NAME = "trained_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
# 3. Save the model state dict
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj = model_0.state_dict(),
f = MODEL_SAVE_PATH)
就能看到 trained_model.pth 文件下载到所属的文件夹位置。
看到这了,点个赞呗~