||
通过一个小案例,反映出DataLoader使用完整流程!!
import torch
import torch.utils.data as Data
BATCH_SIZE = 3 #批训练数据个数
x = torch.linspace(1,10,10) #x data (torch tensor)
y = torch.linspace(10,1,10) #y data (torch tensor)
#随后我们需要把X和Y组成一个完整的数据集,并转化为pytorch能识别的数据集类型:
torch_dataset = Data.TensorDataset(x,y)
#可以看出我们把X和Y通过Data.TensorDataset() 这个函数拼装成了一个数据集,数据集的类型是【TensorDataset】。
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = BATCH_SIZE,
shuffle = True,
num_workers = 0,
)
for epoch in range(5): #训练所有数据5次
i = 0
for batch_x, batch_y in loader:
i = i + 1
print('Epoch:{} | num:{} | batch_x:{} | batch_y:{}'
.format(epoch,i,batch_x,batch_y))
#for epoch in range(3): #训练所有!整套!数据3次
# i=0
# #,每一步 loader 释放一小批数据用来学习
# for step,(batch_x, batch_y) in enumerate(loader):
# i = i+1
# #打出来一些数据
# print('Epoch:{} | num:{} | batch_x:{} | batch_y:{}'
# .format(epoch,i,batch_x,batch_y))
【参考】
https://www.jb51.net/article/167009.htm
点滴分享,福泽你我!Add oil!
Archiver|手机版|科学网 ( 京ICP备07017567号-12 )
GMT+8, 2024-4-26 22:06
Powered by ScienceNet.cn
Copyright © 2007- 中国科学报社