Arantium Maestum

プログラミング、囲碁、読書の話題

Thinking Functionally with Haskell勉強メモ: 第6章4 Maximum Segment Sum

第6章の最後はJohn BentleyのProgramming Pearlsに出てくる「Maximum Segment Sum」を今までのような式変換と証明で解く、という演習。

Maximum Segment Sum問題

あるリスト{A = (a_1, a_2, ... a_n)}{a_i}は整数)の中にある連続部分の和の最大値を求める関数mssを定義せよ

ただし空の連続部分の和は0だとする

例えば

  mss [-1, 2, -3, 5, -2, 1, 3, -2, -2, -3, 6]
= sum [5, -2, 1, 3]
= 7

含まれる整数{a_i}は負の値を取ることがあるが、mssは負にはならない(空リストの和が0なので)。

問題をコードで表現

愚直に問題をコードで書き移す:

mss :: [Int] -> Int
mss = maximum . map sum . segments

segments :: [a] -> [[a]]
segments = concat . map inits . tails

inits :: [a] -> [[a]]
inits []     = [[]]
inits (x:xs) = [] : map (x:) (inits xs)

tails :: [a] -> [[a]]
tails []     = [[]]
tails (x:xs) = (x:xs) : tails xs

mss関数は

  • 引数に[Int]をとり
  • その連続部分すべてを列挙した[[Int]]
  • 各部分の和のリスト[Int]に変換してから
  • そのリストの最大値Intを返す。

segments関数は[Int]を、その連続部分全ての列挙である[[Int]]にしている。

inits[x, y, z][[], [x], [x, y], [x, y, z]]に変換。(前回のブログでscanlの定義に使用したのを再掲)

tails[x, y, z][[x, y, z], [y, z], [z], []]に変換。

segmentsの挙動を追うと:

[x, y, z]

-> tails
[[x, y, z], [y, z], [z], []]

-> map inits
[[[x, y, z], [x, y], [x], []], [[y, z], [y], []], [[z], []], [[]]]

-> concat
[[x, y, z], [x, y], [x], [], [y, z], [y], [], [z], [], []]

空リストがやけに重複してしまうが、それ以外では過不足なく連続部分を列挙できていることがわかる。

さて、これで正しい答えが値となるコードは書けた。

しかし、効率面を見ると、tailsinitsが引数の長さnに対してO(n)個のリストを返し、sumは時間がO(n)となっているので、O(n**3)になっている。

ここからEquational Reasoningで効率化していく。

O(n**2)へ

これから使用する法則を挙げる:

map f . concat = concat . map (map f)

map f . map g = map (f . g)

maximum . concat = maximum . map maximum -- ただし引数が「空ではないリスト」のリストの場合

scanl f e = map (foldl f e) . inits
sum = foldr (+) 0 = foldl (+) 0
scanl (+) 0 = map sum . inits

これらを適用していくと:

maximum . map sum . concat . map inits . tails

= {map f . concat = concat . map (map f)}
maximum . concat . map (map sum) . map inits . tails

= {maximum . concat = maximum . map maximum}
maximum . map maximum . map (map sum) . map inits . tails

= {map f . map g = map (f . g)}
maximum . map (maximum . map sum . inits) . tails

= {map sum . inits = scanl (+) 0}
maximum . map (maximum . scanl (+) 0) . tails

initssumscanlに組み合わせられたおかげでこの部分をO(n2)からO(n)に落とすことができた。リストを大量に作って一つ一つ足し合わせるのは重複した計算が大量にあり、先頭から足していって途中経過を出力していくことでその重複を避けることができる。

ここまでで全体ではO(n**2)。

maximumfoldr1 max

foldr1というのはfoldrに似ているが、空リストではエラーになる関数。エラーにならない関数がすでにあるのになぜわざわざ、と思うかもしれないが、例えばmaximumなどはそもそも空リストを受け取った場合何を返していいか不明なのでエラーになるのが正しい。

foldr1の定義と、foldr1二項演算子maxを使ったmaximumの定義:

foldr1 :: (a -> b -> b) -> [a] -> b
foldr1 f [x]    = x
foldr1 f (x:xs) = f x (foldr1 f xs)

maximum = foldr1 max

なのでmssに戻ると、計算量はかわらないがmap内のmaximumを変換しておく:

maximum . map (maximum . scanl (+) 0) . tails

= {maximum = foldr1 max}
maximum . map (foldr1 max . scanl (+) 0) . tails

あとで効いてくる。

scanlfoldr

scanl (+) 0を何か別の形に変換してみたい。

まずは具体例を入れて挙動からヒントを:

scanl (+) 0 [x, y, z]
= [0, x, x+y, x+y+z]
= 0 : map (x+) [0, y, y+z]
= 0 : map (x+) (scanl (+) 0 [y, z])

scanl (+) 0 [x, y, z] = 0 : map (x+) (scanl (+) 0 [y, z])

これだとfoldrの形に似ている:

scanl (+) 0 (x:xs) = 0 : map (x+) (scanl (+) 0 xs)
foldr f e (x:xs) = f x (foldr f e xs)

一般化すると、scanlfoldrの間に以下の変換則が成り立つ:

scanl (@) e = foldr f [e]
  where f x xs = e : map (x@) xs

ただし(@)e単位元にもち、結合律を満たす関数。

これでmssは:

maximum . map (foldr1 max . scanl (+) 0) . tails

= {scanl (@) e = foldr f [e]}
maximum . map (foldr1 max . foldr f [0]) . tails
  where f x xs = e : map (x+) xs 

となる。これで何が嬉しいかというと、mapの中身がf . foldr g aの形になった事実。

fusion law

つまりfusion law、f . foldr g a = foldr h bの出番である。

foldr1 max . foldr f [0]
  where f x xs = e : map (x+) xs 

を対象に、fusion law適用の条件を満たしているかを確認し、またfoldr h bh関数とb値が具体的に何になるかを探る。

fusion lawとその三条件は:

f . foldr g a = foldr h b

f undefined   = undefined
f a           = b
f (g x y)     = h x (f y)

fusion lawに出てくるfは現在定義しているmss内のfと対応しないので少しややこしい。

直接foldr1 max . foldr f [0]と対応させるのではなく、任意の関数(<>)(@)を使って、どのような(<>), (@)ならfusion lawが適用可能でどのようなh, bになるかを調べてから、現在の具体例に戻る。

とりあえずわかっている対応付け:

foldr h b = foldr1 (<>) . foldr f [e]
            where f x xs == e : map (x@) xs

f_fusion = foldr1 (<>)
g_fusion = f
a_fusion = [e]

第一条件はfoldr1 (<>) undefined = undefined

foldr1がまず引数を[x](x:xs)にパターンマッチしないといけないので、引数がundefinedだとすぐにundefinedを返す。よって第一条件は任意の(<>)でクリア。

第二条件はfoldr1 (<>) [e] = b

foldr1の定義からfoldr1 f [x] = xなので任意の(<>)foldr1 (<>) [e] = eb_fusion = e

三条件foldr1 (<>) (f x xs) = h x (foldr1 (<>) xs)

左辺を展開:

foldr1 (<>) (f x xs)

= {f x xs == e : map (x@) xs}
foldr1 (<>) (e : map (x@) xs)

= {foldr1の定義}
e <> (foldr1 (<>) (map (x@) xs))

まずxs = [y]という要素数1のケースを考えてみる。

左辺

e <> (foldr1 (<>) (map (x@) [y]))

= {map (x@) [y] = [x @ y]}
e <> (foldr1 (<>) [x @ y])

= {foldr1 f [x] = x}
e <> (x @ y)

右辺

h x (foldr1 (<>) [y])

= {foldr1 f [x] = x}
h x y

なのでh x y = e (<>) (x @ y)が成立する。

それを一般的なfoldr1 (<>) (e : map (x@) xs) = h x (foldr1 (<>) xs)に代入すると:

foldr1 (<>) (e : map (x@) xs)    = e (<>) (x @ (foldr1 (<>) xs))

-> {左辺のeを外に出す}
e <> (foldr1 (<>) (map (x@) xs)) = e (<>) (x @ (foldr1 (<>) xs))

-> {両辺からe(<>)を除く}
foldr1 (<>) (map (x@) xs)        = x @ (foldr1 (<>) xs)

-> {point-free化}
foldr1 (<>) . map (x@)           = (x@) . foldr1 (<>)

となる。最後の条件が満たされるためには(@)(<>)に対して分配法則を満たす必要がある。つまり:

x @ (y <> z) = (x @ y) <> (x @ z)

が成り立つのが条件。

(@) = (+), (<>) = maxの場合は

x + (y `max`  z) = (x + y) `max` (x + z)

が成り立つ。というわけで遡ってfusion lawが適用できる。

f_fusion = foldr1 (<>)
g_fusion = f
a_fusion = [e]
b_fusion = e
h_fusion x y = e (<>) (x @ y) 
e = 0
(<>) = max
(@) = (+)

なので

foldr1 max . foldr f [0]
  where f x xs = e : map (x+) xs 

= {fusion law}
foldr h 0
  where h x y = 0 `max` (x + y)

これでmss関数は:

maximum . map (foldr1 max . foldr f [0]) . tails
  where f x xs = e : map (x+) xs

= {fusion law}
maximum . map (foldr h 0) . tails
  where h x y = 0 `max` (x + y)

scanr

map (foldr h 0) . tailsscanlの定義:

scanl f e = map (foldl f e) . init

に非常に似ている。

scanlと似たような論理でscanr f e = map (foldr f e) . tailsが定義できる:

scanr :: (a -> b -> b) -> b -> [a] -> [b]
scanr f e []     = [e]
scanr f e (x:xs) = (f x (head ys)) : ys
                   where ys = scanr f e xs

mssmap (foldr f e) . tailsscanrに変換:

maximum . map (foldr h 0) . tails
  where h x y = 0 `max` (x + y)

= {map (foldr f e) . tails = scanr f e}
maximum . scanr h 0
  where h x y = 0 `max` (x + y)

最終的な解

mss = maximum . scanr h 0
      where h x y = 0 `max` (x + y)

となった。強力だが普遍的な高階関数を使うことで記述が非常に簡潔になったが、さらに重要なのは計算量。なんとO(n)。

なにをしているかを試しにPythonで書いてみると:

def mss(a_list):
  e = 0
  sums = []
  for j in range(len(a_list)):
    i = -1 - j
    e = max(0, e + a_list[i])
    sums.append(e)
  return max(sums)

たしかにうまくいくしO(n)だな・・・ 最初の自明かつ非効率な解と、この最適化された解が同値だと証明できるのはすごい。