Thinking Functionally with Haskell勉強メモ: 第6章4 Maximum Segment Sum
第6章の最後はJohn BentleyのProgramming Pearlsに出てくる「Maximum Segment Sum」を今までのような式変換と証明で解く、という演習。
Maximum Segment Sum問題
あるリスト
(
は整数)の中にある連続部分の和の最大値を求める関数
mssを定義せよただし空の連続部分の和は0だとする
例えば
mss [-1, 2, -3, 5, -2, 1, 3, -2, -2, -3, 6] = sum [5, -2, 1, 3] = 7
含まれる整数は負の値を取ることがあるが、
mssは負にはならない(空リストの和が0なので)。
問題をコードで表現
愚直に問題をコードで書き移す:
mss :: [Int] -> Int mss = maximum . map sum . segments segments :: [a] -> [[a]] segments = concat . map inits . tails inits :: [a] -> [[a]] inits [] = [[]] inits (x:xs) = [] : map (x:) (inits xs) tails :: [a] -> [[a]] tails [] = [[]] tails (x:xs) = (x:xs) : tails xs
mss関数は
- 引数に
[Int]をとり - その連続部分すべてを列挙した
[[Int]]を - 各部分の和のリスト
[Int]に変換してから - そのリストの最大値
Intを返す。
segments関数は[Int]を、その連続部分全ての列挙である[[Int]]にしている。
initsは[x, y, z]を[[], [x], [x, y], [x, y, z]]に変換。(前回のブログでscanlの定義に使用したのを再掲)
tailsは[x, y, z]を[[x, y, z], [y, z], [z], []]に変換。
segmentsの挙動を追うと:
[x, y, z] -> tails [[x, y, z], [y, z], [z], []] -> map inits [[[x, y, z], [x, y], [x], []], [[y, z], [y], []], [[z], []], [[]]] -> concat [[x, y, z], [x, y], [x], [], [y, z], [y], [], [z], [], []]
空リストがやけに重複してしまうが、それ以外では過不足なく連続部分を列挙できていることがわかる。
さて、これで正しい答えが値となるコードは書けた。
しかし、効率面を見ると、tails、initsが引数の長さnに対してO(n)個のリストを返し、sumは時間がO(n)となっているので、O(n**3)になっている。
ここからEquational Reasoningで効率化していく。
O(n**2)へ
これから使用する法則を挙げる:
map f . concat = concat . map (map f) map f . map g = map (f . g) maximum . concat = maximum . map maximum -- ただし引数が「空ではないリスト」のリストの場合 scanl f e = map (foldl f e) . inits sum = foldr (+) 0 = foldl (+) 0 scanl (+) 0 = map sum . inits
これらを適用していくと:
maximum . map sum . concat . map inits . tails = {map f . concat = concat . map (map f)} maximum . concat . map (map sum) . map inits . tails = {maximum . concat = maximum . map maximum} maximum . map maximum . map (map sum) . map inits . tails = {map f . map g = map (f . g)} maximum . map (maximum . map sum . inits) . tails = {map sum . inits = scanl (+) 0} maximum . map (maximum . scanl (+) 0) . tails
initsとsumをscanlに組み合わせられたおかげでこの部分をO(n2)からO(n)に落とすことができた。リストを大量に作って一つ一つ足し合わせるのは重複した計算が大量にあり、先頭から足していって途中経過を出力していくことでその重複を避けることができる。
ここまでで全体ではO(n**2)。
maximumをfoldr1 maxに
foldr1というのはfoldrに似ているが、空リストではエラーになる関数。エラーにならない関数がすでにあるのになぜわざわざ、と思うかもしれないが、例えばmaximumなどはそもそも空リストを受け取った場合何を返していいか不明なのでエラーになるのが正しい。
foldr1の定義と、foldr1と二項演算子maxを使ったmaximumの定義:
foldr1 :: (a -> b -> b) -> [a] -> b foldr1 f [x] = x foldr1 f (x:xs) = f x (foldr1 f xs) maximum = foldr1 max
なのでmssに戻ると、計算量はかわらないがmap内のmaximumを変換しておく:
maximum . map (maximum . scanl (+) 0) . tails = {maximum = foldr1 max} maximum . map (foldr1 max . scanl (+) 0) . tails
あとで効いてくる。
scanlをfoldrに
scanl (+) 0を何か別の形に変換してみたい。
まずは具体例を入れて挙動からヒントを:
scanl (+) 0 [x, y, z] = [0, x, x+y, x+y+z] = 0 : map (x+) [0, y, y+z] = 0 : map (x+) (scanl (+) 0 [y, z]) scanl (+) 0 [x, y, z] = 0 : map (x+) (scanl (+) 0 [y, z])
これだとfoldrの形に似ている:
scanl (+) 0 (x:xs) = 0 : map (x+) (scanl (+) 0 xs) foldr f e (x:xs) = f x (foldr f e xs)
一般化すると、scanlとfoldrの間に以下の変換則が成り立つ:
scanl (@) e = foldr f [e] where f x xs = e : map (x@) xs
ただし(@)はeを単位元にもち、結合律を満たす関数。
これでmssは:
maximum . map (foldr1 max . scanl (+) 0) . tails = {scanl (@) e = foldr f [e]} maximum . map (foldr1 max . foldr f [0]) . tails where f x xs = e : map (x+) xs
となる。これで何が嬉しいかというと、mapの中身がf . foldr g aの形になった事実。
fusion law
つまりfusion law、f . foldr g a = foldr h bの出番である。
foldr1 max . foldr f [0] where f x xs = e : map (x+) xs
を対象に、fusion law適用の条件を満たしているかを確認し、またfoldr h bのh関数とb値が具体的に何になるかを探る。
fusion lawとその三条件は:
f . foldr g a = foldr h b f undefined = undefined f a = b f (g x y) = h x (f y)
fusion lawに出てくるfは現在定義しているmss内のfと対応しないので少しややこしい。
直接foldr1 max . foldr f [0]と対応させるのではなく、任意の関数(<>)と(@)を使って、どのような(<>), (@)ならfusion lawが適用可能でどのようなh, bになるかを調べてから、現在の具体例に戻る。
とりあえずわかっている対応付け:
foldr h b = foldr1 (<>) . foldr f [e] where f x xs == e : map (x@) xs f_fusion = foldr1 (<>) g_fusion = f a_fusion = [e]
第一条件はfoldr1 (<>) undefined = undefined。
foldr1がまず引数を[x]か(x:xs)にパターンマッチしないといけないので、引数がundefinedだとすぐにundefinedを返す。よって第一条件は任意の(<>)でクリア。
第二条件はfoldr1 (<>) [e] = b。
foldr1の定義からfoldr1 f [x] = xなので任意の(<>)でfoldr1 (<>) [e] = e。b_fusion = e
第三条件はfoldr1 (<>) (f x xs) = h x (foldr1 (<>) xs)。
左辺を展開:
foldr1 (<>) (f x xs) = {f x xs == e : map (x@) xs} foldr1 (<>) (e : map (x@) xs) = {foldr1の定義} e <> (foldr1 (<>) (map (x@) xs))
まずxs = [y]という要素数1のケースを考えてみる。
左辺
e <> (foldr1 (<>) (map (x@) [y])) = {map (x@) [y] = [x @ y]} e <> (foldr1 (<>) [x @ y]) = {foldr1 f [x] = x} e <> (x @ y)
右辺
h x (foldr1 (<>) [y]) = {foldr1 f [x] = x} h x y
なのでh x y = e (<>) (x @ y)が成立する。
それを一般的なfoldr1 (<>) (e : map (x@) xs) = h x (foldr1 (<>) xs)に代入すると:
foldr1 (<>) (e : map (x@) xs) = e (<>) (x @ (foldr1 (<>) xs)) -> {左辺のeを外に出す} e <> (foldr1 (<>) (map (x@) xs)) = e (<>) (x @ (foldr1 (<>) xs)) -> {両辺からe(<>)を除く} foldr1 (<>) (map (x@) xs) = x @ (foldr1 (<>) xs) -> {point-free化} foldr1 (<>) . map (x@) = (x@) . foldr1 (<>)
となる。最後の条件が満たされるためには(@)が(<>)に対して分配法則を満たす必要がある。つまり:
x @ (y <> z) = (x @ y) <> (x @ z)
が成り立つのが条件。
(@) = (+), (<>) = maxの場合は
x + (y `max` z) = (x + y) `max` (x + z)
が成り立つ。というわけで遡ってfusion lawが適用できる。
f_fusion = foldr1 (<>) g_fusion = f a_fusion = [e] b_fusion = e h_fusion x y = e (<>) (x @ y) e = 0 (<>) = max (@) = (+)
なので
foldr1 max . foldr f [0] where f x xs = e : map (x+) xs = {fusion law} foldr h 0 where h x y = 0 `max` (x + y)
これでmss関数は:
maximum . map (foldr1 max . foldr f [0]) . tails where f x xs = e : map (x+) xs = {fusion law} maximum . map (foldr h 0) . tails where h x y = 0 `max` (x + y)
scanr
map (foldr h 0) . tailsはscanlの定義:
scanl f e = map (foldl f e) . init
に非常に似ている。
scanlと似たような論理でscanr f e = map (foldr f e) . tailsが定義できる:
scanr :: (a -> b -> b) -> b -> [a] -> [b] scanr f e [] = [e] scanr f e (x:xs) = (f x (head ys)) : ys where ys = scanr f e xs
mssのmap (foldr f e) . tailsをscanrに変換:
maximum . map (foldr h 0) . tails where h x y = 0 `max` (x + y) = {map (foldr f e) . tails = scanr f e} maximum . scanr h 0 where h x y = 0 `max` (x + y)
最終的な解
mss = maximum . scanr h 0 where h x y = 0 `max` (x + y)
となった。強力だが普遍的な高階関数を使うことで記述が非常に簡潔になったが、さらに重要なのは計算量。なんとO(n)。
なにをしているかを試しにPythonで書いてみると:
def mss(a_list): e = 0 sums = [] for j in range(len(a_list)): i = -1 - j e = max(0, e + a_list[i]) sums.append(e) return max(sums)
たしかにうまくいくしO(n)だな・・・ 最初の自明かつ非効率な解と、この最適化された解が同値だと証明できるのはすごい。