torch tensor操作

torch tensor操作
foresta.yang一、张量的基本操作
Pytorch 中,张量的操作分为结构操作和数学运算,其理解就如字面意思。结构操作就是改变张量本身的结构,数学运算就是对张量的元素值完成数学运算。
- 常使用的张量结构操作:维度变换(tranpose、view 等)、合并分割(split、chunk等)、索引切片(index_select、gather 等)。
- 常使用的张量数学运算:标量运算、向量运算、矩阵运算。
二、维度变换
2.1 squeeze vs unsqueeze 维度增减
- squeeze():对 tensor 进行维度的压缩,去掉维数为 1 的维度。用法:torch.squeeze(a) 将 a 中所有为 1 的维度都删除,或者 a.squeeze(1) 是去掉 a中指定的维数为 1 的维度。
- unsqueeze():对数据维度进行扩充,给指定位置加上维数为 1 的维度。用法:torch.unsqueeze(a, N),或者 a.unsqueeze(N),在 a 中指定位置 N 加上一个维数为 1 的维度。
squeeze 用例程序如下:
1 | a = torch.rand(1,1,3,3) |
程序输出结果如下:
torch.Size([3, 3]) torch.Size([1, 3, 3])
unsqueeze 用例程序如下:
1 | x = torch.rand(3,3) |
程序输出结果如下:
torch.Size([1, 3, 3]) torch.Size([1, 3, 3])
2.2 transpose vs permute 维度交换
torch.transpose() 只能交换两个维度,而 .permute() 可以自由交换任意位置。函数定义如下:
- transpose(dim0, dim1) → Tensor # See torch.transpose()
- permute(*dims) → Tensor # dim(int). Returns a view of the original tensor with its dimensions permuted.
在 CNN 模型中,我们经常遇到交换维度的问题,举例:四个维度表示的 tensor:[batch, channel, h, w](nchw),如果想把 channel 放到最后去,形成[batch, h, w, channel](nhwc),如果使用 torch.transpose() 方法,至少要交换两次(先 1 3 交换再 1 2 交换),而使用 .permute() 方法只需一次操作,更加方便。例子程序如下:
1 | import torch |
2.3 reshape vs view
view只适合对满足连续性条件(contiguous)的tensor进行操作,而reshape同时还可以对不满足连续性条件的tensor进行操作,具有更好的鲁棒性。view能干的reshape都能干,如果view不能干就可以用reshape来处理。更多可看[1]
2.4 einsum
首先看下 einsum 实现矩阵乘法的例子:
1 | a = torch.rand(2,3) |
这个方法可以实现矩阵乘法,但是也可以用来更换维度
1 | # 语义解析: |
更多可看[2]
三、索引切片
3.1 规则索引切片方式
张量的索引切片方式和 numpy、python 多维列表几乎一致,都可以通过索引和切片对部分元素进行修改。切片时支持缺省参数和省略号。实例代码如下:
1 | 1,10,[3,3]) t = torch.randint( |
以上切片方式相对规则,对于不规则的切片提取,可以使用 torch.index_select, torch.take, torch.gather, torch.masked_select。
3.2 gather 和 torch.index_select 算子
gather 算子的用法比较难以理解,在翻阅了官方文档和网上资料后,我有了一些自己的理解。
1,gather 是不规则的切片提取算子(Gathers values along an axis specified by dim. 在指定维度上根据索引 index 来选取数据)。函数定义如下:
1 | torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor |
参数解释:
- input (Tensor) – the source tensor.
- dim (int) – the axis along which to index.
- index (LongTensor) – the indices of elements to gather.
对于 3D tensor,output 值的定义如下: gather 的官方定义如下:
1 | out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 |
下面结合 2D 和 3D tensor 的用例来直观理解算子用法。
(1)对于 2D tensor 的例子:
1 | import torch |
output 值定义如下:
1 | # 按照 index = tensor([[0, 1, 2, 3]])顺序作用在行上索引依次为0,1,2,3 |
(2)索引更复杂的 2D tensor 例子:
1 | 1, 2], [3, 4]]) t = torch.tensor([[ |
output 值的计算如下:
1 | output[i][j] = input[i][index[i][j]] # if dim = 1 |
总结:可以看到 gather 是通过将索引在指定维度 dim 上的值替换为 index 的值,但是其他维度索引不变的情况下获取 tensor 数据。直观上可以理解为对矩阵进行重排,比如对每一行(dim=1)的元素进行变换,比如 torch.gather(a, 1, torch.tensor([[1,2,0], [1,2,0]])) 的作用就是对 矩阵 a 每一行的元素,进行 permtute(1,2,0) 操作。
2,理解了 gather 再看 index_select 就很简单,函数作用是返回沿着输入张量的指定维度的指定索引号进行索引的张量子集。函数定义如下:
1 | torch.index_select(input, dim, index, *, out=None) → Tensor |
函数返回一个新的张量,它使用数据类型为 LongTensor 的 index 中的条目沿维度 dim 索引输入张量。返回的张量具有与原始张量(输入)相同的维数。 维度尺寸与索引长度相同; 其他尺寸与原始张量中的尺寸相同。实例代码如下:
1 | 3, 4) x = torch.randn( |
四、合并分割
4.1 torch.cat 和 torch.stack
可以用 torch.cat 方法和 torch.stack 方法将多个张量合并,也可以用 torch.split方法把一个张量分割成多个张量。torch.cat 和 torch.stack 有略微的区别,torch.cat 是连接,不会增加维度,而 torch.stack 是堆叠,会增加一个维度。两者函数定义如下:
1 | # Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty. |
torch.cat 和 torch.stack 用法实例代码如下:
1 | 0,9).view(3,3) a = torch.arange( |
4.2 torch.split 和 torch.chunk
torch.split() 和 torch.chunk() 可以看作是 torch.cat() 的逆运算。split() 作用是将张量拆分为多个块,每个块都是原始张量的视图。split() 函数定义如下:
1 | """ |
chunk() 作用是将 tensor 按 dim(行或列)分割成 chunks 个 tensor 块,返回的是一个元组。chunk() 函数定义如下:
1 | torch.chunk(input, chunks, dim=0) → List of Tensors |
实例代码如下:
1 | 10).reshape(5,2) a = torch.arange( |
五、卷积相关算子
5.1 上采样方法总结
上采样大致被总结成了三个类别:
- 基于线性插值的上采样:最近邻算法(nearest)、双线性插值算法(bilinear)、双三次插值算法(bicubic)等,这是传统图像处理方法。
- 基于深度学习的上采样(转置卷积,也叫反卷积 Conv2dTranspose2d等)
- Unpooling 的方法(简单的补零或者扩充操作)
计算效果:最近邻插值算法 < 双线性插值 < 双三次插值。计算速度:最近邻插值算法 > 双线性插值 > 双三次插值。
5.2 F.interpolate 采样函数
Pytorch 老版本有 nn.Upsample 函数,新版本建议用 torch.nn.functional.interpolate,一个函数可实现定制化需求的上采样或者下采样功能,。
F.interpolate() 函数全称是 torch.nn.functional.interpolate(),函数定义如下:
1 | def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): # noqa: F811 |
参数解释如下:
- input(Tensor):输入张量数据;
- size: 输出的尺寸,数据类型为 tuple: ([optional D_out], [optional H_out], W_out),和 scale_factor 二选一;
- scale_factor:在高度、宽度和深度上面的放大倍数。数据类型既可以是 int——表明高度、宽度、深度都扩大同一倍数;也可是tuple——指定高度、宽度、深度等维度的扩大倍数;
- mode: 上采样的方法,包括最近邻(nearest),线性插值(linear),双线性插值(bilinear),三次线性插值(trilinear),默认是最近邻(nearest);
- align_corners: 如果设为True,输入图像和输出图像角点的像素将会被对齐(aligned),这只在mode = linear, bilinear, or trilinear才有效,默认为False。
例子程序如下:
1 | import torch.nn.functional as F |
5.3 nn.ConvTranspose2d 反卷积
转置卷积(有时候也称为反卷积,个人觉得这种叫法不是很规范),它是一种特殊的卷积,先 padding 来扩大图像尺寸,紧接着跟正向卷积一样,旋转卷积核 180 度,再进行卷积计算。
引用
[0] zhuanlan.zhihu.com/p/
[1] https://blog.csdn.net/Flag_ing/article/details/109129752
[2] https://zhuanlan.zhihu.com/p/361209187