跳转至

8.2.索引

import numpy as np
import copy
import pprint
Windows 10
Python 3.8.8 @ MSC v.1928 64 bit (AMD64)
Latest build date 2021.03.03
numpy version:  1.20.1

索引规则

NumPy 数组的索引功能很强大,但多样的索引也带来了复杂性和混乱性。新手可以很快地开始使用 NumPy 索引,但可能难以区分各种索引的情况。虽然,网上有很多关于 NumPy 数组索的资料,也有很多代码示例,但几乎都是比较简单的说明,大部分代码示例只会让人在简单与混乱之间徘徊。NumPy 索引的官方用户指南当然是很好的资料,但大多都是必要或细节性的说明。这里尝试对所有情况做一个简单但不失全面性的总结。

可以使用 Python 的 array[index] 语法对 ndarray 进行索引。假设 ndarray 是 N 维数组,其索引的形式为 array[idx1, idx2, ..., idxN],即每个轴都有一个索引。

在 Python 中,x[(exp1, exp2, ..., EXPN)] 等同于 x[exp1, exp2, ..., EXPN]。后者只是前者的语法糖。

这里将 (idx1,idx2,.,idxN) 称为索引元组,idxN直接称为索引或索引对象。NumPy 数组的索引 idxN 可以是以下 7 种对象(形式):

  1. 整数
  2. 切片(slice:
  3. 整数序列
  4. 整数数组
  5. 布尔数组
  6. Nonenp.newaxis
  7. Ellipsis(...

序列索引和数组索引的标量元素只能为整数,因此下面将它们简称为序列索引和数组索引。

假设 A 是一个 N 维数组,不同的索引对象具有不同的规则:

  1. 整数索引意味着从对应的轴取出 1 个元素,因此,每多 1 个整数索引,索引结果就减少 1 个维度。
  2. N 维数组最多可以指定 N 个索引对象,但 None 可以突破这个限制;允许最多省略 N-1 个索引对象,即至少指定 1 个索引,省略的索引默认补充为切片 :
  3. 索引元组最多只能包含一个 ...... 代表其余轴的索引都为切片 :,如A[1,...,2]。这是 NumPy 的一个快捷方式。
  4. np.newaxisNone 的别名,索引结果会在索引元组中出现 None 的地方创建一个新轴,该轴长度为 1。None 的数量不受 N 的限制,但不能无限多,因为 NumPy 数组最多只能有 32 个维度。

  5. 序列索引基于整数索引:

    • A 指定 1 个整数序列索引,即 A[[i, j, z]],这相当于 np.array([A[i], A[j], A[z]])
    • A 指定 2 个整数序列索引,即使A[[a, b, c], [i, j, z]],这相当于 np.array([A[a, i], A[b, j], A[c, z]])
    • 以此类推 n 个序列索引的情况。可以看出,指定 n 个序列索引会导致索引结果减少 n-1 个维度。
  6. 切片类似于序列索引,但切片不会改变结果数组的维度,因为多个序列索引之间是内积,而多个切片之间是笛卡儿积。A[0:2:1] 相当于 A[[0, 1]]A[0:2:1, 0:2:1] 相当于 np.array([[A[0, 0], A[0, 1]], [A[1, 0], A[1, 1]]])

  7. 数组索引基于序列索引,假设 A.shape=(s1, s2, ..., sN)

    • A 指定 1 个整数数组索引,即 A[I],并且 I.shape=(a, b), 这相当于 A[I.flatten()].reshape(a, b, s2, ..., sN)I.flatten()是整数序列;
    • A 指定 2 个整数数组索引,即 A[I, J]
      • 如果 I.shape=J.shape=(a, b),这相当于 A[I.flatten(), J.flatten()].reshape(a, b, s3, ..., sN), 两个序列索引首先会导致原数组减少 1 个维度;
      • 如果 I.shape=(a, b)J.shape=(c, d),那么 IJ 会进行广播,得到 shape=(e, f) 新的数组索引 I2J2, 这相当于 A[I2.flatten(), J2.flatten()].reshape(e, f, s3, ..., sN)
    • 以此类推 n 个数组索引的情况。可以看出,数组索引首先会得到和序列索引一致的结果,但结果数组的第一个维度会被替换为索引数组的 shape
    • 假设原数组的维度数为 n,有 m 整数数组索引,其广播后的维度数为 y,则结果数组的维度数为 n - (m-1) + (y-1)。
  8. 布尔数组索引的维度不必和原数组的维度一致,但不能超出原数组的维度,并且对应维度的大小必须和原数组一致,得到的结果是布尔数组中 True 对应的元素。

A = copy.deepcopy(np.arange(90).reshape((10, 3, 3)))
print(f"A shape = {A.shape}")
pprint.pprint(A)
A shape = (10, 3, 3)
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]],

       [[27, 28, 29],
        [30, 31, 32],
        [33, 34, 35]],

       [[36, 37, 38],
        [39, 40, 41],
        [42, 43, 44]],

       [[45, 46, 47],
        [48, 49, 50],
        [51, 52, 53]],

       [[54, 55, 56],
        [57, 58, 59],
        [60, 61, 62]],

       [[63, 64, 65],
        [66, 67, 68],
        [69, 70, 71]],

       [[72, 73, 74],
        [75, 76, 77],
        [78, 79, 80]],

       [[81, 82, 83],
        [84, 85, 86],
        [87, 88, 89]]])
print(A[0, 1, 2])
print(A[(0, 1, 2)])
print(A[0][1][2])
pprint.pprint(A[0])
pprint.pprint(A[0, ...])
5
5
5
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
pprint.pprint(A[[0, 1]])
pprint.pprint(np.array([A[0], A[1]]))
print()
pprint.pprint(A[[0, 1], [0, 1]])
pprint.pprint(np.array([A[0, 0], A[1, 1]]))
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]])
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]])

array([[ 0,  1,  2],
       [12, 13, 14]])
array([[ 0,  1,  2],
       [12, 13, 14]])
pprint.pprint(A[1, :])
print()
pprint.pprint(A[slice(0, 3, 2), 0:3:2])
pprint.pp(np.array([[A[0, 0], A[0, 2]], [A[2, 0], A[2, 2]]]))
# 与序列索引的区别
print()
pprint.pprint(A[[0,2], [0,2]])
array([[ 9, 10, 11],
       [12, 13, 14],
       [15, 16, 17]])

array([[[ 0,  1,  2],
        [ 6,  7,  8]],

       [[18, 19, 20],
        [24, 25, 26]]])
array([[[ 0,  1,  2],
        [ 6,  7,  8]],

       [[18, 19, 20],
        [24, 25, 26]]])

array([[ 0,  1,  2],
       [24, 25, 26]])
print(f"A shape = {A.shape}")
index = np.array([0, 1, 2])
print(f"index shape = {index.shape}")
pprint.pprint(A[index])
print(f"result shape = {A[index].shape}")
print()
print(f"A shape = {A.shape}")
index = np.array([[[0, 1], [0, 1], [0, 1]]])
print(f"index shape = {index.shape}")
pprint.pprint(A[index])
print(f"result shape = {A[index].shape}")
A shape = (10, 3, 3)
index shape = (3,)
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])
result shape = (3, 3, 3)

A shape = (10, 3, 3)
index shape = (1, 3, 2)
array([[[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]]],


        [[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]]],


        [[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]]]]])
result shape = (1, 3, 2, 3, 3)
B = A[0:2]
bool_index = B > 10
print(f"B shape = {B.shape}")
print(f"bool index shape = {bool_index.shape}")
pprint.pprint(bool_index)
pprint.pprint(B[bool_index])
print()
print(f"A shape = {A.shape}")
bool_index = np.array([True, False, False, False, False,
                       False, False, False, True, False])
print(f"bool index shape = {bool_index.shape}")
pprint.pprint(A[bool_index])
B shape = (2, 3, 3)
bool index shape = (2, 3, 3)
array([[[False, False, False],
        [False, False, False],
        [False, False, False]],

       [[False, False,  True],
        [ True,  True,  True],
        [ True,  True,  True]]])
array([11, 12, 13, 14, 15, 16, 17])

A shape = (10, 3, 3)
bool index shape = (10,)
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[72, 73, 74],
        [75, 76, 77],
        [78, 79, 80]]])

索引结果与视图

整数索引和切片得到的结果数组是视图,而其他索引对象得到的结果数组都拥有自己的数据。

使用整数索引时,得到的结果是单个元素,自然可以直接引用原数组的数据,而无需复制数据。使用切片时,得到的数据在原数组存储区域中是等间隔分布的。因此,只需要修改数组的 ndim/shape/strides 等属性以及指向数据存储区域的 data 指针就能够实现切片索引。视图和原数组共享数据存储区域。

使用序列索引、数组索引和布尔数组时,不能保证所取得的数据在原数组存储区中是等间隔的,因此无法和原数组共享数据,只能对数据进行拷贝。

print("整数索引", A[1].flags["OWNDATA"])
print("切片索引", A[0:2].flags["OWNDATA"])
print("序列索引", A[[0, 1]].flags["OWNDATA"])
print("数组索引", A[np.array([[0, 1],
                          [0, 1]])].flags["OWNDATA"])
print("布尔数组", A[A > 50].flags["OWNDATA"])
整数索引 False
切片索引 False
序列索引 True
数组索引 True
布尔数组 True