Pythonで数値計算のコツ:for文書いたら負けかなと思っている

By | 2013年12月20日

転職してから1年とちょっとが経ち、Pythonをメイン言語としてからも同じくらいが経った。最近やっとnumpy/scipyの使い方のコツがわかってきたと思うので、マサカリ飛んでくるのを覚悟でなんか書いてみようと思う。

転職して初めてPythonを使ったというわけではない(実際wafのwscriptとかは書いたことある)が、まあでもほぼ初心者同然だった。学習曲線でいうとPythonはすごく良い言語だと思う。Python本体の言語仕様については、わりとすぐに覚えることができた。だが一方、numpy/scipyについては、そう簡単ではなく習得するにはそれなりに時間がかかったと思う。

ケーススタディ

たとえば\(N\times M\)行列\(B\), \( M\times L \)行列\( C \), \( M \)次元ベクトル\(a=(a_k)_{1\leq k \leq M}\)が与えられて
$$r_{ij}=\sum_{k=1}^{M} a_k b_{ik}c_{kj}$$
により行列\(R=(r_{ij})\)を計算したいとしよう。

例えばC言語で書くなら(C#やJavaでもほぼ同様に)

for (i=0; i<N; i++) {
  for (j=0; j<L; j++) {
    d=0.0;
    for (k=0; k<M; k++) {
      d+=a[k]*b[i][k]*c[k][j];
    }
    r[i][j]=d;
  }
}

と書くのは割と自然なはずだし、最速ではなかったとしても計算速度もそれなりに出ると思うのだが、Pythonでこれを「直訳」して以下のようにしてしまうのは最悪である。

def compute1(N,M,L,a,b,c):
    r=np.empty((N,L))
    for i in xrange(N):
        for j in xrange(L):
            d=0.0
            for k in xrange(M):
                d+=a[k]*b[i,k]*c[k,j]
            r[i,j]=d
    return r

Pythonでは、for文でループしながら配列内の要素を参照するのがとても遅いので、行列積の関数などを使って配列要素への直接参照を減らしたほうが高速になる。

そこで、
$$R = \sum_{k=1}^M a_k b_{\cdot k} c_{k\cdot}$$
(ただし、ここで\(b_{\cdot k}\)は\(B\)の\(k\)列目列ベクトル、\(c_{k\cdot}\)は\(C\)の\(k\)行目行ベクトル)と同値な式を考えてやると、numpyの行列積の関数が使えて高速になる。そのコードがこちら。

def compute2(N,M,L,a,b,c):
    return sum([a[k]*np.outer(b[:,k],c[k,:]) for k in xrange(M)])

ところがもっと高速にする方法がある。
$$
A=\begin{pmatrix}
a_1 & & & \\
&a_2 & & \\
& & \ddots & \\
& & &a_L
\end{pmatrix}
$$
という行列を考えてやると、\(R\)の式は
$$R=BAC$$
と同値である。この\(A\)を疎行列として構成して計算すれば、\(L\)が大きい時でもメモリを大量消費することもない。なので、ここでscipy.sparseを使う。そのコードはこちら。

def compute3(N,M,L,a,b,c):
    aa=sparse.diags([a],[0])
    return np.dot(b,aa.dot(c))

ここで面白いのは、1次元配列を疎行列形式に詰め直すのは余計オーバーヘッドがかかるような気がするが、それをカバーしてあまりあるほど疎行列✕密行列の関数が高速だということ。

とここまで自力で考えたのだが、最初のポストのあとにこんな指摘があった。


要素積の仕様とか、全くわかってなかった。やっぱりnumpy難しい。ここらへんはまた勉強しなおしてブログでも書こうかと思う。そして、指摘していただいたしましまさん、ありがとうございました。(望みどおりマサカリが飛んできたわけですが…)

なので、こうするのが一番いいらしい。

def compute4(N,M,L,a,b,c):
    return np.dot(b*a,c)

では実際にベンチマークをしてみる。ベンチマーク用のコード(全体)はこうなる。

import numpy as np
import scipy.sparse as sparse
import time

def compute1(N,M,L,a,b,c):
    r=np.empty((N,L))
    for i in xrange(N):
        for j in xrange(L):
            d=0.0
            for k in xrange(M):
                d+=a[k]*b[i,k]*c[k,j]
            r[i,j]=d
    return r

def compute2(N,M,L,a,b,c):
    return sum([a[k]*np.outer(b[:,k],c[k,:]) for k in xrange(M)])

def compute3(N,M,L,a,b,c):
    aa=sparse.diags([a],[0])
    return np.dot(b,aa.dot(c))

def compute4(N,M,L,a,b,c):
    return np.dot(b*a,c)

def main():
    np.random.seed(0)
    N=10
    M=10000
    L=20
    a=np.random.random(M)
    N_ITER=10
    b=np.random.random((N,M))
    c=np.random.random((M,L))
    t=time.time()
    for _ in xrange(N_ITER):
        r1=compute1(N,M,L,a,b,c)
    tt=time.time()
    print "compute1 : %.3f sec" % (tt-t)
    t=time.time()
    for _ in xrange(N_ITER):
        r2=compute2(N,M,L,a,b,c)
    tt=time.time()
    print "compute2 : %.3f sec" % (tt-t)
    t=time.time()
    for _ in xrange(N_ITER):
        r3=compute3(N,M,L,a,b,c)
    tt=time.time()
    print "compute3 : %.3f sec" % (tt-t)
    t=time.time()
    for _ in xrange(N_ITER):
        r4=compute4(N,M,L,a,b,c)
    tt=time.time()
    print "compute4 : %.3f sec" % (tt-t)
    # Confirm the results are the same
    eps=1e-10
    y=(r1-r2).reshape(N*L)
    assert np.dot(y,y)<eps*N*L
    y=(r1-r3).reshape(N*L)
    assert np.dot(y,y)<eps*N*L
    y=(r1-r4).reshape(N*L)
    assert np.dot(y,y)<eps*N*L

main()

実行結果はこうなった。

compute1 : 19.018 sec
compute2 : 1.546 sec
compute3 : 0.030 sec
compute4 : 0.021 sec

このように計算時間で大差が出るのは、Pythonの数値計算系ライブラリは内部でFORTRANやCで書かれているからで、計算過程ではできるだけPython側でデータを取り出さない方が速くなる。

まとめ

Pythonで行列・ベクトル関連の計算を速くするには以下のようなことを気をつけるとよい。

  • できるだけ多次元配列や疎行列のデータ型に入れてからライブラリ関数で計算する。計算中にPython側での要素へのアクセスはでるだけ避ける。
  • そのために前処理が重くなっても、多少メモリを散らかしても気にしない。結局安くつくことが多い
  • コードを書く前に代数的に同値な変形を考え、行列の積・和だけで表現できないか考える。そのとき、疎行列もうまく活用する。

つまり、for文書いたら負けかなと思っている。リスト内包表記もできれば避けたい。

あと、そのコード誰が保守するんだ?っていう質問には、聞こえないふりで対応する。

更新履歴:
2013/12/20 23:31 しましまさんの指摘を受け、加筆しました。

One thought on “Pythonで数値計算のコツ:for文書いたら負けかなと思っている

  1. Pingback: 2013/12/21(Sat) | gattya.run

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です