8.2.索引
import numpy as np
import copy
import pprint
Windows 10
Python 3.8.8 @ MSC v.1928 64 bit (AMD64)
Latest build date 2021.03.03
numpy version: 1.20.1
索引规则
NumPy 数组的索引功能很强大,但多样的索引也带来了复杂性和混乱性。新手可以很快地开始使用 NumPy 索引,但可能难以区分各种索引的情况。虽然,网上有很多关于 NumPy 数组索的资料,也有很多代码示例,但几乎都是比较简单的说明,大部分代码示例只会让人在简单与混乱之间徘徊。NumPy 索引的官方用户指南当然是很好的资料,但大多都是必要或细节性的说明。这里尝试对所有情况做一个简单但不失全面性的总结。
可以使用 Python 的 array[index]
语法对 ndarray 进行索引。假设 ndarray 是 N 维数组,其索引的形式为 array[idx1, idx2, ..., idxN]
,即每个轴都有一个索引。
在 Python 中,x[(exp1, exp2, ..., EXPN)]
等同于 x[exp1, exp2, ..., EXPN]
。后者只是前者的语法糖。
这里将 (idx1,idx2,.,idxN)
称为索引元组,idxN
直接称为索引或索引对象。NumPy 数组的索引 idxN
可以是以下 7 种对象(形式):
- 整数
- 切片(
slice
、:
) - 整数序列
- 整数数组
- 布尔数组
None
、np.newaxis
- Ellipsis(
...
)
序列索引和数组索引的标量元素只能为整数,因此下面将它们简称为序列索引和数组索引。
假设 A
是一个 N
维数组,不同的索引对象具有不同的规则:
- 整数索引意味着从对应的轴取出 1 个元素,因此,每多 1 个整数索引,索引结果就减少 1 个维度。
- N 维数组最多可以指定 N 个索引对象,但
None
可以突破这个限制;允许最多省略 N-1 个索引对象,即至少指定 1 个索引,省略的索引默认补充为切片:
。 - 索引元组最多只能包含一个
...
,...
代表其余轴的索引都为切片:
,如A[1,...,2]
。这是 NumPy 的一个快捷方式。 -
np.newaxis
是None
的别名,索引结果会在索引元组中出现None
的地方创建一个新轴,该轴长度为 1。None
的数量不受N
的限制,但不能无限多,因为 NumPy 数组最多只能有 32 个维度。 -
序列索引基于整数索引:
- 给
A
指定 1 个整数序列索引,即A[[i, j, z]]
,这相当于np.array([A[i], A[j], A[z]])
; - 给
A
指定 2 个整数序列索引,即使A[[a, b, c], [i, j, z]]
,这相当于np.array([A[a, i], A[b, j], A[c, z]])
; - 以此类推 n 个序列索引的情况。可以看出,指定 n 个序列索引会导致索引结果减少 n-1 个维度。
- 给
-
切片类似于序列索引,但切片不会改变结果数组的维度,因为多个序列索引之间是内积,而多个切片之间是笛卡儿积。
A[0:2:1]
相当于A[[0, 1]]
;A[0:2:1, 0:2:1]
相当于np.array([[A[0, 0], A[0, 1]], [A[1, 0], A[1, 1]]])
。 -
数组索引基于序列索引,假设
A.shape=(s1, s2, ..., sN)
:- 给
A
指定 1 个整数数组索引,即A[I]
,并且I.shape=(a, b)
, 这相当于A[I.flatten()].reshape(a, b, s2, ..., sN)
,I.flatten()
是整数序列; - 给
A
指定 2 个整数数组索引,即A[I, J]
,- 如果
I.shape=J.shape=(a, b)
,这相当于A[I.flatten(), J.flatten()].reshape(a, b, s3, ..., sN)
, 两个序列索引首先会导致原数组减少 1 个维度; - 如果
I.shape=(a, b)
、J.shape=(c, d)
,那么I
和J
会进行广播,得到shape=(e, f)
新的数组索引I2
和J2
, 这相当于A[I2.flatten(), J2.flatten()].reshape(e, f, s3, ..., sN)
;
- 如果
- 以此类推 n 个数组索引的情况。可以看出,数组索引首先会得到和序列索引一致的结果,但结果数组的第一个维度会被替换为索引数组的
shape
。 - 假设原数组的维度数为 n,有 m 整数数组索引,其广播后的维度数为 y,则结果数组的维度数为 n - (m-1) + (y-1)。
- 给
-
布尔数组索引的维度不必和原数组的维度一致,但不能超出原数组的维度,并且对应维度的大小必须和原数组一致,得到的结果是布尔数组中
True
对应的元素。
A = copy.deepcopy(np.arange(90).reshape((10, 3, 3)))
print(f"A shape = {A.shape}")
pprint.pprint(A)
A shape = (10, 3, 3)
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]],
[[27, 28, 29],
[30, 31, 32],
[33, 34, 35]],
[[36, 37, 38],
[39, 40, 41],
[42, 43, 44]],
[[45, 46, 47],
[48, 49, 50],
[51, 52, 53]],
[[54, 55, 56],
[57, 58, 59],
[60, 61, 62]],
[[63, 64, 65],
[66, 67, 68],
[69, 70, 71]],
[[72, 73, 74],
[75, 76, 77],
[78, 79, 80]],
[[81, 82, 83],
[84, 85, 86],
[87, 88, 89]]])
print(A[0, 1, 2])
print(A[(0, 1, 2)])
print(A[0][1][2])
pprint.pprint(A[0])
pprint.pprint(A[0, ...])
5
5
5
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
pprint.pprint(A[[0, 1]])
pprint.pprint(np.array([A[0], A[1]]))
print()
pprint.pprint(A[[0, 1], [0, 1]])
pprint.pprint(np.array([A[0, 0], A[1, 1]]))
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])
array([[ 0, 1, 2],
[12, 13, 14]])
array([[ 0, 1, 2],
[12, 13, 14]])
pprint.pprint(A[1, :])
print()
pprint.pprint(A[slice(0, 3, 2), 0:3:2])
pprint.pp(np.array([[A[0, 0], A[0, 2]], [A[2, 0], A[2, 2]]]))
# 与序列索引的区别
print()
pprint.pprint(A[[0,2], [0,2]])
array([[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]])
array([[[ 0, 1, 2],
[ 6, 7, 8]],
[[18, 19, 20],
[24, 25, 26]]])
array([[[ 0, 1, 2],
[ 6, 7, 8]],
[[18, 19, 20],
[24, 25, 26]]])
array([[ 0, 1, 2],
[24, 25, 26]])
print(f"A shape = {A.shape}")
index = np.array([0, 1, 2])
print(f"index shape = {index.shape}")
pprint.pprint(A[index])
print(f"result shape = {A[index].shape}")
print()
print(f"A shape = {A.shape}")
index = np.array([[[0, 1], [0, 1], [0, 1]]])
print(f"index shape = {index.shape}")
pprint.pprint(A[index])
print(f"result shape = {A[index].shape}")
A shape = (10, 3, 3)
index shape = (3,)
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
result shape = (3, 3, 3)
A shape = (10, 3, 3)
index shape = (1, 3, 2)
array([[[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]],
[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]],
[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]]]])
result shape = (1, 3, 2, 3, 3)
B = A[0:2]
bool_index = B > 10
print(f"B shape = {B.shape}")
print(f"bool index shape = {bool_index.shape}")
pprint.pprint(bool_index)
pprint.pprint(B[bool_index])
print()
print(f"A shape = {A.shape}")
bool_index = np.array([True, False, False, False, False,
False, False, False, True, False])
print(f"bool index shape = {bool_index.shape}")
pprint.pprint(A[bool_index])
B shape = (2, 3, 3)
bool index shape = (2, 3, 3)
array([[[False, False, False],
[False, False, False],
[False, False, False]],
[[False, False, True],
[ True, True, True],
[ True, True, True]]])
array([11, 12, 13, 14, 15, 16, 17])
A shape = (10, 3, 3)
bool index shape = (10,)
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[72, 73, 74],
[75, 76, 77],
[78, 79, 80]]])
索引结果与视图
整数索引和切片得到的结果数组是视图,而其他索引对象得到的结果数组都拥有自己的数据。
使用整数索引时,得到的结果是单个元素,自然可以直接引用原数组的数据,而无需复制数据。使用切片时,得到的数据在原数组存储区域中是等间隔分布的。因此,只需要修改数组的 ndim/shape/strides 等属性以及指向数据存储区域的 data 指针就能够实现切片索引。视图和原数组共享数据存储区域。
使用序列索引、数组索引和布尔数组时,不能保证所取得的数据在原数组存储区中是等间隔的,因此无法和原数组共享数据,只能对数据进行拷贝。
print("整数索引", A[1].flags["OWNDATA"])
print("切片索引", A[0:2].flags["OWNDATA"])
print("序列索引", A[[0, 1]].flags["OWNDATA"])
print("数组索引", A[np.array([[0, 1],
[0, 1]])].flags["OWNDATA"])
print("布尔数组", A[A > 50].flags["OWNDATA"])
整数索引 False
切片索引 False
序列索引 True
数组索引 True
布尔数组 True