Clojure機械学習勉強 - Gorilla REPLと線形回帰 (その2)
昨日のコードを2点修正:
- 前回は独立変数を一つだけ用意したが、今回は二つ。さらに、今後独立変数の数の増減はX値と母数のθ値だけアップデートすれば、他のコードは変更なしで機能するようリファクタした。
- core.matrixの実装を、標準実装のndarray、vectorz-cljそしてclatrixの三種類から選べるようにした。
そのコードがこれ:
一番重い処理が、30000回更新したθ値を求める以下のコード:
(doall (theta-after-n-iteration 30000))
まずはndarrayで実行してみた場合:
"Elapsed time: 16456.788619 msecs" [2.0190228761901188 5.246421586567178 2.879450293240761] (抜粋)
vectorz:
"Elapsed time: 2122.331857 msecs" [1.881619902630865,4.979437022948518,2.8889278718090208]
頭の方で
(m/set-current-implementation :vectorz)
と指定するだけで十倍近く早くなっている。
そしてclatrix。なんと全く同じコードで:
Exception thrown: clojure.lang.ExceptionInfo (Can't broadcast to a lower dimensional shape)
と例外投げて死んでくれる。
多分clatrixだと(* A B)
といった処理はAの次元数がB以下(というかAがスカラー値かベクトルかBと全く同じ次元である)必要があるのに対して、他の実装ではBがスカラーかベクトルで他の次元が合えば、BをAに合わせてbroadcastしてくれていたようだ。(多分。逆かな?)
同じcore.matrixのAPIを実装しているはずなのに、こういうところで違いが出るのは面白くもあり、めんどくさくもあり。少なくとも「APIが同じだから入れ替えは一行で済むよ!」というのは盛りすぎなようである。