pytorch中怎么用plt显示tensor?

猿友 2021-07-23 15:10:53 浏览数 (4117)
反馈

在pytorch中图片的张量结构与plt可以显示的图片格式要求是不一样的,所以plt是不能直接显示tensor格式的图片的,那么pytorch怎么用plt显示tensor图片呢?这就需要涉及到数据转换了,基本思路就是将tensor转换为numpy类型的数据结构,而numpy类型的格式刚好可以被plt支持。接下来就来看具体怎么操作吧!

问题

图像的张量结构为(C,H,W),而plt可以显示的图片格式要求(H,W,C),C为颜色通道数,可以没有。

所以问题就是将Tensor(C,H,W)=> numpy(H,W,C)

解决办法

def transimg(img):
    img = img / 2 + 0.5 # unnormalize
    npimg = img.numpy()
    npimg1 = np.transpose(npimg,(1,2,0)) # C*H*W => H*W*C
    return npimg1

以上就是pytorch怎么用plt显示tensor的方法介绍了,希望能给大家一个参考,也希望大家多多支持W3Cschool


0 人点赞