Error message here!

Hide Error message here!

Error message here!

Hide Error message here!

Error message here!

Close

## 从头学pytorch(十二):模型保存和加载

sdu20112013 2020-01-03 16:01:00 阅读数:95 评论数:0 点赞数:0 收藏数:0

# 模型读取和存储

## 读写Tensor

``````import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x.pt')``````

``````x2 = torch.load('x.pt')
x2``````

``tensor([1., 1., 1.])``

``````y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list``````

``[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]``

``````torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy``````

``{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}``

## state_dict

``````class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()``````

``````OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678],
[ 0.2030, -0.2073, -0.0104]])),
('hidden.bias', tensor([-0.3117, -0.4232])),
('output.weight', tensor([[-0.4556, 0.4084]])),
('output.bias', tensor([-0.3573]))])``````

``````optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()``````

``{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139952370292992, 139952370293784, 139952370294144, 139952370293496]}]}``

## 保存和加载模型

PyTorch中保存和加载训练模型有两种常见的方法:

1. 仅保存和加载模型参数(`state_dict`)
2. 保存和加载整个模型

### 保存和加载`state_dict`(推荐方式)

``torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth``

``````model = TheModelClass(*args, **kwargs)

### 保存和加载整个模型

``torch.save(model, PATH)``

``model = torch.load(PATH)``

``````X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
Y2 = net2(X)
Y2 == Y``````

``````tensor([[1],
[1]], dtype=torch.uint8)``````

https://www.cnblogs.com/sdu20112013/p/12145341.html