Numpy中axis的理解
Numpy是个好东西,但是ndarray的轴感觉弄不太明白。可能二维三维数组还好,要是再增加几维就无法在脑海中想象这个东西,对于一些有关轴的操作就稀里糊涂,只能一个个尝试。现在准备把它彻底弄明白!
思路
首先从二维入手,然后扩展到三维以及更高的维度(从特殊到一般),然后找出普遍的规律,再进行验证(从一般到特殊)
官方文档应该是最权威的,首先看官方文档是怎么说明的,然后查找一些资料,看看其他人是怎么理解的,最后总结出自己的一套规律
import numpy as np
ndarray.shape
感受一个ndarray,最简单的方法就是打印ndarray的shape。
官方文档里面是这样写的:
the dimensions of the array. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, shape
will be (n,m)
. The length of the shape
tuple is therefore the number of axes, ndim
.
只列举了矩阵的例子,尝试一下:
a1 = np.arange(15).reshape(3, 5)
print(a1,'\n',a1.shape,'\n',a1.ndim)
输出结果:
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]]
(3, 5)
2
- reshape成什么样,最后打印出来的shape就会是什么样,这一点可以确定。
- 官方文档里面写道“对于一个n行m列的矩阵来说,shape将会是(n,m)”。经验证,打印出来了一个3行5列的矩阵,shape是(3,5)。
- 官方文档里面写道“shape元组的长度就是轴的数量,也就是ndim”。经验证,ndim=2
简单推断:最开始有2个方括号,因此矩阵是2维的,且第1个方括号内部有3个“2级方括号”,每一个“2级方括号”内部都有5个元素,因此这个shape可能是从外向里数的。
尝试1维ndarray:
a2 = np.arange(15)
print(a2,'\n',a2.shape,'\n',a2.ndim)
输出结果:
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]
(15,)
1
- 打印ndim为1,最开始有1个方括号,因此数组是1维的。结论得到验证。
- 打印shape为(15,)(一维元组),第1个方括号内部没有“2级方括号”shape从外向里数只有15。结论得到验证。
尝试3维ndarray:
a3 = np.arange(24).reshape(3,2,4)
print(a3,'\n',a3.shape,'\n',a3.ndim)
输出结果:
[[[ 0 1 2 3]
[ 4 5 6 7]]
[[ 8 9 10 11]
[12 13 14 15]]
[[16 17 18 19]
[20 21 22 23]]]
(3, 2, 4)
3
- 打印ndim为3,最开始有3个方括号,因此数组是3维的。结论得到验证。
- 打印shape为(3, 2, 4),第1个方括号内部有3个“2级方括号”,“2级方括号”内部有2个“3级方括号”,“3级方括号”内部有4个元素。满足shape从外向里数,结论得到验证。
尝试4维ndarray:
a4 = np.arange(24).reshape(3,2,1,4)
print(a4,'\n',a4.shape,'\n',a4.ndim)
输出结果:
[[[[ 0 1 2 3]]
[[ 4 5 6 7]]]
[[[ 8 9 10 11]]
[[12 13 14 15]]]
[[[16 17 18 19]]
[[20 21 22 23]]]]
(3, 2, 1, 4)
4
- 打印ndim为4,最开始有4个方括号,因此数组是4维的。结论得到验证。
- 打印shape为(3, 2, 1, 4),第1个方括号内部有3个“2级方括号”,“2级方括号”内部有2个“3级方括号”,“3级方括号”内部有1个“4级方括号”,“4级方括号”内部有4个元素。满足shape从外向里数,结论得到验证。
- 有一个维度是1,也就是这个维度实际上并没有任何的作用。但是在实际中可能会有“凑维度”的操作,需要手动增加或者减少维度,会出现这种维度为1的情况。(增加维度使用reshape()实现,减小维度使用squeeze()实现)
因此可以得出结论:对于给定的ndarray,判断ndim就是计数最前面有多少个相连的方括号,判断shape就是从外向内看,每一层分别有多少个“元素”。
也可以看出,数组超过4维后,肉眼就有些难以区分了。
索引
索引就是取数组中的某些元素,官方文档有下面的举例:
>>> a = np.arange(30).reshape(2, 3, 5)
>>> a
array([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]],
[[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]])
>>> a[0, 2, :]
array([10, 11, 12, 13, 14])
>>> a[0, :, 3]
array([ 3, 8, 13])
索引操作是与shape相对应的。如上述例子,a[0]即为取数组的第1个维度(2)的第1个元素,这样原来3维的数组就降到了2维;a[0, :]就是在a[0]的基础上取数组的第2个维度(3)的全部元素,数组的维度不变,还是2维;a[0, :, 3]就是在a[0, :]的基础上取数组的第3个维度(5)的第4个元素,即可得出上面的结果。
索引操作后的维度与索引的数量以及是否有“:”相关。如果索引的数量与ndim相同,则最后取出来的是一个数。如果数量不同或者有“:”(数量不同可以看成在后面补“:”),则最终取得的数组的维度与“:”对应的原数组的维度相同。
轴
以numpy.sum为例:
官方文档:
Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, a sum is performed on all of the axes specified in the tuple instead of a single axis or all the axes as before.
以三维数组为例:
print('origin')
print(a3,a3.shape)
print('axis=0')
print(a3.sum(axis=0),a3.sum(axis=0).shape)
print('axis=1')
print(a3.sum(axis=1),a3.sum(axis=1).shape)
print('axis=2')
print(a3.sum(axis=2),a3.sum(axis=2).shape)
print('axis=(0,1)')
print(a3.sum(axis=(0,1)),a3.sum(axis=(0,1)).shape)
print('axis=(1,2)')
print(a3.sum(axis=(1,2)),a3.sum(axis=(1,2)).shape)
print('axis=(0,2)')
print(a3.sum(axis=(0,2)),a3.sum(axis=(0,2)).shape)
print('axis=(0,1,2)')
print(a3.sum(axis=(0,1,2)),a3.sum(axis=(0,1,2)).shape)
origin
[[[ 0 1 2 3]
[ 4 5 6 7]]
[[ 8 9 10 11]
[12 13 14 15]]
[[16 17 18 19]
[20 21 22 23]]] (3, 2, 4)
axis=0
[[24 27 30 33]
[36 39 42 45]] (2, 4)
axis=1
[[ 4 6 8 10]
[20 22 24 26]
[36 38 40 42]] (3, 4)
axis=2
[[ 6 22]
[38 54]
[70 86]] (3, 2)
axis=(0,1)
[60 66 72 78] (4,)
axis=(1,2)
[ 28 92 156] (3,)
axis=(0,2)
[114 162] (2,)
axis=(0,1,2)
276 ()
axis为多少,就是在这个维度上进行操作,最终的结果就是这个维度消失
不要从行列什么的去思考怎么变化,直接从shape的角度入手。设置axis为多少,这个维度就没有了!比如原来是(3,2,4)的维度,要是axis=0,第一个维度就没有了,加和得到的矩阵就是(2,4)。
如果希望保留维度,可以增加keepdims=True的选项,这样被操作的维度就会变为1而不是直接消失。
print('axis=(0,1)')
print(a3.sum(axis=(0,1),keepdims=True),a3.sum(axis=(0,1),keepdims=True).shape)
axis=(0,1)
[[[60 66 72 78]]] (1, 1, 4)
这样想应该会比较好理解,尤其是对于更高维的数组来说,行列的概念基本失效,从shape的角度思考会好。
np.concatenate
另外一个比较常用的操作是np.concatenate,可以将数组进行合并,在数据处理或者神经网络中很常用。
在np.concatenate上检验一下对于axis的理解:
ta = np.arange(24).reshape(3,2,4)
tb = np.arange(24,36).reshape(3,1,4)
print(ta,ta.shape)
print(tb,tb.shape)
[[[ 0 1 2 3]
[ 4 5 6 7]]
[[ 8 9 10 11]
[12 13 14 15]]
[[16 17 18 19]
[20 21 22 23]]] (3, 2, 4)
[[[24 25 26 27]]
[[28 29 30 31]]
[[32 33 34 35]]] (3, 1, 4)
两者合并,第2个维度不相同,应该是可以合并的,合并后的shape应该为(3,3,4)
print(np.concatenate((ta,tb),axis=1),np.concatenate((ta,tb),axis=1).shape)
[[[ 0 1 2 3]
[ 4 5 6 7]
[24 25 26 27]]
[[ 8 9 10 11]
[12 13 14 15]
[28 29 30 31]]
[[16 17 18 19]
[20 21 22 23]
[32 33 34 35]]] (3, 3, 4)
np.concatenate除了在待合并的axis上之外,必须具有相同的shape
之前的结论也得到了验证。
总结
我们处在三维空间中,二维和三维是比较直观的,可以在脑海中想象出来。因此我们会觉得axis的设计有些反直觉。以后应该从shape的角度去看待axis的设计思想,首先理解上比较直观,其次在更高维度的数组上也能合理的进行操作。不要去思考数组实际中应该是个什么样子,直接观察axis就足够了。