Clojure入門 - Project Eulerを解いてみる 問30
二桁以上の整数で、その数の各桁の5乗の和と等しいものを選んで足し合わせる。
基本戦略としては、以下のように問題文をほぼなぞるようなコードにする。
(->> (range 10 upper-limit) (filter #(= % (digit-power-sum %))) (apply +))
何らかの上限を設定すること以外は問題文通り。とりあえず、upper-limitとdigit-power-sumを定義できれば解決できる。
まずは定数。各桁を5乗するという部分を「N乗する」と変更して、プログラムの一番上でNを定義しておく。
(def N 5)
この数値を4とか3とか6とかにしても大丈夫なように書いていく。(4だと例文の通り19316と出るようにする)
乗数のヘルパー関数。
(defn power [a b] (if (zero? b) 1 (apply *' (repeat b a))))
Math/powだとfloatを返すので、そこからint変換するのが嫌。Overflowの場合は別の整数タイプになるし。以下のコードだと*'
を使って、Overflow時に自動的に大きな整数タイプになる。
各桁のN乗の和を計算する関数。
(defn digit-power-sum [x] (let [char-to-int (fn [c] (- (int c) 48))] (->> x str (map char-to-int) (map #(power % N)) (apply +))))
他のところでは使わないのでletでchar-to-intを定義している。ちなみに、(- (int c) 48)
と書く方が(Character/digit c 10)
に比べて30倍近く早かった。あとは、数値を文字列化、各桁の数のリスト化、N乗して足し合わせる。
さて、では上限について。すべての桁が9の数とその桁のN乗の和を比較した時に前者の方が大きかった場合、その数は上限になる。
なので、まず9、99、999といった数列を定義して、その数列の要素と要素の桁のN乗の和を比較して上限を割り出す。
(def all-digits-nines (iterate #(+ 9 (*' % 10)) 9)) (def upper-limit (->> all-digits-nines (map (juxt identity digit-power-sum)) (drop-while #(apply < %)) first first))
9、99...の数列が95×1、95×2...の数列より大きくなった最初の数を上限として定める。
(map (juxt identity f) [a b c])
で([a (f a)] [b (f b)] [c (f c)])
というリストが帰ってくるのが便利。
あとは上記の通り組み合わせるだけ…
(->> (range 10 upper-limit) (filter #(= % (digit-power-sum %))) (apply +))
なのだが、これだと5秒ほどかかる。並列化でもう少し減らせるか。
(->> (range 10 upper-limit) (pmap #(if (= % (digit-power-sum %)) %)) (remove nil?) (apply +))
超単純に並列化してみる。filterをpmap+remove nil?に変えただけ。
数字一つ一つの計算を別々のコアで走るスレッドに回していることになる。これだと所要時間は4秒ほどに。一つ一つの数の計算が独立しているのはいいのだけど、スレッドに回している作業の粒度が小さすぎてオーバーヘッドがかなりかかっている。いっぺんに1000ずつの数を回してみる。
(->> (range 10 upper-limit) (partition 1000) (pmap (fn [chunk] (->> chunk (filter #(= % (digit-power-sum %))) (apply +)))) (apply +))
これだと1.5秒。あまり複雑なこともせずに並列化のメリットが大きく出て嬉しい。
追記(2016年5月5日):
おっと危ない、書いた時はちゃんと知らなかったのだがpartitionだと端数に当たる要素は切り捨てになる。つまり、最後に1000に満たない要素が残った場合、なかったものと扱われる。上記のコードは運良くその最後の1000の数字にdigit-power-sumに該当するものがなかったから上手くいったわけだ。
切り捨てをしないためにはpartition-allを使う。
(->> (range 10 upper-limit) (partition-all 1000) (pmap (fn [chunk] (->> chunk (filter #(= % (digit-power-sum %))) (apply +)))) (apply +))