||
Python中的torch包中包含torch.Tensor(a)、torch.tensor(a)、torch.from_numpy(a)、torch.as_tensor(a)四个转tensor函数。区别在于:
torch.Tensor(a)是类构造函数,转出来的tensor格式数据dtype是全局默认dtype(一般为torch.float32),全局默认类型可以通过torch.get_default_dtype()函数来查询;而其它三个都是工厂函数,转出来的tensor格式数据dtype是根据输入a的dtype来推断。
因此:torch.Tensor(a)与torch.tensor(a, dtype=torch.float32)几乎一致。
2. torch.Tensor(a)、torch.tensor(a)是深拷贝,会创造一个新的内存空间,不共享内存,因此a改变时,torch.Tensor(a)、torch.tensor(a)不会改变;而torch.from_numpy(a)、torch.as_tensor(a)不会创造新的内存空间,因此a改变时,torch.from_numpy(a)、torch.as_tensor(a)也会发生改变。
3. torch.from_numpy(a)、torch.as_tensor(a)对比:torch.from_numpy(a)的输入a只能是ndarray格式,并输出一个与a的dtype、device都一样的tensor数据;torch.as_tensor(a)适用性更广,它的输入a可以是非ndarray格式,同时还可以改变dtype和device。torch.as_tensor(a)当a是ndarray格式,且dtype和device都默认时等同于torch.from_numpy(a),是浅拷贝,而当dtype和device不默认时,会创建一个新的内存空间,变为深拷贝。
Archiver|手机版|科学网 ( 京ICP备07017567号-12 )
GMT+8, 2024-10-19 22:24
Powered by ScienceNet.cn
Copyright © 2007- 中国科学报社