首先看下有关的函数:
- torch.save: 将一个文件保存到硬盘上,内部是用了pickle库
- troch.load:用的pickle的unpicking方法将存储在硬盘上的object读取到内存中
- torch.nn.Module.load_state_dict:从一个state_dict中加载一个模型的参数
什么是state_dict:
pytorch中的每个module的可学习的参数:如权重和bias等都在module.parameters()里面。
一个state_dict简单来说就是一个字典object,可以把每一层映射到他的参数上去。可学习参数以及register buffer(bn)已经优化器都有state_dict。因为state_dict是python字典对象,因此很简单就可以保存,修改。
下面是读取模型state_dict的例子:
1 | for param_tensor in model.state_dict(): |
两种方法
回到正题,有两种方法可以保存和读取模型。
第一种是通过模型的state_dict来进行读取和保存。特别是读取的时候,首先得新建一个模型object,然后加载参数。
Save:
1 | torch.save(model.state_dict(), PATH) |
Load:
1 | model = TheModelClass(*args, **kwargs) |
第二种方法之别保存和加载整个模型:
Save:
1 | torch.save(model, PATH) |
Load:
1 | # Model class must be defined somewhere |
这种方法的缺点是序列化数据绑定到特定类以及保存模型时使用的确切目录结构。这是因为pickle不保存模型类本身。相反,它会保存包含类的文件的路径,该文件在加载时使用。因此,当您在其他项目中或在重构之后使用时,您的代码可能会以各种方式中断。
保存checkpoint
可以保存checkpoint用于后续的推理和重新训练。和单独保存模型的参数不同,优化器的参数也会被保存,以便于后续的训练。
Save:
1 | torch.save({ |
Load:
1 | model = TheModelClass(*args, **kwargs) |
Reference
https://pytorch.org/tutorials/beginner/saving_loading_models.html