先日ありがたいマサカリを頂いたのを機に、numpyのブロードキャスティングについてあまり理解してなかったなと思い、改めてまとめてみることにする。以下自分の理解をまとめたものだが、一応初心者向け解説のつもり。わかってる人は読まなくていい。
解説
Numpyでは例えば、以下のように1次元配列とスカラー値の演算ができる。
>>> from numpy import * >>> a=arange(5) >>> a array([0, 1, 2, 3, 4]) >>> a*5 array([ 0, 5, 10, 15, 20])
こういうのをブロードキャスティングと呼ぶ。
これは、2次元と1次元の計算に限らず、また掛け算に限らず他の四則演算でも似たようなことができる。なので、以下足し算に限定して例示する。例えばこの場合。
>>> a=arange(10,130,10).reshape(4,3) >>> a array([[ 10, 20, 30], [ 40, 50, 60], [ 70, 80, 90], [100, 110, 120]]) >>> b=arange(1,4) >>> b array([1, 2, 3]) >>> a+b array([[ 11, 22, 33], [ 41, 52, 63], [ 71, 82, 93], [101, 112, 123]])
これは式で表すと、
$$ c_{i,j}= a_{i,j} + b_j (0\leq i <4, 0\leq j<3)$$
となる。この計算はaのシェイプ(4,3)の2つめの軸のインデックスの数(この場合3)がbのインデックスの数と一致してる場合に限られる。そうでないと次のようにエラーになる。
[text]
>>> a=arange(10,130,10).reshape(4,3)
>>> b=arange(5)
>>> a+b
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (4,3) (5)
[/text]
この場合、aのシェイプが(3,4,1)ならエラーにならない。
[text]
>>> a=arange(10,130,10).reshape(4,3,1)
>>> b=arange(5)
>>> c=a+b
>>> c
array([[[ 10, 11, 12, 13, 14],
[ 20, 21, 22, 23, 24],
[ 30, 31, 32, 33, 34]],
[[ 40, 41, 42, 43, 44],
[ 50, 51, 52, 53, 54],
[ 60, 61, 62, 63, 64]],
[[ 70, 71, 72, 73, 74],
[ 80, 81, 82, 83, 84],
[ 90, 91, 92, 93, 94]],
[[100, 101, 102, 103, 104],
[110, 111, 112, 113, 114],
[120, 121, 122, 123, 124]]])
>>> c.shape
(4, 3, 5)
[/text]
これは何をやっているかというと、\(\{a_{i,j,0}\}_{0\leq i <4, 0\leq j <3}\), \(\{b_k\}_{0\leq k <5}\)(aが(4,3,1)という形であることを示すためにわざと余計なインデックス0を加えている)に対し、\(\{c_{i,j,k}\}_{0\leq i <4, 0\leq j <3, 0\leq k <5}\)を、
$$c_{i,j,k}=a_{i,j,0} + b_k$$
で計算していることになる。
では、aが(4,1,3)というシェイプの時に、1次元配列bを同じように足し算したいときは、そのままではできなくて、bを(5,1)という形に変形するとできるようになる。
[text]
>>> a=arange(10,130,10).reshape(4,1,3)
>>> b=arange(5)
>>> c=a+b
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (4,1,3) (5)
>>> b=b.reshape(5,1)
>>> c=a+b
>>> c.shape
(4, 5, 3)
>>> c
array([[[ 10, 20, 30],
[ 11, 21, 31],
[ 12, 22, 32],
[ 13, 23, 33],
[ 14, 24, 34]],
[[ 40, 50, 60],
[ 41, 51, 61],
[ 42, 52, 62],
[ 43, 53, 63],
[ 44, 54, 64]],
[[ 70, 80, 90],
[ 71, 81, 91],
[ 72, 82, 92],
[ 73, 83, 93],
[ 74, 84, 94]],
[[100, 110, 120],
[101, 111, 121],
[102, 112, 122],
[103, 113, 123],
[104, 114, 124]]])
[/text]
この場合、\(a_{i,0,k}\)と\(b_{j,0}\)に対して
$$ c_{i,j,k}=a_{i,0,k}+b_{j,0}$$
を計算していることになる。
これは一般の次元で考えられるが、その場合の規則は、末尾の側から見ていって1)次元の大きさが一致している場合、または2)片方の次元の大きさが1の場合、についてブロードキャスティング演算ができることになっている。
以下の例は本家ページからの抜粋だが、OKなケースとしては以下の例がある。
A (4d array): 8 x 1 x 6 x 1
B (3d array): 7 x 1 x 5
Result (4d array): 8 x 7 x 6 x 5
A (2d array): 5 x 4
B (1d array): 1
Result (2d array): 5 x 4
A (2d array): 5 x 4
B (1d array): 4
Result (2d array): 5 x 4
A (3d array): 15 x 3 x 5
B (3d array): 15 x 1 x 5
Result (3d array): 15 x 3 x 5
A (3d array): 15 x 3 x 5
B (2d array): 3 x 5
Result (3d array): 15 x 3 x 5
A (3d array): 15 x 3 x 5
B (2d array): 3 x 1
Result (3d array): 15 x 3 x 5
ダメなケースとしては以下の例がある。
A (1d array): 3
B (1d array): 4 # 最後の次元が一致してない
A (2d array): 2 x 1
B (3d array): 8 x 4 x 3 # 最後から2つ目の次元が一致してない
こういうのを見ると、2次元配列の掛け算(*)がなぜ行列の意味の掛け算ではなくて要素ごとの掛け算になっているかがわかると思う。それは、ブロードキャスティングの特殊な場合にすぎない。
最後にちょっと複雑な例:
A : 4 x 1 x 2 x 1
B : 3 x 2 x 3
Result: 4 x 3 x 2 x 3
について計算例を示す。
これは、\(\{a_{i,0,k,0}\}_{0\leq i<4, 0\leq k<2}\)と\(\{b_{j,k,l}\}_{0\leq j<3, 0\leq k<2,\; 0\leq l<3}\)に対して $$ c_{i,j,k,l}=a_{i,0,k,0}+b_{j,k,l} $$ を計算することを意味する。ここで、3つ目のインデックスだけが両方で使われていることに注目する。 実際の計算は以下の通り。 [text] >>> a=arange(10,90,10).reshape(4,1,2,1) >>> a=arange(100,900,100).reshape(4,1,2,1) >>> b=arange(18).reshape(3,2,3) >>> a array([[[[100], [200]]], [[[300], [400]]], [[[500], [600]]], [[[700], [800]]]]) >>> b array([[[ 0, 1, 2], [ 3, 4, 5]], [[ 6, 7, 8], [ 9, 10, 11]], [[12, 13, 14], [15, 16, 17]]]) >>> c=a+b >>> c array([[[[100, 101, 102], [203, 204, 205]], [[106, 107, 108], [209, 210, 211]], [[112, 113, 114], [215, 216, 217]]], [[[300, 301, 302], [403, 404, 405]], [[306, 307, 308], [409, 410, 411]], [[312, 313, 314], [415, 416, 417]]], [[[500, 501, 502], [603, 604, 605]], [[506, 507, 508], [609, 610, 611]], [[512, 513, 514], [615, 616, 617]]], [[[700, 701, 702], [803, 804, 805]], [[706, 707, 708], [809, 810, 811]], [[712, 713, 714], [815, 816, 817]]]]) >>> c.shape (4, 3, 2, 3) [/text] これではぱっと見ではわかりづらいと思うが、3つ目のインデックスだけが重なっていることの影響は以下のように確認できる。 [text] >>> c[:,0,0,1] array([101, 301, 501, 701]) >>> c[0,0,:,1] array([101, 204]) [/text] つまりc[:,0,0,1]では、すべて一の位と十の位が同じだが、c[0,0,:,1]については一の位と十の位が要素によって異なっている。つまり1つ目のインデックスが変わっても加算されるbの要素は変わらないが、3つ目のインデックスについては、インデックスの値によって加算されるbの要素が変わっていることが確認できる。
感想
一見複雑な仕様に見えたが、実際に手を動かしてみるとさほど難しくないし、自然な仕様にも思えてくる。でもどこがどう「自然」なのかはまだうまく日本語で説明できない。
参考文献
- 本家のマニュアル:Broadcasting – NumPy v1.8 Manual
- わかりやすい解説スライド:NumPy MedKit – Stefan van der Walt(これもしましまさんに教えてもらった)