Arantium Maestum

プログラミング、囲碁、読書の話題

Clojure機械学習勉強 - Gorilla REPLと線形回帰 (その2)

昨日のコードを2点修正:

  1. 前回は独立変数を一つだけ用意したが、今回は二つ。さらに、今後独立変数の数の増減はX値と母数のθ値だけアップデートすれば、他のコードは変更なしで機能するようリファクタした。
  2. core.matrixの実装を、標準実装のndarray、vectorz-cljそしてclatrixの三種類から選べるようにした。

そのコードがこれ:

Gorilla REPL viewer

一番重い処理が、30000回更新したθ値を求める以下のコード:

(doall (theta-after-n-iteration 30000))

まずはndarrayで実行してみた場合:

"Elapsed time: 16456.788619 msecs" [2.0190228761901188 5.246421586567178 2.879450293240761] (抜粋)

vectorz:

"Elapsed time: 2122.331857 msecs" [1.881619902630865,4.979437022948518,2.8889278718090208]

頭の方で

(m/set-current-implementation :vectorz)

と指定するだけで十倍近く早くなっている。

そしてclatrix。なんと全く同じコードで:

Exception thrown: clojure.lang.ExceptionInfo (Can't broadcast to a lower dimensional shape)

と例外投げて死んでくれる。

多分clatrixだと(* A B)といった処理はAの次元数がB以下(というかAがスカラー値かベクトルかBと全く同じ次元である)必要があるのに対して、他の実装ではBがスカラーかベクトルで他の次元が合えば、BをAに合わせてbroadcastしてくれていたようだ。(多分。逆かな?)

同じcore.matrixのAPIを実装しているはずなのに、こういうところで違いが出るのは面白くもあり、めんどくさくもあり。少なくとも「APIが同じだから入れ替えは一行で済むよ!」というのは盛りすぎなようである。