网络模型的保存与读取
模型的保存
模型的加载
完整的模型训练套路-GPU训练
完整的模型验证套路
再来看一下GitHub
模型保存
import torch
import torchvision
# vgg16_false = torchvision.models.vgg16(pretrained = False)
vgg16 = torchvision.models.vgg16(weights=None)
# 保存方式1,模型结构+模型参数
# torch.save(vgg16_false, "vgg16_method0.pth")
torch.save(vgg16, "vgg16_method1.pth")
# 保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 把vgg16的状态保存成字典形式,字典是python的一种数据格式
# 相当于把vgg16网络模型中的参数保存成字典,现在不保存结构了,就保存一些参数
对于比较大的模型,保存方式2用的空间肯定更小
在终端输入ls -all
,我们使用dir
...
2025/06/03 15:38 553,450,705 vgg16_method1.pth
2025/06/03 15:38 553,441,041 vgg16_method2.pth
...
重点看这两种方式保存的这两个文件,明显可以看出方式2保存的文件比方式1保存的文件要小
模型加载
import torch
# 方式1 -> 保存方式1,加载模型
model = torch.load("vgg16_method1.pth")
print(model)
# 方式2 -> 保存方式2,加载模型
model = torch.load("vgg16_method2.pth")
print(model)
保存方式1保存的模型,可以比较明显的看出,加载的模型参数与保存的模型参数一样
保存方式2保存的模型,按照字典的方式保存,返回信息太长,可自行验证
那如何恢复成网络模型的形式?
import torch
import torchvision
# 方式1 -> 保存方式1,加载模型
model = torch.load("vgg16_method1.pth")
print(model)
# 方式2 -> 保存方式2,加载模型
vgg16 = torchvision.models.vgg16(weights=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
# print(model)
print(vgg16)
方式1还是有陷阱的,保存
import torch
import torchvision
from torch import nn
# 陷阱
class Pongber(nn.Module):
def __init__(self):
super(Pongber, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3)
def forward(self, x):
x = self.conv1(x)
return x
pongber = Pongber()
torch.save(pongber, "pongber_method1.pth")
方式1还是有陷阱的,加载
import torch
import torchvision
# 陷阱1
model = torch.load('pongber_method1.pth')
print(model)
返回:
D:\Anaconda_python3.12\envs\py3.10\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
warnings.warn(msg)
Traceback (most recent call last):
File "D:\desktop\learn_dl\pytorch_1\P26_model_load.py", line 16, in <module>
model = torch.load('pongber_method1.pth')
File "D:\Anaconda_python3.12\envs\py3.10\lib\site-packages\torch\serialization.py", line 809, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "D:\Anaconda_python3.12\envs\py3.10\lib\site-packages\torch\serialization.py", line 1172, in _load
result = unpickler.load()
File "D:\Anaconda_python3.12\envs\py3.10\lib\site-packages\torch\serialization.py", line 1165, in find_class
return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'Pongber' on <module '__main__' from 'D:\\desktop\\learn_dl\\pytorch_1\\P26_model_load.py'>
这就得把之前的网络结构复制过来,但是不需要写pongber = Pongber()
import torch
import torchvision
from torch import nn
# 陷阱1
class Pongber(nn.Module):
def __init__(self):
super(Pongber, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3)
def forward(self, x):
x = self.conv1(x)
return x
model = torch.load('pongber_method1.pth')
print(model)
这样就会正常返回:
Pongber(
(conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
)
其实这个陷阱也不算一个陷阱,因为真实写项目过程中,我们会把它定义在一个单独的文件里
from P26_model_save import *
# 陷阱1
# class Pongber(nn.Module):
# def __init__(self):
# super(Pongber, self).__init__()
# self.conv1 = nn.Conv2d(3, 64, 3)
#
# def forward(self, x):
# x = self.conv1(x)
# return x
model = torch.load('pongber_method1.pth')
print(model)
这样也会正常返回:
Pongber(
(conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
)
原始资料地址:
网络模型的保存与读取
如有侵权联系删除 仅供学习交流使用