sunpongber

网络模型的保存与读取

模型的保存
模型的加载

完整的模型训练套路-GPU训练
完整的模型验证套路

再来看一下GitHub

PyTorch官网

模型保存

import torch
import torchvision

# vgg16_false = torchvision.models.vgg16(pretrained = False)
vgg16 = torchvision.models.vgg16(weights=None)

# 保存方式1模型结构+模型参数
# torch.save(vgg16_false, "vgg16_method0.pth")
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式2模型参数官方推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 把vgg16的状态保存成字典形式字典是python的一种数据格式
# 相当于把vgg16网络模型中的参数保存成字典现在不保存结构了就保存一些参数

对于比较大的模型,保存方式2用的空间肯定更小
在终端输入ls -all,我们使用dir

...
2025/06/03  15:38       553,450,705 vgg16_method1.pth
2025/06/03  15:38       553,441,041 vgg16_method2.pth
...

重点看这两种方式保存的这两个文件,明显可以看出方式2保存的文件比方式1保存的文件要小

模型加载

import torch

# 方式1 -> 保存方式1加载模型
model = torch.load("vgg16_method1.pth")
print(model)

# 方式2 -> 保存方式2加载模型
model = torch.load("vgg16_method2.pth")
print(model)

保存方式1保存的模型,可以比较明显的看出,加载的模型参数与保存的模型参数一样
保存方式2保存的模型,按照字典的方式保存,返回信息太长,可自行验证

那如何恢复成网络模型的形式?

import torch
import torchvision

# 方式1 -> 保存方式1加载模型
model = torch.load("vgg16_method1.pth")
print(model)

# 方式2 -> 保存方式2加载模型
vgg16 = torchvision.models.vgg16(weights=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
# print(model)
print(vgg16)

方式1还是有陷阱的,保存

import torch
import torchvision
from torch import nn

# 陷阱
class Pongber(nn.Module):
    def __init__(self):
        super(Pongber, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)

    def forward(self, x):
        x = self.conv1(x)
        return x
pongber = Pongber()
torch.save(pongber, "pongber_method1.pth")

方式1还是有陷阱的,加载

import torch
import torchvision

# 陷阱1
model = torch.load('pongber_method1.pth')
print(model)

返回:

D:\Anaconda_python3.12\envs\py3.10\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)
Traceback (most recent call last):
  File "D:\desktop\learn_dl\pytorch_1\P26_model_load.py", line 16, in <module>
    model = torch.load('pongber_method1.pth')
  File "D:\Anaconda_python3.12\envs\py3.10\lib\site-packages\torch\serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "D:\Anaconda_python3.12\envs\py3.10\lib\site-packages\torch\serialization.py", line 1172, in _load
    result = unpickler.load()
  File "D:\Anaconda_python3.12\envs\py3.10\lib\site-packages\torch\serialization.py", line 1165, in find_class
    return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'Pongber' on <module '__main__' from 'D:\\desktop\\learn_dl\\pytorch_1\\P26_model_load.py'>

这就得把之前的网络结构复制过来,但是不需要写pongber = Pongber()

import torch
import torchvision
from torch import nn

# 陷阱1
class Pongber(nn.Module):
    def __init__(self):
        super(Pongber, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model = torch.load('pongber_method1.pth')
print(model)

这样就会正常返回:

Pongber(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
)

其实这个陷阱也不算一个陷阱,因为真实写项目过程中,我们会把它定义在一个单独的文件里

from P26_model_save import *

# 陷阱1
# class Pongber(nn.Module):
#     def __init__(self):
#         super(Pongber, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, 3)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         return x
model = torch.load('pongber_method1.pth')
print(model)

这样也会正常返回:

Pongber(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
)

原始资料地址:
网络模型的保存与读取
如有侵权联系删除 仅供学习交流使用