|||
用pytorch完成字符识别分类任务时,发现loss = lossFunction(out, labels)报错
同样的代码在MNIST数据集上就没有报错,原因是数据载入类型不符合规范
输入labels维度应该为1维,且精度不能是Double,必须换成long
修改后的数据导入代码:
lossFunction = torch.nn.CrossEntropyLoss()
loss = lossFunction(out, labels.long()) # 修改数据精度
Archiver|手机版|科学网 ( 京ICP备07017567号-12 )
GMT+8, 2025-1-10 03:59
Powered by ScienceNet.cn
Copyright © 2007- 中国科学报社