1-问题:only one element tensors can be converted to Python scalars
2-分析问题
      for e in range(start, self.num_epochs):
           for i, (input_data, labels) in enumerate(zip(tqdm(self.data_loader))):
               iter_ctr += 1
               start = time.time()

               input_data = self.to_var(input_data)
               # print('input_data.size:',input_data.size)
               # input_data=input_data.reshape[256,-1]
               input_data=input_data.view(input_data.size(0),-1)
               total_loss,sample_energy, recon_error, cov_diag = self.dagmm_step(input_data)

在训练网络时数据输入用了torch.utils.data.DataLoader函数,用自己的数据库进行数据封装。由于先用了zip(),所以输出是tuple元组格式,用np.array将其转换为数组格式就可以将其导入,但是问题来了,每次只能转一个batch_size的数据之后的不能连续转换

3-解决方案
       for e in range(self.num_epochs):
           print('Epoch ({}/{})----------------------------------------------------------------------------'.format(e, self.num_epochs))
           batch_idxs = len(self.data_loader)// self.batch_size  #sample 367200,batch_size=256
           for i in tqdm(range(0, batch_idxs)):
               iter_ctr += 1
               start = time.time()
               batch = self.data_loader[i * self.batch_size:(i + 1) * self.batch_size]
               input_data = np.array(batch).astype(np.float32)
               # print('batch_images', input_data.shape)
               input_data=torch.tensor(input_data)

               input_data = self.to_var(input_data)
               # print('input_data.size:',input_data.size)
               # input_data=input_data.reshape[256,-1]
               input_data=input_data.view(input_data.size(0),-1)
               total_loss,sample_energy, recon_error, cov_diag = self.dagmm_step(input_data)

将上述代码改写了一下,解决了该问题,但是感觉不是最优解,我看了其他博客他们的解决方案主要是:数组和张量相互转化,经过实验证明在我的问题上并未能解决问题

要是大家有更好的解决方案欢迎共享~

GitHub 加速计划 / eleme / element
54.06 K
14.63 K
下载
A Vue.js 2.0 UI Toolkit for Web
最近提交(Master分支:2 个月前 )
c345bb45 6 个月前
a07f3a59 * Update transition.md * Update table.md * Update transition.md * Update table.md * Update transition.md * Update table.md * Update table.md * Update transition.md * Update popover.md 7 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐