NumPy 使用布尔数组索引
当我们使用(整数)索引数组索引数组时,我们提供了要选择的索引列表。对于布尔索引,方法是不同的;我们明确地选择数组中的哪些项是我们想要的,哪些是我们不想要的。
对于布尔索引,人们可以想到的最自然的方法是使用与原始数组具有相同形状的布尔数组:
>>> a = np.arange(12).reshape(3, 4)
>>> b = a > 4
>>> b # `b` is a boolean with `a`'s shape
array([[False, False, False, False],
[False, True, True, True],
[ True, True, True, True]])
>>> a[b] # 1d array with the selected elements
array([ 5, 6, 7, 8, 9, 10, 11])
这个属性在赋值中非常有用:
>>> a[b] = 0 # All elements of `a` higher than 4 become 0
>>> a
array([[0, 1, 2, 3],
[4, 0, 0, 0],
[0, 0, 0, 0]])
可以查看以下示例来了解如何使用布尔索引生成Mandelbrot 集的图像:
import numpy as np
>>> import matplotlib.pyplot as plt
>>> def mandelbrot(h, w, maxit=20, r=2):
... """Returns an image of the Mandelbrot fractal of size (h,w)."""
... x = np.linspace(-2.5, 1.5, 4*h+1)
... y = np.linspace(-1.5, 1.5, 3*w+1)
... A, B = np.meshgrid(x, y)
... C = A + B*1j
... z = np.zeros_like(C)
... divtime = maxit + np.zeros(z.shape, dtype=int)
...
... for i in range(maxit):
... z = z**2 + C
... diverge = abs(z) > r # who is diverging
... div_now = diverge & (divtime == maxit) # who is diverging now
... divtime[div_now] = i # note when
... z[diverge] = r # avoid diverging too much
...
... return divtime
>>> plt.imshow(mandelbrot(400, 400))
使用布尔值索引的第二种方式更类似于整数索引;对于数组的每个维度,我们给出一个一维布尔数组,选择我们想要的切片:
>>> a = np.arange(12).reshape(3, 4)
>>> b1 = np.array([False, True, True]) # first dim selection
>>> b2 = np.array([True, False, True, False]) # second dim selection
>>>
>>> a[b1, :] # selecting rows
array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>>
>>> a[b1] # same thing
array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>>
>>> a[:, b2] # selecting columns
array([[ 0, 2],
[ 4, 6],
[ 8, 10]])
>>>
>>> a[b1, b2] # a weird thing to do
array([ 4, 10])
请注意,一维布尔数组的长度必须与要切片的维度(或轴)的长度一致。在前面的例子中,b1
具有长度为3(的数目的行中a
),和 b2
(长度4)适合于索引a
的第二轴线(列) 。