Arantium Maestum

プログラミング、囲碁、読書の話題

蟻本初級編攻略 - 2-4 Union Find木 後編

前回のUnion Find木実装を踏まえてAtCoder ABC/ARCのD問題や蟻本の元の例題に取り組んでみる。

ABC049 D問題 - 連結/Connectivity

abc049.contest.atcoder.jp

abc049.contest.atcoder.jp

道路ネットワークと鉄道ネットワークが走っている国で、街ごとに道路でも鉄道でも連結になっている街の数を出力する。

AとBが道路でも鉄道でも連結である必要かつ十分条件は「道路ネットワークでも鉄道ネットワークでも同一のグループに属する」ということなのが(少し考えると)わかった。

あとは街ごとに(道路ネットワークのルート, 鉄道ネットワークのルート)というタプルで所属するグループを表して、そのグループに所属する街の数を算出してやればいい。

というわけで実装:

from collections import Counter
 
n, k, l = map(int, input().split())
ps = [tuple(map(int, input().split())) for _ in range(k)]
rs = [tuple(map(int, input().split())) for _ in range(l)]
 
def root(n, u):
    if u[n] == n:
        return n
    u[n] = root(u[n], u)
    return u[n]
 
union1 = {x:x for x in range(1, n+1)}
for p, q in ps:
    union1[root(p, union1)] = root(q, union1)
 
union2 = {x:x for x in range(1, n+1)}
for r, s in rs:
    union2[root(r, union2)] = root(s, union2)
 
gs = {x:(root(x, union1), root(x, union2)) for x in range(1, n+1)}
c = Counter(gs.values())
print(' '.join(str(c[gs[x]]) for x in range(1, n+1)))

二回Union Findをして、最終的なグループをタプルで表し、グループごとの街数をCounterで算出し、街ごとに所属するグループの街数を出力している。

各Union FindがO(N log N)、最終的なグループを作るのもO(N log N)、カウンターはO(N)、出力もO(N)で計算量のオーダーはO(N log N)。

ARC097 D問題 - Equals

arc097.contest.atcoder.jp

arc097.contest.atcoder.jp

「いれかえが許された添字ペア」を何度でも利用して、配列の要素の値と添字を最大何個一致させることができるか。

「添字ペア」がUnion Findで言うところの連結の概念と一致することがわかれば実装は非常に楽。そのためには「(a, b), (b, c), (c, d) ... (y, z)」のような添字ペアの連なりがあれば、入れ替えを連続で行うことでa~zまでのすべての要素を任意の場所に配置できる、ということがわかればいい。

これは帰納法で考えられる。

添字ペアが(a, b)だけの場合:

  • ありえる二つの配置(a, b)と(b, a)が「入れ替えられる」という問題の定義から自明に可能

すでに入れ替えを連続で行うことでa1~amまでのすべての要素を任意の場所に配置できるグループに、(am, an)を加えた場合:

  • amの位置に任意の要素を持ってくることができる(帰納法)上、amとanを交換できるので任意の要素をanに持ってくることができる
  • anに位置にある要素をamに交換できる上、a1~amにあるすべての要素を任意の位置に入れ替えることができる(帰納法)のでanの要素を任意の位置に移せる

ということで証明終わり。

あとは「値iを含む位置jと位置iが同じグループに属している」という条件を満たすiの数を数えるだけ(単一であることが制約により保証されていることも重要な成立条件)。

n, m = map(int, input().split())
ps = [int(x) for x in input().split()]
xs = [tuple(map(int, input().split())) for _ in range(m)]
 
union = {x:x for x in range(1, n+1)}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for x,y in xs:
    union[root(x)] = root(y)
 
print(sum(root(i+1) == root(ps[i]) for i in range(n)))

「Union Find木の典型」と言えるような解法になる。

ARC036 D問題 - 偶数メートル

arc036.contest.atcoder.jp

街から街へ(もしかすると複数の)道が通っていて、距離の和が偶数になるような道順を通ってある街から別の街へ移動できるかを判定する問題。

各道の距離が与えられるが実際に重要なのはその距離が奇数か偶数かという点。

まず最初に浮かんだ解法がこれ:

arc036.contest.atcoder.jp

Union Find木は「偶数で繋がっているグループ」を表し、それとは別に「ある街が奇数で繋がっているグループ」をdictionaryで表す。

街A <- 奇数道-> 街B <- 偶数道 -> 街C

のようになっている場合、街Aと街Cも奇数で繋がっているので「奇数で繋がっている相手」を「偶数グループ」で表せるのがポイント。

あと

街A <- 奇数道-> 街B <- 奇数道 -> 街C

だとAとCが同じ偶数グループに属するようになる、というのも重要。

それを実装するとこうなる:

n, q = map(int, input().split())
ws = [tuple(map(int, input().split())) for _ in range(q)]
 
union = {x:x for x in range(1, n+1)}
odds = {x:None for x in range(1, n+1)}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for w, x, y, z in ws:
    rx, ry = root(x), root(y)
    if w == 2:
        print('YES' if rx == ry else 'NO')
    elif z % 2 == 0:
        union[rx] = ry
        if odds[rx] and odds[ry]:
            union[root(odds[rx])] = root(odds[ry])
        else:
            odds[rx] = odds[ry] = odds[rx] or odds[ry]
    else:
        odds[rx] = odds[rx] or ry
        odds[ry] = odds[ry] or rx
        union[rx] = root(odds[ry])
        union[ry] = root(odds[rx])

ちょっと奇数・偶数で場合分けした時のロジックが複雑になっている。

もう一つの解法はこれ:

arc036.contest.atcoder.jp

Union Find木で「奇数で繋がっている」「偶数で繋がっている」を一括で管理する。各ノードは「街、奇偶」のタプルで表している。

実装が簡単になるなーと思ったのだが・・・

n, q = map(int, input().split())
ws = [tuple(map(int, input().split())) for _ in range(q)]
 
union = {(x, z):(x, z) for x in range(1, n+1) for z in 'oe'}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for w, x, y, z in ws:
    if w == 2:
        print('YES' if root((x, 'e')) == root((y, 'e')) else 'NO')
    elif z % 2 == 0:
        rex, rey = sorted([root((x, 'e')), root((y, 'e'))])
        union[rex] = rey
        rox, roy = sorted([root((x, 'o')), root((y, 'o'))])
        union[rox] = roy
    else:
        reox, reoy = sorted([root((x, 'e')), root((y, 'o'))])
        union[reox] = reoy
        roex, roey = sorted([root((x, 'o')), root((y, 'e'))])
        union[roex] = roey

多分同じ街が二回現れることが理由なんだと思うのだが、気をつけないとroot関数が無限ループに陥ってStack Overflowを起こす。実際テストケース最後の一つだけで何回かREしてしまった・・・

かならず「ソート順が一番大きいノード」に貼るようにすればループは起こりえない。その分少しコードが冗長になってしまったが、さきほどの解に比べると場合分けのロジックの対称性がはっきりしていてやはりわかりやすいように思う。

ARC090 D問題 - People in a line

arc090.contest.atcoder.jp

arc090.contest.atcoder.jp

直線上に並んでいる人たちについて、「LはRのXm左にいる」という情報が与えられる。その情報に矛盾がないか判定する問題。

Union Find木の発展系である重みつきUnion Find木を使うと簡単に解ける:

n, m = map(int, input().split())
xs = [tuple(map(int, input().split())) for _ in range(m)]
 
union = {x:(x, 0) for x in range(1, n+1)}
def root(n):
    if union[n][0] == n:
        return union[n]
    n1, d1 = union[n]
    n2, d2 = root(n1)
    union[n] = (n2, d2+d1)
    return union[n]
 
for l, r, d in xs:
    n1, d1 = root(l)
    n2, d2 = root(r)
    if n1 != n2:
        union[n2] = (n1, d + d1 - d2)
    elif d != d2 - d1:
        print('No')
        break
else:
    print('Yes')

「ルートノードの人、その人に対しての距離」をタプルとしてUnion Find木に入れている。

新しい情報が入ってくるたびに * もしルートノードが違えばUnion Find木をアップデートできる(二つのルートノードの相対位置がわかる) * もしルートノードが同じならそのルートノードに対してのLとRの位置の差が新しい情報と一致しているかを判定する

Union Findのノードをタプル化して追加で情報を持たせるのは「偶数メートル」と同じ。ただし、root関数内でその情報をまとめる必要がでてくるのでより複雑(といってもノード間の距離の総和を追加で返すだけだが)

あと少し気の利いたところとしては最後のelse。Pythonのfor-loopには「breakしなかったら最後にこれを実行」というelse構文があることはあまり知られていないように思う。ループ中に矛盾が見つかれば"No"と出力してbreak、もし見つからなければ最後にelseで"Yes"と出力、というロジック。

POJ 1182 - 食物連鎖

「AがBを食べ、BがCを食べ、CがAを食べる」とわかっているなかで、「xとyは同じ種」「xはyを食べる」という情報が順次入ってくる。そのうちで以前来た情報や個体番号の制限と矛盾するものは無視するとして、最終的に無視した情報の数を求める問題。

入力値の一例としてはこんな感じ:

n = 100
k = 7

xs = [(1, 101, 1),
      (2, 1, 2),
      (2, 2, 3),
      (2, 3, 3),
      (1, 1, 3),
      (2, 3, 1),
      (1, 5, 5)]

nが個体数、kが情報数、xsが情報のリスト。情報は

  • 1=「同じ種」, 2=「xはyを食べる」という情報のタイプ
  • xの番号
  • yの番号

の三点からなるタプル。

上記の入力から期待される出力は3。

1 101 1 2 3 3 2 3 1

の三つの情報が矛盾する。

これもUnion Find木で解決するのだが、ノードは個体ではなく「1はAである」などの命題であり、ノードの連結は

命題Aが真の時そしてその時に限り命題Bも真 A <-> B

という関係を表している。

なので

  • 「xとyは同じ種」という情報が入ってくるとx-Aとy-Aは互いに連結、x-Bとy-B、x-Cとy-Cも同じく互いに連結
  • 「xはyを食べる」という情報が入ってくるとx-Aとy-Bは互いに連結、x-Bとy-C、x-Cとy-Aも同じく互いに連結

その上で「xはyを食べる」という情報が入ってきた時すでにx-Aとy-Aが連結であったりした場合、情報が矛盾するので無視することになる。

というのが以下の実装:

groups = ('a', 'b', 'c')
union = {(x, g):(x, g) for x in range(1, n+1)
                         for g in groups}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]

def is_error(i, x, y):
    if not (1 <= x <= n and 1 <= y <= n):
        return True
    if i == 1:
        return root((x, 'a')) in {root((y, g)) for g in ('b', 'c')}
    if i == 2:
        return root((x, 'a')) in {root((y, g)) for g in ('a', 'c')}
    return True

ans = 0
for i, x, y in xs:
    if is_error(i, x, y):
        ans += 1
    elif i == 1:
        for g in groups:
            union[root((x, g))] = root((y, g))
    else:
        for g1, g2 in (('a', 'b'), ('b', 'c'), ('c', 'a')):
            union[root((x, g1))] = root((y, g2))
print(ans) 

まあ実装はそこまで特筆すべきものはない・・・ 矛盾があるかをチェックする時に「x is A」と繋がったyに関する命題の矛盾だけ調べればいい、という点くらいか。

しかし解法自体は非常に面白かった。「命題がノードになる」というのはけっこう盲点で、蟻本の解説を読むまではどう書けばいいのかなかなか想像がつかなかった。

この視点はUnion Find木の応用の幅が広がりそうなのでこれからも大切にしたい。

結論

Union Find好き。