Thinking Functionally with Haskell勉強メモ: 第6章4 Maximum Segment Sum
第6章の最後はJohn BentleyのProgramming Pearlsに出てくる「Maximum Segment Sum」を今までのような式変換と証明で解く、という演習。
Maximum Segment Sum問題
あるリスト(は整数)の中にある連続部分の和の最大値を求める関数
mss
を定義せよただし空の連続部分の和は0だとする
例えば
mss [-1, 2, -3, 5, -2, 1, 3, -2, -2, -3, 6] = sum [5, -2, 1, 3] = 7
含まれる整数は負の値を取ることがあるが、mss
は負にはならない(空リストの和が0なので)。
問題をコードで表現
愚直に問題をコードで書き移す:
mss :: [Int] -> Int mss = maximum . map sum . segments segments :: [a] -> [[a]] segments = concat . map inits . tails inits :: [a] -> [[a]] inits [] = [[]] inits (x:xs) = [] : map (x:) (inits xs) tails :: [a] -> [[a]] tails [] = [[]] tails (x:xs) = (x:xs) : tails xs
mss
関数は
- 引数に
[Int]
をとり - その連続部分すべてを列挙した
[[Int]]
を - 各部分の和のリスト
[Int]
に変換してから - そのリストの最大値
Int
を返す。
segments
関数は[Int]
を、その連続部分全ての列挙である[[Int]]
にしている。
inits
は[x, y, z]
を[[], [x], [x, y], [x, y, z]]
に変換。(前回のブログでscanl
の定義に使用したのを再掲)
tails
は[x, y, z]
を[[x, y, z], [y, z], [z], []]
に変換。
segments
の挙動を追うと:
[x, y, z] -> tails [[x, y, z], [y, z], [z], []] -> map inits [[[x, y, z], [x, y], [x], []], [[y, z], [y], []], [[z], []], [[]]] -> concat [[x, y, z], [x, y], [x], [], [y, z], [y], [], [z], [], []]
空リストがやけに重複してしまうが、それ以外では過不足なく連続部分を列挙できていることがわかる。
さて、これで正しい答えが値となるコードは書けた。
しかし、効率面を見ると、tails
、inits
が引数の長さnに対してO(n)個のリストを返し、sum
は時間がO(n)となっているので、O(n**3)になっている。
ここからEquational Reasoningで効率化していく。
O(n**2)へ
これから使用する法則を挙げる:
map f . concat = concat . map (map f) map f . map g = map (f . g) maximum . concat = maximum . map maximum -- ただし引数が「空ではないリスト」のリストの場合 scanl f e = map (foldl f e) . inits sum = foldr (+) 0 = foldl (+) 0 scanl (+) 0 = map sum . inits
これらを適用していくと:
maximum . map sum . concat . map inits . tails = {map f . concat = concat . map (map f)} maximum . concat . map (map sum) . map inits . tails = {maximum . concat = maximum . map maximum} maximum . map maximum . map (map sum) . map inits . tails = {map f . map g = map (f . g)} maximum . map (maximum . map sum . inits) . tails = {map sum . inits = scanl (+) 0} maximum . map (maximum . scanl (+) 0) . tails
inits
とsum
をscanl
に組み合わせられたおかげでこの部分をO(n2)からO(n)に落とすことができた。リストを大量に作って一つ一つ足し合わせるのは重複した計算が大量にあり、先頭から足していって途中経過を出力していくことでその重複を避けることができる。
ここまでで全体ではO(n**2)。
maximum
をfoldr1 max
に
foldr1
というのはfoldr
に似ているが、空リストではエラーになる関数。エラーにならない関数がすでにあるのになぜわざわざ、と思うかもしれないが、例えばmaximum
などはそもそも空リストを受け取った場合何を返していいか不明なのでエラーになるのが正しい。
foldr1
の定義と、foldr1
と二項演算子max
を使ったmaximum
の定義:
foldr1 :: (a -> b -> b) -> [a] -> b foldr1 f [x] = x foldr1 f (x:xs) = f x (foldr1 f xs) maximum = foldr1 max
なのでmss
に戻ると、計算量はかわらないがmap
内のmaximum
を変換しておく:
maximum . map (maximum . scanl (+) 0) . tails = {maximum = foldr1 max} maximum . map (foldr1 max . scanl (+) 0) . tails
あとで効いてくる。
scanl
をfoldr
に
scanl (+) 0
を何か別の形に変換してみたい。
まずは具体例を入れて挙動からヒントを:
scanl (+) 0 [x, y, z] = [0, x, x+y, x+y+z] = 0 : map (x+) [0, y, y+z] = 0 : map (x+) (scanl (+) 0 [y, z]) scanl (+) 0 [x, y, z] = 0 : map (x+) (scanl (+) 0 [y, z])
これだとfoldr
の形に似ている:
scanl (+) 0 (x:xs) = 0 : map (x+) (scanl (+) 0 xs) foldr f e (x:xs) = f x (foldr f e xs)
一般化すると、scanl
とfoldr
の間に以下の変換則が成り立つ:
scanl (@) e = foldr f [e] where f x xs = e : map (x@) xs
ただし(@)
はe
を単位元にもち、結合律を満たす関数。
これでmss
は:
maximum . map (foldr1 max . scanl (+) 0) . tails = {scanl (@) e = foldr f [e]} maximum . map (foldr1 max . foldr f [0]) . tails where f x xs = e : map (x+) xs
となる。これで何が嬉しいかというと、map
の中身がf . foldr g a
の形になった事実。
fusion law
つまりfusion law、f . foldr g a = foldr h b
の出番である。
foldr1 max . foldr f [0] where f x xs = e : map (x+) xs
を対象に、fusion law適用の条件を満たしているかを確認し、またfoldr h b
のh
関数とb
値が具体的に何になるかを探る。
fusion lawとその三条件は:
f . foldr g a = foldr h b f undefined = undefined f a = b f (g x y) = h x (f y)
fusion lawに出てくるf
は現在定義しているmss
内のf
と対応しないので少しややこしい。
直接foldr1 max . foldr f [0]
と対応させるのではなく、任意の関数(<>)
と(@)
を使って、どのような(<>), (@)
ならfusion lawが適用可能でどのようなh, b
になるかを調べてから、現在の具体例に戻る。
とりあえずわかっている対応付け:
foldr h b = foldr1 (<>) . foldr f [e] where f x xs == e : map (x@) xs f_fusion = foldr1 (<>) g_fusion = f a_fusion = [e]
第一条件はfoldr1 (<>) undefined = undefined
。
foldr1
がまず引数を[x]
か(x:xs)
にパターンマッチしないといけないので、引数がundefinedだとすぐにundefinedを返す。よって第一条件は任意の(<>)
でクリア。
第二条件はfoldr1 (<>) [e] = b
。
foldr1
の定義からfoldr1 f [x] = x
なので任意の(<>)
でfoldr1 (<>) [e] = e
。b_fusion = e
第三条件はfoldr1 (<>) (f x xs) = h x (foldr1 (<>) xs)
。
左辺を展開:
foldr1 (<>) (f x xs) = {f x xs == e : map (x@) xs} foldr1 (<>) (e : map (x@) xs) = {foldr1の定義} e <> (foldr1 (<>) (map (x@) xs))
まずxs = [y]
という要素数1のケースを考えてみる。
左辺
e <> (foldr1 (<>) (map (x@) [y])) = {map (x@) [y] = [x @ y]} e <> (foldr1 (<>) [x @ y]) = {foldr1 f [x] = x} e <> (x @ y)
右辺
h x (foldr1 (<>) [y]) = {foldr1 f [x] = x} h x y
なのでh x y = e (<>) (x @ y)
が成立する。
それを一般的なfoldr1 (<>) (e : map (x@) xs) = h x (foldr1 (<>) xs)
に代入すると:
foldr1 (<>) (e : map (x@) xs) = e (<>) (x @ (foldr1 (<>) xs)) -> {左辺のeを外に出す} e <> (foldr1 (<>) (map (x@) xs)) = e (<>) (x @ (foldr1 (<>) xs)) -> {両辺からe(<>)を除く} foldr1 (<>) (map (x@) xs) = x @ (foldr1 (<>) xs) -> {point-free化} foldr1 (<>) . map (x@) = (x@) . foldr1 (<>)
となる。最後の条件が満たされるためには(@)
が(<>)
に対して分配法則を満たす必要がある。つまり:
x @ (y <> z) = (x @ y) <> (x @ z)
が成り立つのが条件。
(@) = (+), (<>) = max
の場合は
x + (y `max` z) = (x + y) `max` (x + z)
が成り立つ。というわけで遡ってfusion lawが適用できる。
f_fusion = foldr1 (<>) g_fusion = f a_fusion = [e] b_fusion = e h_fusion x y = e (<>) (x @ y) e = 0 (<>) = max (@) = (+)
なので
foldr1 max . foldr f [0] where f x xs = e : map (x+) xs = {fusion law} foldr h 0 where h x y = 0 `max` (x + y)
これでmss
関数は:
maximum . map (foldr1 max . foldr f [0]) . tails where f x xs = e : map (x+) xs = {fusion law} maximum . map (foldr h 0) . tails where h x y = 0 `max` (x + y)
scanr
map (foldr h 0) . tails
はscanl
の定義:
scanl f e = map (foldl f e) . init
に非常に似ている。
scanl
と似たような論理でscanr f e = map (foldr f e) . tails
が定義できる:
scanr :: (a -> b -> b) -> b -> [a] -> [b] scanr f e [] = [e] scanr f e (x:xs) = (f x (head ys)) : ys where ys = scanr f e xs
mss
のmap (foldr f e) . tails
をscanr
に変換:
maximum . map (foldr h 0) . tails where h x y = 0 `max` (x + y) = {map (foldr f e) . tails = scanr f e} maximum . scanr h 0 where h x y = 0 `max` (x + y)
最終的な解
mss = maximum . scanr h 0 where h x y = 0 `max` (x + y)
となった。強力だが普遍的な高階関数を使うことで記述が非常に簡潔になったが、さらに重要なのは計算量。なんとO(n)。
なにをしているかを試しにPythonで書いてみると:
def mss(a_list): e = 0 sums = [] for j in range(len(a_list)): i = -1 - j e = max(0, e + a_list[i]) sums.append(e) return max(sums)
たしかにうまくいくしO(n)だな・・・ 最初の自明かつ非効率な解と、この最適化された解が同値だと証明できるのはすごい。