深度学习
作者:
B0b
,
2023-10-17 21:14:36
,
所有人可见
,
阅读 131
pytorch 快速入门
1. 数据集
Dataset 提供一种方式去获取数据及其label
Dataloader 为后面的网络提供不同的数据形式
2. Tensorboard
画图
tensorboard -- logdir=logs --port=6007
3. transform
3.1 Transform 如何使用 totensor
3.2 为什么需要 Tensor 数据类型 包装了参数 :反向传播 梯度 设备
3.3 常见的Transform
查看: 输入 输出 作用
__call__ 内置构造函数
Totensor()
tran_totensor = taransforms.Totensor()
img_tensor = tran_totensor(img)
writer.add_img("Totensor",img_tensor)
ToPILimage()
Nomarlize()
tran_norm = taransforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
img_norm = tran_norm(img_tensor)
writer.add_img("Normalize",img_norm)
Resize()
Compose()
Randomcrop()
4. torchvision
dataset和transform 的使用
5.dateloader
dataset()
batch_size() 取出多少数据
shuffel() 打乱
sample()
num_worker() 多进程
drop_last() 除不尽剩余的数据
6. 神经网络基本骨架torch.nn
neural network
foward()
卷积层--CONV2D
最大池化层
非线性激活
线性层及其他层
损失函数
7. 模型的保存和模型的加载
7.1 方式1 (模型结构 + 参数)
import torch
torch.save(vgg16,"vgg16_method1.pth")
model = troch.laod("vgg16_method1.pth")
7.2 方式2 (模型结构 官网推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
vgg16 = torchvision.nodel.vgg16(pretrained = False)
vgg16.load_state_dict("vgg16_method2.pth")
8. 模型的训练套路
准备数据集
加载数据集Dataloader
搭建神经网络
损失函数 nn.CrossEntropyLoss
优化器 torch.optin.SGD
设置训练网络的参数