Pytorch学习笔记4-张量的结构操作

由于时效问题,该文某些代码、技术可能已经过期,请注意!!!本文最后更新于:2 年前

张量的操作主要包括张量的结构操作和张量的数学运算

张量结构操作诸如:张量创建,索引切片,维度变换,合并分割。
张量数学运算主要有:标量运算,向量运算,矩阵运算。另外我们会介绍张量运算的广播机制。

创建张量

张量创建的许多方法和numpy中创建array的方法很像

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
import torch

a = torch.tensor([1,2,3],dtype = torch.float)
print(a)
'''
tensor([1., 2., 3.])
'''

b = torch.arange(1,10,step = 2)
print(b)
'''
tensor([1, 3, 5, 7, 9])
'''

c = torch.linspace(0.0, 2*3.14, 10)
print(c)
'''
tensor([0.0000, 0.6978, 1.3956, 2.0933, 2.7911, 3.4889, 4.1867, 4.8844, 5.5822,
6.2800])
'''

另外还有 torch.ones()、torch.zero_like()、torch.zeros() 等方法创建

1
2
3
4
5
6
7
8
# 均匀随机分布
torch.manual_seed(0)
minval, maxval = 0, 10
a = minval + (maxval - minval) * torch.rand([5])
print(a)
'''
tensor([4.9626, 7.6822, 0.8848, 1.3203, 3.0742])
'''
1
2
3
4
5
6
7
8
# 正太分布随机
b = torch.normal(mean = torch.zeros(3,3), std = torch.ones(3,3))
print(b)
'''
tensor([[ 0.5507, 0.2704, 0.6472],
[ 0.2490, -0.3354, 0.4564],
[-0.6255, 0.4539, -1.3740]])
'''
1
2
3
4
5
6
7
# 整数随机排列
d = torch.randperm(20)
print(d)
'''
tensor([ 3, 17, 9, 19, 1, 18, 4, 13, 15, 12, 0, 16, 7, 11, 2, 5, 8, 10,
6, 14])
'''

此外还有 torch.eye() (单位矩阵) 、 torch.diag() (对角矩阵) 等

索引切片

张量的索引切片方式和numpy几乎是一样的。切片时支持缺省参数和省略号。可以通过索引和切片对部分元素进行修改。
此外,对于不规则的切片提取,可以使用 torch.index_select, torch.masked_select, torch.take
如果要通过修改张量的某些元素得到新的张量,可以使用 torch.where,torch.masked_fill,torch.index_fill

  • torch.where可以理解为if的张量版本。
  • torch.index_fill的选取元素逻辑和torch.index_select相同。
  • torch.masked_fill的选取元素逻辑和torch.masked_select相同。
维度变换

维度变换相关函数主要有 torch.reshape(或者调用张量的view方法), torch.squeeze, torch.unsqueeze, torch.transpose

  • torch.reshape 可以改变张量的形状。
  • torch.squeeze 可以减少维度。
  • torch.unsqueeze 可以增加维度。
  • torch.transpose 可以交换维度。

如果张量在某个维度上只有一个元素,利用torch.squeeze可以消除这个维度。
torch.unsqueeze的作用和torch.squeeze的作用相反。

torch.transpose可以交换张量的维度,torch.transpose常用于图片存储格式的变换上。
如果是二维的矩阵,通常会调用矩阵的转置方法 matrix.t(),等价于 torch.transpose(matrix,0,1)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
minval=0
maxval=255
# Batch,Height,Width,Channel
data = torch.floor(minval + (maxval-minval)*torch.rand([100,256,256,4])).int()
print(data.shape)

# 转换成 Pytorch默认的图片格式 Batch,Channel,Height,Width
# 需要交换两次
data_t = torch.transpose(torch.transpose(data,1,2),1,3)
print(data_t.shape)
'''
torch.Size([100, 256, 256, 4])
torch.Size([100, 4, 256, 256])
'''
合并分割

可以用torch.cat方法和torch.stack方法将多个张量合并,可以用torch.split方法把一个张量分割成多个张量。
torch.cat和torch.stack有略微的区别,torch.cat是连接,不会增加维度,而torch.stack是堆叠,会增加维度。
torch.split是torch.cat的逆运算,可以指定分割份数平均分割,也可以通过指定每份的记录数量进行分割

搬运自: