Arantium Maestum

プログラミング、囲碁、読書の話題

蟻本初級編攻略 - 2-4 Union Find木 後編

前回のUnion Find木実装を踏まえてAtCoder ABC/ARCのD問題や蟻本の元の例題に取り組んでみる。

ABC049 D問題 - 連結/Connectivity

abc049.contest.atcoder.jp

abc049.contest.atcoder.jp

道路ネットワークと鉄道ネットワークが走っている国で、街ごとに道路でも鉄道でも連結になっている街の数を出力する。

AとBが道路でも鉄道でも連結である必要かつ十分条件は「道路ネットワークでも鉄道ネットワークでも同一のグループに属する」ということなのが(少し考えると)わかった。

あとは街ごとに(道路ネットワークのルート, 鉄道ネットワークのルート)というタプルで所属するグループを表して、そのグループに所属する街の数を算出してやればいい。

というわけで実装:

from collections import Counter
 
n, k, l = map(int, input().split())
ps = [tuple(map(int, input().split())) for _ in range(k)]
rs = [tuple(map(int, input().split())) for _ in range(l)]
 
def root(n, u):
    if u[n] == n:
        return n
    u[n] = root(u[n], u)
    return u[n]
 
union1 = {x:x for x in range(1, n+1)}
for p, q in ps:
    union1[root(p, union1)] = root(q, union1)
 
union2 = {x:x for x in range(1, n+1)}
for r, s in rs:
    union2[root(r, union2)] = root(s, union2)
 
gs = {x:(root(x, union1), root(x, union2)) for x in range(1, n+1)}
c = Counter(gs.values())
print(' '.join(str(c[gs[x]]) for x in range(1, n+1)))

二回Union Findをして、最終的なグループをタプルで表し、グループごとの街数をCounterで算出し、街ごとに所属するグループの街数を出力している。

各Union FindがO(N log N)、最終的なグループを作るのもO(N log N)、カウンターはO(N)、出力もO(N)で計算量のオーダーはO(N log N)。

ARC097 D問題 - Equals

arc097.contest.atcoder.jp

arc097.contest.atcoder.jp

「いれかえが許された添字ペア」を何度でも利用して、配列の要素の値と添字を最大何個一致させることができるか。

「添字ペア」がUnion Findで言うところの連結の概念と一致することがわかれば実装は非常に楽。そのためには「(a, b), (b, c), (c, d) ... (y, z)」のような添字ペアの連なりがあれば、入れ替えを連続で行うことでa~zまでのすべての要素を任意の場所に配置できる、ということがわかればいい。

これは帰納法で考えられる。

添字ペアが(a, b)だけの場合:

  • ありえる二つの配置(a, b)と(b, a)が「入れ替えられる」という問題の定義から自明に可能

すでに入れ替えを連続で行うことでa1~amまでのすべての要素を任意の場所に配置できるグループに、(am, an)を加えた場合:

  • amの位置に任意の要素を持ってくることができる(帰納法)上、amとanを交換できるので任意の要素をanに持ってくることができる
  • anに位置にある要素をamに交換できる上、a1~amにあるすべての要素を任意の位置に入れ替えることができる(帰納法)のでanの要素を任意の位置に移せる

ということで証明終わり。

あとは「値iを含む位置jと位置iが同じグループに属している」という条件を満たすiの数を数えるだけ(単一であることが制約により保証されていることも重要な成立条件)。

n, m = map(int, input().split())
ps = [int(x) for x in input().split()]
xs = [tuple(map(int, input().split())) for _ in range(m)]
 
union = {x:x for x in range(1, n+1)}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for x,y in xs:
    union[root(x)] = root(y)
 
print(sum(root(i+1) == root(ps[i]) for i in range(n)))

「Union Find木の典型」と言えるような解法になる。

ARC036 D問題 - 偶数メートル

arc036.contest.atcoder.jp

街から街へ(もしかすると複数の)道が通っていて、距離の和が偶数になるような道順を通ってある街から別の街へ移動できるかを判定する問題。

各道の距離が与えられるが実際に重要なのはその距離が奇数か偶数かという点。

まず最初に浮かんだ解法がこれ:

arc036.contest.atcoder.jp

Union Find木は「偶数で繋がっているグループ」を表し、それとは別に「ある街が奇数で繋がっているグループ」をdictionaryで表す。

街A <- 奇数道-> 街B <- 偶数道 -> 街C

のようになっている場合、街Aと街Cも奇数で繋がっているので「奇数で繋がっている相手」を「偶数グループ」で表せるのがポイント。

あと

街A <- 奇数道-> 街B <- 奇数道 -> 街C

だとAとCが同じ偶数グループに属するようになる、というのも重要。

それを実装するとこうなる:

n, q = map(int, input().split())
ws = [tuple(map(int, input().split())) for _ in range(q)]
 
union = {x:x for x in range(1, n+1)}
odds = {x:None for x in range(1, n+1)}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for w, x, y, z in ws:
    rx, ry = root(x), root(y)
    if w == 2:
        print('YES' if rx == ry else 'NO')
    elif z % 2 == 0:
        union[rx] = ry
        if odds[rx] and odds[ry]:
            union[root(odds[rx])] = root(odds[ry])
        else:
            odds[rx] = odds[ry] = odds[rx] or odds[ry]
    else:
        odds[rx] = odds[rx] or ry
        odds[ry] = odds[ry] or rx
        union[rx] = root(odds[ry])
        union[ry] = root(odds[rx])

ちょっと奇数・偶数で場合分けした時のロジックが複雑になっている。

もう一つの解法はこれ:

arc036.contest.atcoder.jp

Union Find木で「奇数で繋がっている」「偶数で繋がっている」を一括で管理する。各ノードは「街、奇偶」のタプルで表している。

実装が簡単になるなーと思ったのだが・・・

n, q = map(int, input().split())
ws = [tuple(map(int, input().split())) for _ in range(q)]
 
union = {(x, z):(x, z) for x in range(1, n+1) for z in 'oe'}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for w, x, y, z in ws:
    if w == 2:
        print('YES' if root((x, 'e')) == root((y, 'e')) else 'NO')
    elif z % 2 == 0:
        rex, rey = sorted([root((x, 'e')), root((y, 'e'))])
        union[rex] = rey
        rox, roy = sorted([root((x, 'o')), root((y, 'o'))])
        union[rox] = roy
    else:
        reox, reoy = sorted([root((x, 'e')), root((y, 'o'))])
        union[reox] = reoy
        roex, roey = sorted([root((x, 'o')), root((y, 'e'))])
        union[roex] = roey

多分同じ街が二回現れることが理由なんだと思うのだが、気をつけないとroot関数が無限ループに陥ってStack Overflowを起こす。実際テストケース最後の一つだけで何回かREしてしまった・・・

かならず「ソート順が一番大きいノード」に貼るようにすればループは起こりえない。その分少しコードが冗長になってしまったが、さきほどの解に比べると場合分けのロジックの対称性がはっきりしていてやはりわかりやすいように思う。

ARC090 D問題 - People in a line

arc090.contest.atcoder.jp

arc090.contest.atcoder.jp

直線上に並んでいる人たちについて、「LはRのXm左にいる」という情報が与えられる。その情報に矛盾がないか判定する問題。

Union Find木の発展系である重みつきUnion Find木を使うと簡単に解ける:

n, m = map(int, input().split())
xs = [tuple(map(int, input().split())) for _ in range(m)]
 
union = {x:(x, 0) for x in range(1, n+1)}
def root(n):
    if union[n][0] == n:
        return union[n]
    n1, d1 = union[n]
    n2, d2 = root(n1)
    union[n] = (n2, d2+d1)
    return union[n]
 
for l, r, d in xs:
    n1, d1 = root(l)
    n2, d2 = root(r)
    if n1 != n2:
        union[n2] = (n1, d + d1 - d2)
    elif d != d2 - d1:
        print('No')
        break
else:
    print('Yes')

「ルートノードの人、その人に対しての距離」をタプルとしてUnion Find木に入れている。

新しい情報が入ってくるたびに * もしルートノードが違えばUnion Find木をアップデートできる(二つのルートノードの相対位置がわかる) * もしルートノードが同じならそのルートノードに対してのLとRの位置の差が新しい情報と一致しているかを判定する

Union Findのノードをタプル化して追加で情報を持たせるのは「偶数メートル」と同じ。ただし、root関数内でその情報をまとめる必要がでてくるのでより複雑(といってもノード間の距離の総和を追加で返すだけだが)

あと少し気の利いたところとしては最後のelse。Pythonのfor-loopには「breakしなかったら最後にこれを実行」というelse構文があることはあまり知られていないように思う。ループ中に矛盾が見つかれば"No"と出力してbreak、もし見つからなければ最後にelseで"Yes"と出力、というロジック。

POJ 1182 - 食物連鎖

「AがBを食べ、BがCを食べ、CがAを食べる」とわかっているなかで、「xとyは同じ種」「xはyを食べる」という情報が順次入ってくる。そのうちで以前来た情報や個体番号の制限と矛盾するものは無視するとして、最終的に無視した情報の数を求める問題。

入力値の一例としてはこんな感じ:

n = 100
k = 7

xs = [(1, 101, 1),
      (2, 1, 2),
      (2, 2, 3),
      (2, 3, 3),
      (1, 1, 3),
      (2, 3, 1),
      (1, 5, 5)]

nが個体数、kが情報数、xsが情報のリスト。情報は

  • 1=「同じ種」, 2=「xはyを食べる」という情報のタイプ
  • xの番号
  • yの番号

の三点からなるタプル。

上記の入力から期待される出力は3。

1 101 1 2 3 3 2 3 1

の三つの情報が矛盾する。

これもUnion Find木で解決するのだが、ノードは個体ではなく「1はAである」などの命題であり、ノードの連結は

命題Aが真の時そしてその時に限り命題Bも真 A <-> B

という関係を表している。

なので

  • 「xとyは同じ種」という情報が入ってくるとx-Aとy-Aは互いに連結、x-Bとy-B、x-Cとy-Cも同じく互いに連結
  • 「xはyを食べる」という情報が入ってくるとx-Aとy-Bは互いに連結、x-Bとy-C、x-Cとy-Aも同じく互いに連結

その上で「xはyを食べる」という情報が入ってきた時すでにx-Aとy-Aが連結であったりした場合、情報が矛盾するので無視することになる。

というのが以下の実装:

groups = ('a', 'b', 'c')
union = {(x, g):(x, g) for x in range(1, n+1)
                         for g in groups}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]

def is_error(i, x, y):
    if not (1 <= x <= n and 1 <= y <= n):
        return True
    if i == 1:
        return root((x, 'a')) in {root((y, g)) for g in ('b', 'c')}
    if i == 2:
        return root((x, 'a')) in {root((y, g)) for g in ('a', 'c')}
    return True

ans = 0
for i, x, y in xs:
    if is_error(i, x, y):
        ans += 1
    elif i == 1:
        for g in groups:
            union[root((x, g))] = root((y, g))
    else:
        for g1, g2 in (('a', 'b'), ('b', 'c'), ('c', 'a')):
            union[root((x, g1))] = root((y, g2))
print(ans) 

まあ実装はそこまで特筆すべきものはない・・・ 矛盾があるかをチェックする時に「x is A」と繋がったyに関する命題の矛盾だけ調べればいい、という点くらいか。

しかし解法自体は非常に面白かった。「命題がノードになる」というのはけっこう盲点で、蟻本の解説を読むまではどう書けばいいのかなかなか想像がつかなかった。

この視点はUnion Find木の応用の幅が広がりそうなのでこれからも大切にしたい。

結論

Union Find好き。

蟻本初級編攻略 - 2-4 Union Find木 前編

「ある集合の二つの要素が繋がっているか」を効率的に判定できるデータ構造であるUnion Find木の話。

非常に簡単でエレガントな実装のわりに非常に強力な概念だという印象で、個人的にはとても好き。

蟻本の例題が一番特殊で面白いので後回し。まずはAtCoder Typical Contestから:

AtCoder Typical Contest 001 B問題 Union Find

atc001.contest.atcoder.jp

atc001.contest.atcoder.jp

ザ・Union Find典型。Union Findが逐次更新できるデータ構造で、途中のいつでも「現在明らかになった情報を元に二つの要素が繋がっているか判定できる」という点も強調されている。

Python実装はこんな感じ:

n, q = map(int, input().split())
xs = [tuple(map(int, input().split())) for _ in range(q)]
 
union = {x:x for x in range(n)}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for p, a, b in xs:
    if p == 0:
        union[root(a)] = root(b)
    else:
        print("Yes" if root(a) == root(b) else "No")

やはり非常に簡潔に済むのがうれしい。

Union Findはデータをdictionaryに持たせ、そのdictionaryをもとにある要素が繋がっているグループのルートノードを返す関数と、二つのグループをつなげるイディオムを提供する。二つの要素が繋がっているかを判定するのは、その二つの要素の属するグループのルートノードが一致するかを調べるだけ。

Union Findは気をつけないとroot関数が最悪O(n)かかってしまう。それを回避するメジャーな最適化が二つある。ここらへんはAtCoderのUnion FindについてのSlideShareに詳しい:

www.slideshare.net

両方の最適化を使うと計算量のオーダーが逆アッカーマン関数(ほぼ定数・・・)になるということだが、一つだけでもO(logN)になるうえ、実装が簡単で定数倍ははやい。

というわけで私は「ルートノードを探る度に枝をすべてルートノード直下に貼り直す」という最適化のみ実装している。

つまりroot関数のナイーブな実装:

def root(n):
    if union[n] == n:
        return n
    return root(union[n])

それをこう変えている:

def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n]) #ここ
    return union[n]

これだけでO(N)がO(log N)になるのはうれしい。

そういえば昨日PEP572が受理されたので将来こう書けるようになる:

def root(n):
    if union[n] == n:
        return n
    return union[n] := root(union[n]) #ここ

まあこれは大した違いじゃないけど・・・。

競プロ時にPythonでClassを使わずにデータ構造を実装することについて

ちなみに私は個人的に競技プログラミングPythonを書く場合、Union Findくらいだったらインターフェースつけてclassにwrapするより、もとのデータ構造をむき出しにしておきたい。

ということをちょっと乱暴に呟いた:

そしたら速攻でツッコミが入った:

(次回出てくる重みつきUnion Find木の話)

もうちょっと丁寧に自分の考えを説明するとこんな感じ:

(最初のnamedtupleについての言及はプロダクションコードだったら、の話)

さらにいうとデータ構造がむき出しのほうが短期的にはいじりやすい。微妙にデータ構造に持たせる情報や算出させる情報を変えたいときに、すくなくともABCレベルの競技プログラミングだと変更箇所はかなり限られているから使われているその場で変えてしまったほうが楽。

数百行以上のコードのいろんなところで使われているなら状況は逆転する。やはりプロダクションコードと競技プログラミングでは細かい作法がけっこう違ってくるのは自然だと思う。

次回予告

つらつらと書いていたら長くなったので記事を二回にわける。

次回はAtCoder ABC/ARCからUnion Findを使ったD問題4つと蟻本の例題の解法を説明する。

蟻本初級編攻略 - 2-4 二分探索木

蟻本の章名である「二分探索木」というサブタイトルを付けたが、Pythonなので二分探索木使わない。setもdictもハッシュマップだ。

ABC085 B問題 Kagami Mochi

abc085.contest.atcoder.jp

abc085.contest.atcoder.jp

上に乗せる鏡餅は下の餅より直径が小さくなくてはいけないという制約のもと、何段の鏡餅が作れるかという問題。

制約は言い換えれば「同じ直径の鏡餅は使えない」ということ。なので直径の重複を除いて数えればいい。

n = int(input())
print(len(set(int(input()) for _ in range(n))))

Pythonなのでsetにしてlenをとるだけ。set化がO(n)、lenがO(1)。

ABC091 B問題 Two Colors Card Game

abc091.contest.atcoder.jp

abc091.contest.atcoder.jp

青と赤のカードに文字列が書いてあって、ある文字列を選ぶとその文字列が書いてある青いカードの数だけポイントがもらえ、赤いカードの数だけポイントが減る。ポイントを最大化する文字列は何か、という問題。

青いカードに書かれている文字列のうち、青いカードの数と赤いカードの数の差額で最も大きいもの、あるいは0が答え:

from collections import Counter
 
n = int(input())
ss = Counter(input() for _ in range(n))
m = int(input())
ts = Counter(input() for _ in range(m))

answer = max(cs - ts[s] for s, cs in ss.items())
print(answer if answer > 0 else 0)

標準ライブラリに入っている、dictを拡張したCounterクラスを使っている。上記のロジックをそのままで記述している。

Counter作成はO(n)、ss.items()のループは全部でO(n)、ts[s]はO(1)、maxはO(n)、と全体でもO(n)の計算量。

蟻本初級編攻略 - 2-4 Expedition

動的計画法関連のところはPythonでやるとTLEする問題もいくつかあり、OCamlでやることも視野に入れつつ少し寝かせておく。

というわけで先にPriority Queueを使った問題3つ。まずは蟻本に載っているPOJ問題から:

POJ - Expedition

2431 -- Expedition

現在地から目的地への直線上にある複数のガソリンスタンドに最小で何回止まって給油する必要があるか、という問題。

ガソリンが足りなくなる限界までの区間にあるガソリンスタンドの給油量をPriorityQueueにいれて最大値を取り出し、ガソリンが足りなくなる限界距離を更新して考慮に入れるスタンドをQueueにいれて最大値を取り出し、というループを限界距離が目的距離に到達するまで続ける:

from queue import PriorityQueue

n = int(input())
xs = [tuple(int(x) for x in input().split()) for _ in range(n)]
l, p = map(int, input().split())

current = p
count = 0
q = PriorityQueue()
xs = list(sorted([(l-x, v) for x, v in xs], reverse=True))

while current < l:
    while xs and xs[-1][0] <= current:
        q.put(-xs.pop()[1])
    if not q:
        break
    current -= q.get()
    count += 1

print(count if current >= l else -1)

PythonのPriorityQueueやheapqはmin-heapしか提供していない。ちょっとショック。なので「最大値を常にとってくる」という挙動のためにマイナスでかけている。その点以外は比較的素直な実装ではないだろうか。

それではAtCoderの類題2問を解いてみる。

Code Thanks Festival 2017 C問題 - Factory

code-thanks-festival-2017-open.contest.atcoder.jp

code-thanks-festival-2017-open.contest.atcoder.jp

プレゼントを作るごとにスピードが劣化していく複数の機械をどう使うと最小の時間で一定数のプレゼントが作れるか、という問題。

キューに「かかる時間」と「劣化の度合い」をタプルとして入れて、かかる時間が最小のものを選んで、そして「かかる時間+劣化」と「劣化の度合い」のタプルをキューに戻して、とループしていく。ほしいプレゼント数の分だけループが回って「かかる時間」の和が答え。

import heapq
 
n, k = map(int, input().split())
xys = [tuple(int(x) for x in input().split()) for _ in range(n)]
 
heapq.heapify(xys)
 
time = 0
for _ in range(k):
    time += xys[0][0]
    heapq.heappushpop(xys, (xys[0][0]+xys[0][1], xys[0][1]))
 
print(time)

今回はheapqで実装してみた。heapqには普通のlistをヒープ化したり、ヒープ化されたリストに対するヒープ操作をしたり、という関数が用意されている。実装がリスト準拠なので、最小の要素をpeekするだけならxs[0]でできる。queueモジュールのPriorityQueueは実はスレッドセーフな実装になっており、その分のオーバーヘッドがかかるのでheapqのほうが定数倍効率がいい。

heapqがO(N)、heappushpopがO(logN)なので全部でO(n + k logn)の計算量。

ARC074 D問題 - 3N Numbers

arc074.contest.atcoder.jp

arc074.contest.atcoder.jp

3*N個の正の整数のリストからN個要素を取り除き、「前半N個の和 - 後半N個の和」を最大化する問題。

制約から見えてきたいくつかのポイント:

  • 前半の和を最大化し、後半の和を最小化する必要がある
  • 前半はA[:2*n]の要素のうちのN個によって構成される
  • 後半はA[n:]の要素のうちのN個によって構成される
  • 前半が使える部分と後半が使える部分のオーバーラップであるA[n:2*n]のどこに線を引くかがポイント

というわけで

  • i = 0 ~ nの各iごとにA[:n+i]からN個選んだ場合の和の最大値を求める
  • j = 0 ~ nの各ごとにA[2*n-j:]からN個選んだ場合の和の最小値を求める

ということをPriorityQueueを使って一歩ずつ算出していく。その上でi+j = nで差が最大になるiとjのペアを見つける。

from heapq import heapify, heappushpop
 
n = int(input())
xs = [int(x) for x in input().split()]
 
xs1 = xs[:n]
heapify(xs1)
r1 = [sum(xs1)]
for x in xs[n:2*n]:
    r1.append(r1[-1] + x - heappushpop(xs1, x))
 
xs2 = [-x for x in xs[2*n:]]
heapify(xs2)
r2 = [sum(xs2)]
for x in reversed(xs[n:2*n]):
    r2.append(r2[-1] + (-x) - heappushpop(xs2, -x))
 
print(max(a+b for a,b in zip(r1, reversed(r2))))

前半の和の最大値を求めるのに、A[n+i]の要素をキューにいれてから最小値を出して、(A[n+i] - 最小値)をそれまでの和に足している。後半も全く同じロジック、ただし「最大値をキューから取り出す」という挙動のためにキューに入れる・キューから取り出す値をすべてマイナスでかけている。

ほとんどの部分がO(N)だが、ループしている部分だけN回ループが回ってO(logN)のheappushpopが毎回行われているので、そこが計算量O(N logN)で大きい。

Priority Queue感想

やっぱり便利。今回の問題はみんなPriority Queueが前面に出てきているが、そうでなくても解法の一部として計算量を下げるのに活躍してくれる場面は多そう。

しかしmax heapがないのはやっぱり不便だな、微妙なところで実装を間違えてバグりそうだから要注意だ。

AtCoder Beginner Contest 102に参加してみた

第1問

abc102.contest.atcoder.jp

abc102.contest.atcoder.jp

入力値nと2の最小公倍数を求める問題。nが偶数ならn、そうじゃないなら2*n

n = int(input())
print(2*n if n%2 else n)

第2問

abc102.contest.atcoder.jp

abc102.contest.atcoder.jp

Aの(添字の)異なる 2 要素の差の絶対値の最大値を求めてください。

と言われるとややこしく感じてしまうが「2要素の差が一番大きい」なのでmaxとminの差。

入力値に「Aの要素が2つ以上」という制約がある以上、要素の値がすべて同じだったとしても「添字の異なる2要素」は満たせる。

n = int(input())
xs = [int(x) for x in input().split()]
print(max(xs) - min(xs))

第3問

abc102.contest.atcoder.jp

abc102.contest.atcoder.jp

Ai と b+iの差

と考えるとちょっとややこしいかもしれないがabs(Ai - (b+i) = abs((Ai - i) - b)なのでまずAi-iの数列Bを作ってしまう。

その数列Bの各要素と任意の数bの距離の和を最小化したい。

たとえばbがBのどの要素よりも小さかったとする。そうするとb+1はBのすべての要素に1ずつ近くなるので明らかにbより勝る。

bがBの最小の要素と同じ値だったとする(Bの最小の要素が単一だと仮定する)。するとb+1はその最小の要素からの距離は1増え、n-1個の要素との距離は1ずつ減る。

この「1ずれるごとに距離が1増える要素と距離が1減る要素の数」がちょうどバランスするところが最善であり、そのポイントでの距離の和が答えとなる。

それは中央値を構成する要素が1つの場合はちょうどその要素の値となるし、もし中央値が二つの要素の平均だとしたら、その二つの要素の値の間ならどこでも最善になる。

なのでてっとりばやく中央値を使える:

from statistics import median
 
n = int(input())
xs = [int(x) for x in input().split()]
 
ys = [x-i-1 for i, x in enumerate(xs)]
m = int(median(ys))
 
print(sum(abs(y-m) for y in ys))

第4問

abc102.contest.atcoder.jp

abc102.contest.atcoder.jp

この問題はコンテスト中には解けなかった・・・ しゃくとり法や左右から端を狭めていく方法などをずっと考えていて最後に苦し紛れの実装でWAを出したりしながらタイムアップ。悔しいので解説は見ずに1日ぼーっと考えたりしていた。

累積和をとると便利なのは最初から考えていたのだが、「真ん中をまず割ってみる」というところに発想がいかなかったのが敗因。

真ん中を割ってしまえば、「左右の配列を個別にどう割るか」という二つの小さい問題になる。

まず左の配列について考えてみる:

  • 配列をどう割るのが最善か、と考えると、ここでも和の差を最小化するのが一番だとわかる
  • 和の差が最小化されるためには一つ一つの和ができるだけ平均に近ければいい
  • 累積和を平均値で二分探索すればいい
  • ただし二分探索で得られた添字の周りが答えの可能性もあるので-1, +1も調べて和の差が最小化されるポイントを選ぶ

右の配列も似た論理だ。ただし累積和がすこしややこしい(左の配列の総和を差し引く必要がある)。

あとは左右の配列の結果を組み合わせて最大と最小の和の差をとれば、あるポイントを真ん中にして割った場合の最小の和の差がわかる。

つまり真ん中が決定していれば、最小の和の差を2*O(log N/2)で計算できる。

あとは真ん中候補すべてに対してこの計算をして最小の結果が全体の答え。なので全体でO(N logN)。

from itertools import accumulate, islice
from bisect import bisect_left
 
n = int(input())
xs = [int(x) for x in input().split()]
 
ys = list(accumulate(xs))
zs, total = ys[:-1], ys[-1]
 
res = max(xs) - min(xs)
for i, z in islice(enumerate(zs), 1, n-1):
    j = bisect_left(zs, z//2)
    splits = ((zs[j+k], z - zs[j+k]) for k in (-1, 0, 1) if 0 <= j+k <= i)
    a, b = min(splits, key=lambda s:abs(s[0]-s[1]))
 
    j = bisect_left(zs, (total - z)//2 + z)
    splits = ((zs[j+k] - z, total - zs[j+k]) for k in (-1, 0, 1) if i <= j+k <= n-2)
    c, d = min(splits, key=lambda s:abs(s[0]-s[1]))
 
    minn, _, _, maxx = sorted((a,b,c,d))
    if maxx - minn < res:
        res = maxx - minn
 
print(res)

Pythonだとテストケースにかかる時間が最長で1800msとかなりギリギリだ。うっかりj = bisect_left(zs, z/2)などと浮動小数点との比較で二分探索してしまうと簡単にTLEになる。今度OCamlで実装してどう書けてどれくらい速度がでるか試してみたい。

OCamlでTypical Dynamic Programming Contest E問題を解いてみた

ここ一ヶ月ほどAtCoderPythonですすめてきて、Pythonの遅さにひやりとすることはあっても基本的にアルゴリズムがあっていて素直な最適化を施していればABCの問題であれば通せそう、という感触を得ている。

しかし少し上の問題を見てみるとそうはいかないようだ。例えばAtCoderのTypical Dynamic Programming ContestというDP問題ばかり集めたサンプルコンテストのE問題:

tdpc.contest.atcoder.jp

かなり素直な桁DPなのだが、桁数が10000でmodをとる値が最大100なので計算量のオーダーが10000 * 100 * 10(最後の10は0~9までの数で回すから)、つまり107になる。 この計算量はコンパイル言語なら余裕だがPythonだと普通にアウトのようだ。

tdpc.contest.atcoder.jp

(まあこのコードが定数で最適かといわれると違うのだが・・・)

というわけでコンパイル言語で書いてみる。

競技プログラミング的に素直な選択はC++だが、もう少しPythonに近い宣言的な表現力のある言語がいい。

表現力という意味ではHaskellが相当高いが、純粋関数型だとアルゴリズムにけっこう工夫が必要になり、Pythonで考えるのとはまったく違う世界になる。

速度もあり、適度に関数型で宣言的に記述することができ、かつ必要なところで副作用を気軽に扱える言語、ということでOCamlを選んだ。

tdpc.contest.atcoder.jp

最長251msで無事AC。

open Batteries
 
let digitList s = s |> String.to_list |> List.map (fun c -> Char.code c - 48)
let id x = x
 
let d = Scanf.scanf "%d\n" id
let n = Scanf.scanf "%s\n" id
let m = 1000000007
 
let f x n =
  let v, a = x in
  let a1 = Array.make d 0 in
  for i = 0 to d-1 do
    for j = 0 to 9 do
      a1.((i+j) mod d) <- (a1.((i+j) mod d) + a.(i)) mod m
    done
  done;
  for j = 0 to n-1 do
    a1.((v+j) mod d) <- (a1.((v+j) mod d) + 1) mod m
  done;
  ((v+n) mod d, a1)
 
let v, a = List.fold_left f (0, Array.make d 0) (digitList n)
let () = Printf.printf "%d\n" (a.(0)  + (if v == 0 then 1 else 0) - 1)

まともにOCamlを書くのははじめてだったのでまだいろいろ汚いと思うが、とりあえず自分がしたいことを比較的素直に記述できて速度が出るのは非常にうれしい。

f関数の中で動的計画法らしくArrayに対する破壊的代入を繰り返しているが、引数には変更は加えずに命令的に作成したArrayを結果として返しているので外側からみると純粋関数だ。

そのf関数を引数としてfold_leftで畳み込んだり、文字列を|>を使った関数のチェーンで一桁の数字のリストに変換したり、といったところは抽象度が高い関数型らしい記述ができる。

その上で実行速度が107のオーダーをAtCoderの時間制限の2000msでも問題なくこなせるのだからこれは強い。

習熟度が非常に低いので実際のコンテストではまだ使えないだろうが、少しずつ過去問(とくに計算量のオーダーが大きそうな問題で)をOCamlで挑戦していきたい。

トポロジカル・ソートをPythonで実装してみた

DPとはDAGの最短経路を、トポロジカルソート順に埋めていくことで計算する手法という話からそもそもトポロジカルソートってどうやるんだっけ?となり、Pythonで一つのアプローチを実装してみた:

from collections import defaultdict, deque

v, n = map(int, input().split())
es = [[int(x) for x in input().split()] for _ in range(n)]

outs = defaultdict(list)
ins = defaultdict(int)
for v1, v2 in es:
    outs[v1].append(v2)
    ins[v2] += 1

q = deque(v1 for v1 in range(v) if ins[v1] == 0)
res = []
while q:
    v1 = q.popleft()
    res.append(v1)
    for v2 in outs[v1]:
        ins[v2] -= 1
        if ins[v2] == 0:
            q.append(v2)

print(', '.join(res))

冒頭にリンクも貼った「DPの話」の記事で言われている「入次数 0 のノード (とそこから伸びる辺) をひたすら取り除きまくる」実装。

vがノードの数、nが辺の数、esが各辺を表す親ノード、小ノードのペア。

outsが各ノードから出ていく先のノードのリストの辞書。insが各ノードの入次数の辞書。

qが入次数が0のノードのキュー。

  • キューの頭のノードを取り出し、ソートされたリストの次のメンバとしてappend
  • そのノードから伸びる各辺について
    • その行き先のノードの入次数を1減らし
    • もしそのノードの入次数が0に落ちたらキューに入れる

という処理を繰り返す。

outsとinsの作成にO(E)、qの作成にO(V)、resの作成のループはO(V+E)(外側のループがO(V)回、内側のループは合計でO(E)回)ということで全体的にO(E+V)の計算量。

実際にDPでトポロジカルソートをする必要があることはあまりないらしい(入力が与えられた時点ですでにDAGとしてソートされていることが多いとのこと)。だけどとりあえず必要だったらすぐ実装して使えるようにしておきたい。今回でわかった通りロジックも実装も非常に簡単だし。