Numpy中axis的理解

Numpy是个好东西,但是ndarray的轴感觉弄不太明白。可能二维三维数组还好,要是再增加几维就无法在脑海中想象这个东西,对于一些有关轴的操作就稀里糊涂,只能一个个尝试。现在准备把它彻底弄明白!

思路

首先从二维入手,然后扩展到三维以及更高的维度(从特殊到一般),然后找出普遍的规律,再进行验证(从一般到特殊)

官方文档应该是最权威的,首先看官方文档是怎么说明的,然后查找一些资料,看看其他人是怎么理解的,最后总结出自己的一套规律

import numpy as np

ndarray.shape

感受一个ndarray,最简单的方法就是打印ndarray的shape。

官方文档里面是这样写的:

the dimensions of the array. This is a tuple of integers indicating the size of the array in each dimension. For a matrix with n rows and m columns, shape will be (n,m). The length of the shape tuple is therefore the number of axes, ndim.

只列举了矩阵的例子,尝试一下:

a1 = np.arange(15).reshape(3, 5)
print(a1,'\n',a1.shape,'\n',a1.ndim)

输出结果:

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]] 
 (3, 5) 
 2
  1. reshape成什么样,最后打印出来的shape就会是什么样,这一点可以确定。
  2. 官方文档里面写道“对于一个n行m列的矩阵来说,shape将会是(n,m)”。经验证,打印出来了一个3行5列的矩阵,shape是(3,5)。
  3. 官方文档里面写道“shape元组的长度就是轴的数量,也就是ndim”。经验证,ndim=2

简单推断:最开始有2个方括号,因此矩阵是2维的,且第1个方括号内部有3个“2级方括号”,每一个“2级方括号”内部都有5个元素,因此这个shape可能是从外向里数的。

尝试1维ndarray:

a2 = np.arange(15)
print(a2,'\n',a2.shape,'\n',a2.ndim)

输出结果:

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14] 
 (15,) 
 1
  1. 打印ndim为1,最开始有1个方括号,因此数组是1维的。结论得到验证。
  2. 打印shape为(15,)(一维元组),第1个方括号内部没有“2级方括号”shape从外向里数只有15。结论得到验证。

尝试3维ndarray:

a3 = np.arange(24).reshape(3,2,4)
print(a3,'\n',a3.shape,'\n',a3.ndim)

输出结果:

[[[ 0  1  2  3]
  [ 4  5  6  7]]

 [[ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]]] 
 (3, 2, 4) 
 3
  1. 打印ndim为3,最开始有3个方括号,因此数组是3维的。结论得到验证。
  2. 打印shape为(3, 2, 4),第1个方括号内部有3个“2级方括号”,“2级方括号”内部有2个“3级方括号”,“3级方括号”内部有4个元素。满足shape从外向里数,结论得到验证。

尝试4维ndarray:

a4 = np.arange(24).reshape(3,2,1,4)
print(a4,'\n',a4.shape,'\n',a4.ndim)

输出结果:

[[[[ 0  1  2  3]]

  [[ 4  5  6  7]]]


 [[[ 8  9 10 11]]

  [[12 13 14 15]]]


 [[[16 17 18 19]]

  [[20 21 22 23]]]] 
 (3, 2, 1, 4) 
 4
  1. 打印ndim为4,最开始有4个方括号,因此数组是4维的。结论得到验证。
  2. 打印shape为(3, 2, 1, 4),第1个方括号内部有3个“2级方括号”,“2级方括号”内部有2个“3级方括号”,“3级方括号”内部有1个“4级方括号”,“4级方括号”内部有4个元素。满足shape从外向里数,结论得到验证。
  3. 有一个维度是1,也就是这个维度实际上并没有任何的作用。但是在实际中可能会有“凑维度”的操作,需要手动增加或者减少维度,会出现这种维度为1的情况。(增加维度使用reshape()实现,减小维度使用squeeze()实现)

因此可以得出结论:对于给定的ndarray,判断ndim就是计数最前面有多少个相连的方括号,判断shape就是从外向内看,每一层分别有多少个“元素”。

也可以看出,数组超过4维后,肉眼就有些难以区分了。

索引

索引就是取数组中的某些元素,官方文档有下面的举例:

>>> a = np.arange(30).reshape(2, 3, 5)
>>> a
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]],

        [[15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]])
>>> a[0, 2, :]
array([10, 11, 12, 13, 14])
>>> a[0, :, 3]
array([ 3,  8, 13])

索引操作是与shape相对应的。如上述例子,a[0]即为取数组的第1个维度(2)的第1个元素,这样原来3维的数组就降到了2维;a[0, :]就是在a[0]的基础上取数组的第2个维度(3)的全部元素,数组的维度不变,还是2维;a[0, :, 3]就是在a[0, :]的基础上取数组的第3个维度(5)的第4个元素,即可得出上面的结果。

索引操作后的维度与索引的数量以及是否有“:”相关。如果索引的数量与ndim相同,则最后取出来的是一个数。如果数量不同或者有“:”(数量不同可以看成在后面补“:”),则最终取得的数组的维度与“:”对应的原数组的维度相同。

以numpy.sum为例:

官方文档:

Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.

If axis is a tuple of ints, a sum is performed on all of the axes specified in the tuple instead of a single axis or all the axes as before.

以三维数组为例:

print('origin')
print(a3,a3.shape)
print('axis=0')
print(a3.sum(axis=0),a3.sum(axis=0).shape)
print('axis=1')
print(a3.sum(axis=1),a3.sum(axis=1).shape)
print('axis=2')
print(a3.sum(axis=2),a3.sum(axis=2).shape)
print('axis=(0,1)')
print(a3.sum(axis=(0,1)),a3.sum(axis=(0,1)).shape)
print('axis=(1,2)')
print(a3.sum(axis=(1,2)),a3.sum(axis=(1,2)).shape)
print('axis=(0,2)')
print(a3.sum(axis=(0,2)),a3.sum(axis=(0,2)).shape)
print('axis=(0,1,2)')
print(a3.sum(axis=(0,1,2)),a3.sum(axis=(0,1,2)).shape)
origin
[[[ 0  1  2  3]
  [ 4  5  6  7]]

 [[ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]]] (3, 2, 4)
axis=0
[[24 27 30 33]
 [36 39 42 45]] (2, 4)
axis=1
[[ 4  6  8 10]
 [20 22 24 26]
 [36 38 40 42]] (3, 4)
axis=2
[[ 6 22]
 [38 54]
 [70 86]] (3, 2)
axis=(0,1)
[60 66 72 78] (4,)
axis=(1,2)
[ 28  92 156] (3,)
axis=(0,2)
[114 162] (2,)
axis=(0,1,2)
276 ()

axis为多少,就是在这个维度上进行操作,最终的结果就是这个维度消失

不要从行列什么的去思考怎么变化,直接从shape的角度入手。设置axis为多少,这个维度就没有了!比如原来是(3,2,4)的维度,要是axis=0,第一个维度就没有了,加和得到的矩阵就是(2,4)。

如果希望保留维度,可以增加keepdims=True的选项,这样被操作的维度就会变为1而不是直接消失。

print('axis=(0,1)')
print(a3.sum(axis=(0,1),keepdims=True),a3.sum(axis=(0,1),keepdims=True).shape)
axis=(0,1)
[[[60 66 72 78]]] (1, 1, 4)

这样想应该会比较好理解,尤其是对于更高维的数组来说,行列的概念基本失效,从shape的角度思考会好。

np.concatenate

另外一个比较常用的操作是np.concatenate,可以将数组进行合并,在数据处理或者神经网络中很常用。

在np.concatenate上检验一下对于axis的理解:

ta = np.arange(24).reshape(3,2,4)
tb = np.arange(24,36).reshape(3,1,4)
print(ta,ta.shape)
print(tb,tb.shape)
[[[ 0  1  2  3]
  [ 4  5  6  7]]

 [[ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]]] (3, 2, 4)
[[[24 25 26 27]]

 [[28 29 30 31]]

 [[32 33 34 35]]] (3, 1, 4)

两者合并,第2个维度不相同,应该是可以合并的,合并后的shape应该为(3,3,4)

print(np.concatenate((ta,tb),axis=1),np.concatenate((ta,tb),axis=1).shape)
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [24 25 26 27]]

 [[ 8  9 10 11]
  [12 13 14 15]
  [28 29 30 31]]

 [[16 17 18 19]
  [20 21 22 23]
  [32 33 34 35]]] (3, 3, 4)

np.concatenate除了在待合并的axis上之外,必须具有相同的shape

之前的结论也得到了验证。

总结

我们处在三维空间中,二维和三维是比较直观的,可以在脑海中想象出来。因此我们会觉得axis的设计有些反直觉。以后应该从shape的角度去看待axis的设计思想,首先理解上比较直观,其次在更高维度的数组上也能合理的进行操作。不要去思考数组实际中应该是个什么样子,直接观察axis就足够了。

参考资料

Code

Numpy官方文档


Numpy中axis的理解
https://zhangzhao219.github.io/2022/08/06/ndarray-axis/
作者
Zhang Zhao
发布于
2022年8月6日
许可协议