Arantium Maestum

プログラミング、囲碁、読書の話題

PythonでForth処理系を書く! その4(内部状態をPython Dictionaryに保存)

前回の終わりで書いたように、Forth処理系の内部状態、とくにコンパイラ部分とインタプリタ部分のものをデータ構造として抜き出して、保存・出力できるようにしたい。

その変更がこれ:

github.com

主要な変更は二点:

  1. 内部状態をPythonのDictionaryで表し、compileやinterpretにそのState dictionaryを渡している
  2. コンパイラインタプリタのどちらにも存在したif-elif-elseの分岐処理を、個別にState dictionaryを受け取ってひとつの処理を行う複数の小さい関数に分割する

state dictionaryの導入

とくにREPLを実装するのに役にたつと考えての変更。

一つのプログラムを最初から最後まで一気に読み込んで処理を行う場合は、こういう状態が関数内で完結していて問題ないのだが、ユーザから一行ずつプログラムを与えられて、その都度「トークン化・コンパイル・実行」するとなると、過去の状態も含めて保持する必要がある。

その方法としては

  • オブジェクト
  • Python generator
  • グローバル変数としてのState
  • 引数・戻り値としてのイミュータブルState
  • 引数としてのミュータブルState

あたりが考えられると思う。

一番素直なのはオブジェクト化してしまって状態をそのままメンバ変数にいれてしまうことだろう。あるいはPythonだとyieldキーワードを使ったgeneratorで関数に状態を手軽に持たせることもできる。ただ、デバッグなどの関係もあって、できれば状態を簡単なデータ構造に入れて外部からアクセスできるようにしたかった。(もちろんオブジェクトでもメソッドを定義するなどして実現できるが少し手間だ・・・)

グローバル変数は宗教上の理由でアウト。イミュータブルなStateを引数・戻り値にするのは好みにかなっているのだが、Pythonのデータ構造はそもそもmutationを前提としていてイミュータブルに使うにはコピーコストが大きい。

というわけでcompileやinterpretに引数としてStateを与えて、それに対して破壊的変更を加えていく形のコードになった。

if-elif-elseの分割

Stateをデータ構造につめこんだので、一つの関数内で処理が完結する必要はなくなった。

  • 分岐で行っていた各処理を個別の関数にわけて
  • 現在位置にあるトークン・コードによってディスパッチ
  • Stateも引数として与えて変更がかかる

という流れにする。このような実装にすることで今後機能を追加していくときに「一つのでかい条件分岐にelifを追加してどんどん長くしていく」という地獄を避けられるはず。

次回は今回の変更を踏まえてインタプリタにREPL機能をつける。

PythonでForth処理系を書く! その3(最低限の機能)

PythonでForth処理系を書く! その2(実装する機能) - Arantium Maestum

でも書いた通り、まずは以下の機能を最低限として実装する:

  • 整数をスタックに乗せる
  • スタックから二つの整数をとり、それらを足し合わせた整数をスタックに乗せる
  • スタックから整数をひとつとり、その数を標準出力に表示する
  • 標準出力に改行を表示する

実行できるコードはこんな感じだ:

1 2 + . CR

簡単な実装

まずはできるだけ簡単に実装してみる:

s = '1 2 + . CR'

stack = []
for token in s.split():
    try:
        stack.append(int(token))
        continue
    except ValueError:
        pass
    if token == '+':
        a, b = stack.pop(), stack.pop()
        stack.append(a+b)
    elif token == '.':
        print(stack.pop(), end='')
    elif token == 'CR':
        print()

ハードコードされたforthコードをsplitでトークンに分け、forループで一つ一つのトークンごとにtry-except/if-elseでトークンタイプをチェックして実行していく。

今回実装する機能だけだったらこれで充分かもしれない。ただ、密結合な上にフローをかなり単純視している(具体的にはトークンとインタプリタの動作に1対1の対応があることを前提としてしまっている)ので、今後の機能追加が容易ではない。

拡張しやすくなるように、どのような異なる機能・段階があるのかを明確にする形で分割してみる。

分割した実装

github.com

まずはインタプリタの全体を統括するmain.py:

import tokenizer
import compiler
import interpreter

if __name__ == '__main__':
    with open('samples/sample1.forth') as f:
        s = ' '.join(f)
    tokens = tokenizer.tokenize(s)
    code_list = compiler.compile(tokens)
    interpreter.interpret(code_list)

分割した個別のトークナイザ・コンパイラインタプリタをインポートして使っている。一つのポイントとしては、全体としてはインタプリタなのだが、トークンを直接インタプリタに渡すのではなく、インタプリタが実行しやすいような中間言語コンパイルするステップが入っていること。このコンパイルステージが今後の条件分岐や変数定義などでは非常に便利になる。

tokenizer.pyの実装:

def tokenize(s):
    return s.split()

非常に簡単。スペースで分割するだけ。トークンも特にクラス化せず、ただの文字として扱う。

compiler.py:

def compile(tokens):
    code_list = []
    for token in tokens:
        code_list.extend(compile_token(token))
    code_list.append('END')
    return code_list

def compile_token(token):
    if is_int(token):
        return ('PUSH', int(token))
    d = {'+':'PLUS', '.':'PRINT', 'CR':'CR'}
    return (d[token], )

def is_int(s):
    try:
        int(s)
    except ValueError:
        return False
    return True

トークンが数字だったらスタックに乗せるようにPUSH Xという二つのコードにコンパイル。それ以外(+, ., CR)はそれぞれ一つのコードにコンパイルする。

トークンが数字かどうかの判定もコンパイラ部分で行っているが、トーケナイザ部分で判定してその情報をトークンに持たせる手もある。あまり関係ないが、個人的にPythonの「文字列がInt化できるか判定」イディオムはあまり好きではない・・・

コンパイラの吐く中間言語を受け取り、実行していくinterpreter.py:

def interpret(code_list):
    counter = 0
    stack = []
    while True:
        if code_list[counter] == 'END':
            break
        elif code_list[counter] == 'PUSH':
            stack.append(code_list[counter+1])
            counter += 2
        elif code_list[counter] == 'PLUS':
            a, b = stack.pop(), stack.pop()
            stack.append(a+b)
            counter += 1
        elif code_list[counter] == 'PRINT':
            print(stack.pop(), end='')
            counter += 1
        elif code_list[counter] == 'CR':
            print()
            counter += 1

スタック(Pythonのリスト)に数字をappendしたりpopしたりしながら中間言語の指示を実行していく。

これで元々のサンプル:

1 2 + . CR

を実行することはもちろん、今後の機能を追加していくのもかなりやりやすくなった。だが、実はもう一つリファクタするべきポイントがある。コンパイラインタプリタの両方で、状態をデータ構造に保存するようにしたい。

次回はその理由と実装の話。

PythonでForth処理系を書く! その2(実装する機能)

処理系で実装する機能を列挙してみる。

  • 整数をスタックに乗せる
  • スタックから二つの整数をとり、それらを足し合わせた整数をスタックに乗せる
  • スタックから整数をひとつとり、その数を標準出力に表示する
  • 標準出力に改行を表示する
  • REPL機能
  • 分岐構文 IF-ELSE-THEN
  • ループ構文 DO-LOOP
  • 変数定義
  • 関数定義

第1段階 - 最低限の機能

まず手始めに実装するもの:

  • 整数をスタックに乗せる
  • スタックから二つの整数をとり、それらを足し合わせた整数をスタックに乗せる
  • スタックから整数をひとつとり、その数を標準出力に表示する
  • 標準出力に改行を表示する

以上で例えばこんなコードをファイルから読み出して実行できるようになる:

1 2 + . CR

処理としては

  • 1をスタックに乗せる
  • 2をスタックに乗せる
  • 2と1をスタックから取り出し、足した結果の3をスタックに乗せる
  • 3をスタックから取り出し、出力する
  • 改行を出力する

という流れになる。

第2段階 - REPL

上記の機能をインタラクティブなREPLを通して使えるようにする。

具体的には、一行入力されるごとにコードをインタプリタが実行していく。スタックの状態は行をまたいで維持される。

第3段階 - 分岐構文

IF-ELSE-THEN構文を実装する。

この構文は(とくにTHENが)ちょっと独特で、このように使う:

1 IF 2 ELSE 3 THEN . CR

どういう論理かというと

IFコマンドの時点でスタックの上が0以外ならELSEまで実行してからTHENにジャンプ、0ならELSEまでジャンプしてから実行を続ける

というもの。

上記のコードは2を出力してから改行する。 一番先頭の1が0なら「3を出力してから改行」になる。

第4段階 - ループ構文

DO-LOOP構文を実装する。

この構文の要素は3つあって:

  • DOでスタックに乗っている二つの整数を分岐のために使う スタックから得たはじめの整数がループインデックスの始点 スタックから得た次の整数がループインデックスの終点

  • LOOPでループインデックスをインクリメントする ループインデックスが終点未満の場合はDO直後のコマンドまで後方ジャンプ ループインデックスが終点に到達した場合はLOOP以降のコマンドを実行

  • DOとLOOPの間で使われるIは現在のループインデックスの値をスタックに乗せる

こんな使い方をする:

0 11 1 DO I + LOOP . CR

上記のコードは1以上11未満の整数をループ足し合わせて55を出力して改行する。

第5段階 - 変数定義

Forthでの変数はVARIABLEで宣言、!で代入、@で参照となる。

VARIABLE X
VARIABLE Y
5 X !
10 Y !
X @ Y @ + . CR

のように使う。上記のコードを走らせると15が出力される。

まずVARIABLE X、VARIABLE YでXとYを未使用のメモリ領域に結びつける。VARIABLE宣言の後にはXやYというコードはそのメモリ領域のアドレスをスタックに乗せるコマンドになる。

5 X !では5をスタックに乗せ、Xのアドレスをスタックに乗せ、そして!でその二つの整数をとり、最初の整数をアドレスとして次の整数をそのアドレスの値として代入する。

X @でXのアドレスをスタックに乗せ、そのアドレスをとってその領域に入っている値をスタックに乗せる。

第6段階 - 関数定義

Forthでは特定の処理が紐付いたコードのことをワードと呼ぶ。関数もワードだ。

ユーザがワードを自前で定義できるようにするための構文が:と;だ。

:がワード宣言の冒頭、;がワード宣言の終了を示す。:の次のコードは新しいワード名、それ以降で;までのコードはそのワードに紐づけられる処理だ。

: ADD2 2 + ;
1 ADD2 . CR

で3が出力される。

: ADD2 2 + :でADD2というワードがその処理2 +とともに登録される。

1 ADD2 . CRはADD2の部分で登録された処理が呼び出されるので結果としては1 2 + . CRと同じ処理となる。

その後

ここまで実装してとりあえず終わる予定。

この枠組みに四則演算やスタック上の値を操作(コピー・廃棄・上位Xを入れ替えなど)するような機能で肉付けしていくと、結構簡単に非常に表現力の高い言語処理系ができあがるようだ。こっちの実装はForthについてもっと理解が進んだら試してみたい。

AtCoder Beginner Contest 094の問題を解いた

第1問 Cats and Dogs

abc094.contest.atcoder.jp

abc094.contest.atcoder.jp

猫がA匹いることがわかっているのに加えて、追加で猫か犬かがB匹いる場合、猫がX匹いることが可能か。

不可能になるのはX < AかX > A + Bの二ケースなので、A <= X <= A+Bが成り立っているかどうかを調べる:

a, b, x = map(int, input().split())
print('YES' if (a <= x <= a+b) else 'NO')

第2問 Toll Gates

abc094.contest.atcoder.jp

abc094.contest.atcoder.jp

直線上に順番にならんでいる料金所を通るたびにコスト1かかるとして、地点Xからはじめて直線の左右どちらかの端に到達するのにかかる最小コストを求める。

料金所がもとからソートされているので、そのまま二分探索が使える。地点X自体には料金所がないことがわかっているのも処理を簡単にしてくれるポイント。

from bisect import bisect_left
 
n, m, x = map(int, input().split())
axs = [int(x) for x in input().split()]
 
i = bisect_left(axs, x)
print(min(i, len(axs) - i))

bisect_left(axs, x)で地点Xの左にいくつ関所があるかがわかる。

右にある料金所の数は「料金所の総数 - 左側にある料金所の数」なのでmin(i, len(axs) - i)で最小コストが求められる。

ちなみにこれ、最初に書いた解だと

print(min(len(axs[:i]), len(axs[i:])))

になっていた。iが持つ意味を直観的に捉えきれてなかったのだろうか。こういうところで何気なく無駄な処理をしてしまうかどうかがけっこうアルゴリズムでは重要な気がする。

第3問 Many Medians

abc094.contest.atcoder.jp

abc094.contest.atcoder.jp

与えられたリストxsのある要素が含まれていない場合の中央値を、各要素ごとに算出していく。

xsの長さが偶数であるという制約で計算が楽になる。

xsの中央値を計算する場合は、中央にある二つの数字m1とm2 (m1 <= m2)の平均をとることになる。

とりあえずxsから最小値を除いたリストzsの中央値を考えてみると、m2であることがわかる。さらに最小値でなくてもm2未満のどの数字を除いたとしてもm2が中央値になる。

そしてm2以上の場合はm1が中央値だ(最大値からはじめて全く同じロジックを考えることができる)。

なので実装はこんな感じ:

n = int(input())
xs = list(map(int, input().split()))
 
ys = list(sorted(xs))
m1 = ys[n//2 - 1]
m2 = ys[n//2]
 
medians = [m2 if x < m2 else m1 for x in xs]
print('\n'.join(map(str, medians)))

実はPython3.4以降statisticsという標準ライブラリが追加され、median_lowとmedian_highという関数が用意されている。

medianはquickselectアルゴリズムを使ったりすると平均でO(N)になったりすることが知られているので、ちょっと期待して使ってみる:

abc094.contest.atcoder.jp

from statistics import median_low, median_high
 
n = int(input())
xs = [int(x) for x in input().split()]
 
m1, m2 = median_low(xs), median_high(xs)
 
medians = (str(m2 if x < m2 else m1) for x in xs)
print('\n'.join(medians))

問題なくACするのだが、もとのコードが200msだったのがstatisticsを使うと300msほどになる。

CPython statisticsモジュールのソースにあたってみると、medianのなかでリストを複製ソートしていた。

元のコードで一番時間がかかるところをmedian_lowとmedian_highで二回やっているわけだからそれは遅くなるわけだ・・・

今度時間があったら自分でquickselectを実装してみたい。

第4問 Binomial Coefficients

abc094.contest.atcoder.jp

abc094.contest.atcoder.jp

nCrを最大化するnとrを与えられた配列xsから選択するという問題。

まず、任意のrで n1 > n2 なら n1 C r > n2 C rであることから、nはxsの最大値。

任意のnで、nCrを最大化するrはn/2に一番近いものだ。

そしてこの問題だとn > rである必要がある(nCn == nC0 == 1なのだが、この設問だと(n,0)がACの場合、(n, n)はWA)

_ = input()
 
axs = list(map(int, input().split()))
 
i = max(axs)
j = min(axs, key=lambda j: abs((i/2) - j) + (i == j))
print('{} {}'.format(i, j))

Pythonのsort/max/minがオプションでkey関数をとるので、「xに一番近い」というのをminで表せる。ついでにi == jの時のペナルティもつけた。ペナルティが必要になるのはリストに最大要素と0しかない場合のみなので、「最大要素の場合1を足す」というペナルティだけで(n, n)になることを防げる。

感想

この回は基礎数学的な素養が試されている印象が強い。典型的なアルゴリズムやデータ構造らしきものを使うことなく、考察が済めばあとはちょっとソートしたりループしたりするだけで、一番アルゴリズムしていたのはB問題の二分探索だった。

頭の体操としては大変面白いのだが、ABC103のように一旦盲点に入ってしまうとアイデアの手がかりがあまりなく思考が空転してしまいそうな怖さがある。どうしたらこの分野を強化できるか(というか安定的に問題解決に向けて思考を重ねていけるか)、ちょっと考えないといけない。

PythonでForth処理系を書く! その1(序)

Forthというプログラミング言語がある。分類としては「スタックマシン型言語」になるだろう。

スタックマシンというのは、データやアドレスをスタックに乗せたり取り出したりすることを基本の操作とするプログラム実行モデルで、有名どころでいうとJava Virtual MachinePython Bytecode Interpreterのような「いったん中間言語としてコンパイルされたものを実行する」ことによく使われている。

Forthはそのスタックマシンを直接操るプログラミング言語だ。

実行モデルがC(というかAlgol)系やML/Haskell的関数型、あるいはLispなどとも大きく違っていて面白い。個人的にもっとよく知りたい言語のひとつだ。

魅力の一つとして、言語処理系としては実装が非常に簡単かつ効率的かつ表現力が高いことが挙げられる。組み込みなどのシステムで、自前でForthを作って使う、ということも十分可能なようだ。

というわけで以前からForth処理系を作ってみたいと思っていたのだが、ある日Kindleでそのものずばりな「Forthを作ってみる」という電子書籍が売っていたのでそれを買ってみた。以下のサイトの内容を書籍化したもののようだ:

Forthを作ってみる - moiの頭の中

読んでみて大体の概要がわかったのでPythonで似たような機能を持つForth Interpreterを実装してみた。

github.com

これから何回かに分けてこの実装過程について書いていきたい。

AtCoder Beginner Contest 103に参加してみた

今回はノーミスで全完したのだが、C問題にあまりにも時間がかかってしまい、実感としてはかなりまずい印象だった。

前回のABCから三週間ぶりで感覚がおかしかった気がする。

第1問

abc103.contest.atcoder.jp

abc103.contest.atcoder.jp

「最小→真ん中→最大」あるいは「最大→真ん中→最小」のどちらかの順番ですすむのが最善で、結果は最大-最小の差になる。

xs = [int(x) for x in input().split()]
print(max(xs) - min(xs))

いきなりWAを飛ばしてびっくりした。何回見直しても大丈夫そうだったので放置しておいたら、AtCoder側の不具合があったらしくリジャッジでACになっていた。

第2問

abc103.contest.atcoder.jp

abc103.contest.atcoder.jp

ある文字列が別の文字列を回転させたものかどうか、という質問。文字列の長さが最長100なので愚直に書いて問題ない。

s = input()
t = input()
 
print('Yes' if any(s == t[i:] + t[:i] for i in range(len(t))) else 'No')

ただ、ツイッターprint('Yes' if s in t+t else 'No')を見たときには、自分で思いつかなかったのが悔しかった。

第3問

abc103.contest.atcoder.jp

abc103.contest.atcoder.jp

何故かこの問題に1時間ほどかかった。

  • 規則性?LCM?とかうだうだと考えるのに四十五分ほど
  • (a1 * a2 * ... an) - 1 だとどの数でも割り切れないし最大化されるな、しかし巨大数になりそう((10 ^ 5) ^ 3000)だからどうするか・・・ と悩むのに十分
  • 「あ」と解法に気づいた瞬間から実装するのに一分かからなかったと思う
n = input()
xs = [int(x) for x in input().split()]
 
print(sum(x-1 for x in xs))

(a1 * a2 * ... an) - 1で「すべての数字が割り切るのに1足りない数」を作れるので、その数の結果はsum(a - 1)になる。気づいてしまえばなんということはないのだけど、何故かなかなか気付かなかった。

第4問

abc103.contest.atcoder.jp

abc103.contest.atcoder.jp

西から東へ一直線に並ぶ島をつなぐ橋を最小でいくつ消せば、「島aと島bは繋がっていない」という要望をすべて満たせるか、という問題。

残り三十分ほどだし無理かな、と思って眺めていたら蟻本初級編「Saruman's Army」の類題だった。

まず要望を西側の島がもっとも小さい(西側によっている)順に要望をソートする。あとは西端から貪欲に要望を満たしていく:

n, m = map(int, input().split())
lrs = [[int(x) for x in input().split()] for _ in range(m)]
lrs.sort()
 
ans = 0
m = -1
for l, r in lrs:
    if r < m:
        m = r
    if l >= m:
        ans += 1
        m = r
 
print(ans)

最初はプライオリティキューを使うことも考えたのだけど、まったく必要なかった。

感想・結果

ノーミス全完ながらも消費時間が93:46で順位は526、パフォーマンスは1271でちょこっとだけしかレートが上がらなかった。

それも仕方ないな、というくらいにCの解法が盲点に入ってしまっていた。実装自体は全問非常に簡単だったし。

どう考えればよかったのか・・・ 解法やアルゴリズムに入るまえに「そもそも想定し得る最大はいくつだ?」「その最大になりえる条件は?」と考えていればsum(a - 1)が早い段階で見えたかもしれない。

考察のやり方のまずさと、精神力の弱さ(AがWAだったことによる動揺も大きかった)が露呈したコンテストだった。

ただ、D問題も読んでみて「歯が立たないな」という印象はなく、最終的に今持っている知識で解答に確信を持って全完できたのは良かった。ここ2ヶ月弱の競プロ精進のひとつの成果だと思う。

考察の仕方や精神力は、もっと場数を踏むことで培っていこう。

蟻本初級編攻略 - 2-4 Union Find木 後編

前回のUnion Find木実装を踏まえてAtCoder ABC/ARCのD問題や蟻本の元の例題に取り組んでみる。

ABC049 D問題 - 連結/Connectivity

abc049.contest.atcoder.jp

abc049.contest.atcoder.jp

道路ネットワークと鉄道ネットワークが走っている国で、街ごとに道路でも鉄道でも連結になっている街の数を出力する。

AとBが道路でも鉄道でも連結である必要かつ十分条件は「道路ネットワークでも鉄道ネットワークでも同一のグループに属する」ということなのが(少し考えると)わかった。

あとは街ごとに(道路ネットワークのルート, 鉄道ネットワークのルート)というタプルで所属するグループを表して、そのグループに所属する街の数を算出してやればいい。

というわけで実装:

from collections import Counter
 
n, k, l = map(int, input().split())
ps = [tuple(map(int, input().split())) for _ in range(k)]
rs = [tuple(map(int, input().split())) for _ in range(l)]
 
def root(n, u):
    if u[n] == n:
        return n
    u[n] = root(u[n], u)
    return u[n]
 
union1 = {x:x for x in range(1, n+1)}
for p, q in ps:
    union1[root(p, union1)] = root(q, union1)
 
union2 = {x:x for x in range(1, n+1)}
for r, s in rs:
    union2[root(r, union2)] = root(s, union2)
 
gs = {x:(root(x, union1), root(x, union2)) for x in range(1, n+1)}
c = Counter(gs.values())
print(' '.join(str(c[gs[x]]) for x in range(1, n+1)))

二回Union Findをして、最終的なグループをタプルで表し、グループごとの街数をCounterで算出し、街ごとに所属するグループの街数を出力している。

各Union FindがO(N log N)、最終的なグループを作るのもO(N log N)、カウンターはO(N)、出力もO(N)で計算量のオーダーはO(N log N)。

ARC097 D問題 - Equals

arc097.contest.atcoder.jp

arc097.contest.atcoder.jp

「いれかえが許された添字ペア」を何度でも利用して、配列の要素の値と添字を最大何個一致させることができるか。

「添字ペア」がUnion Findで言うところの連結の概念と一致することがわかれば実装は非常に楽。そのためには「(a, b), (b, c), (c, d) ... (y, z)」のような添字ペアの連なりがあれば、入れ替えを連続で行うことでa~zまでのすべての要素を任意の場所に配置できる、ということがわかればいい。

これは帰納法で考えられる。

添字ペアが(a, b)だけの場合:

  • ありえる二つの配置(a, b)と(b, a)が「入れ替えられる」という問題の定義から自明に可能

すでに入れ替えを連続で行うことでa1~amまでのすべての要素を任意の場所に配置できるグループに、(am, an)を加えた場合:

  • amの位置に任意の要素を持ってくることができる(帰納法)上、amとanを交換できるので任意の要素をanに持ってくることができる
  • anに位置にある要素をamに交換できる上、a1~amにあるすべての要素を任意の位置に入れ替えることができる(帰納法)のでanの要素を任意の位置に移せる

ということで証明終わり。

あとは「値iを含む位置jと位置iが同じグループに属している」という条件を満たすiの数を数えるだけ(単一であることが制約により保証されていることも重要な成立条件)。

n, m = map(int, input().split())
ps = [int(x) for x in input().split()]
xs = [tuple(map(int, input().split())) for _ in range(m)]
 
union = {x:x for x in range(1, n+1)}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for x,y in xs:
    union[root(x)] = root(y)
 
print(sum(root(i+1) == root(ps[i]) for i in range(n)))

「Union Find木の典型」と言えるような解法になる。

ARC036 D問題 - 偶数メートル

arc036.contest.atcoder.jp

街から街へ(もしかすると複数の)道が通っていて、距離の和が偶数になるような道順を通ってある街から別の街へ移動できるかを判定する問題。

各道の距離が与えられるが実際に重要なのはその距離が奇数か偶数かという点。

まず最初に浮かんだ解法がこれ:

arc036.contest.atcoder.jp

Union Find木は「偶数で繋がっているグループ」を表し、それとは別に「ある街が奇数で繋がっているグループ」をdictionaryで表す。

街A <- 奇数道-> 街B <- 偶数道 -> 街C

のようになっている場合、街Aと街Cも奇数で繋がっているので「奇数で繋がっている相手」を「偶数グループ」で表せるのがポイント。

あと

街A <- 奇数道-> 街B <- 奇数道 -> 街C

だとAとCが同じ偶数グループに属するようになる、というのも重要。

それを実装するとこうなる:

n, q = map(int, input().split())
ws = [tuple(map(int, input().split())) for _ in range(q)]
 
union = {x:x for x in range(1, n+1)}
odds = {x:None for x in range(1, n+1)}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for w, x, y, z in ws:
    rx, ry = root(x), root(y)
    if w == 2:
        print('YES' if rx == ry else 'NO')
    elif z % 2 == 0:
        union[rx] = ry
        if odds[rx] and odds[ry]:
            union[root(odds[rx])] = root(odds[ry])
        else:
            odds[rx] = odds[ry] = odds[rx] or odds[ry]
    else:
        odds[rx] = odds[rx] or ry
        odds[ry] = odds[ry] or rx
        union[rx] = root(odds[ry])
        union[ry] = root(odds[rx])

ちょっと奇数・偶数で場合分けした時のロジックが複雑になっている。

もう一つの解法はこれ:

arc036.contest.atcoder.jp

Union Find木で「奇数で繋がっている」「偶数で繋がっている」を一括で管理する。各ノードは「街、奇偶」のタプルで表している。

実装が簡単になるなーと思ったのだが・・・

n, q = map(int, input().split())
ws = [tuple(map(int, input().split())) for _ in range(q)]
 
union = {(x, z):(x, z) for x in range(1, n+1) for z in 'oe'}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]
 
for w, x, y, z in ws:
    if w == 2:
        print('YES' if root((x, 'e')) == root((y, 'e')) else 'NO')
    elif z % 2 == 0:
        rex, rey = sorted([root((x, 'e')), root((y, 'e'))])
        union[rex] = rey
        rox, roy = sorted([root((x, 'o')), root((y, 'o'))])
        union[rox] = roy
    else:
        reox, reoy = sorted([root((x, 'e')), root((y, 'o'))])
        union[reox] = reoy
        roex, roey = sorted([root((x, 'o')), root((y, 'e'))])
        union[roex] = roey

多分同じ街が二回現れることが理由なんだと思うのだが、気をつけないとroot関数が無限ループに陥ってStack Overflowを起こす。実際テストケース最後の一つだけで何回かREしてしまった・・・

かならず「ソート順が一番大きいノード」に貼るようにすればループは起こりえない。その分少しコードが冗長になってしまったが、さきほどの解に比べると場合分けのロジックの対称性がはっきりしていてやはりわかりやすいように思う。

ARC090 D問題 - People in a line

arc090.contest.atcoder.jp

arc090.contest.atcoder.jp

直線上に並んでいる人たちについて、「LはRのXm左にいる」という情報が与えられる。その情報に矛盾がないか判定する問題。

Union Find木の発展系である重みつきUnion Find木を使うと簡単に解ける:

n, m = map(int, input().split())
xs = [tuple(map(int, input().split())) for _ in range(m)]
 
union = {x:(x, 0) for x in range(1, n+1)}
def root(n):
    if union[n][0] == n:
        return union[n]
    n1, d1 = union[n]
    n2, d2 = root(n1)
    union[n] = (n2, d2+d1)
    return union[n]
 
for l, r, d in xs:
    n1, d1 = root(l)
    n2, d2 = root(r)
    if n1 != n2:
        union[n2] = (n1, d + d1 - d2)
    elif d != d2 - d1:
        print('No')
        break
else:
    print('Yes')

「ルートノードの人、その人に対しての距離」をタプルとしてUnion Find木に入れている。

新しい情報が入ってくるたびに * もしルートノードが違えばUnion Find木をアップデートできる(二つのルートノードの相対位置がわかる) * もしルートノードが同じならそのルートノードに対してのLとRの位置の差が新しい情報と一致しているかを判定する

Union Findのノードをタプル化して追加で情報を持たせるのは「偶数メートル」と同じ。ただし、root関数内でその情報をまとめる必要がでてくるのでより複雑(といってもノード間の距離の総和を追加で返すだけだが)

あと少し気の利いたところとしては最後のelse。Pythonのfor-loopには「breakしなかったら最後にこれを実行」というelse構文があることはあまり知られていないように思う。ループ中に矛盾が見つかれば"No"と出力してbreak、もし見つからなければ最後にelseで"Yes"と出力、というロジック。

POJ 1182 - 食物連鎖

「AがBを食べ、BがCを食べ、CがAを食べる」とわかっているなかで、「xとyは同じ種」「xはyを食べる」という情報が順次入ってくる。そのうちで以前来た情報や個体番号の制限と矛盾するものは無視するとして、最終的に無視した情報の数を求める問題。

入力値の一例としてはこんな感じ:

n = 100
k = 7

xs = [(1, 101, 1),
      (2, 1, 2),
      (2, 2, 3),
      (2, 3, 3),
      (1, 1, 3),
      (2, 3, 1),
      (1, 5, 5)]

nが個体数、kが情報数、xsが情報のリスト。情報は

  • 1=「同じ種」, 2=「xはyを食べる」という情報のタイプ
  • xの番号
  • yの番号

の三点からなるタプル。

上記の入力から期待される出力は3。

1 101 1 2 3 3 2 3 1

の三つの情報が矛盾する。

これもUnion Find木で解決するのだが、ノードは個体ではなく「1はAである」などの命題であり、ノードの連結は

命題Aが真の時そしてその時に限り命題Bも真 A <-> B

という関係を表している。

なので

  • 「xとyは同じ種」という情報が入ってくるとx-Aとy-Aは互いに連結、x-Bとy-B、x-Cとy-Cも同じく互いに連結
  • 「xはyを食べる」という情報が入ってくるとx-Aとy-Bは互いに連結、x-Bとy-C、x-Cとy-Aも同じく互いに連結

その上で「xはyを食べる」という情報が入ってきた時すでにx-Aとy-Aが連結であったりした場合、情報が矛盾するので無視することになる。

というのが以下の実装:

groups = ('a', 'b', 'c')
union = {(x, g):(x, g) for x in range(1, n+1)
                         for g in groups}
def root(n):
    if union[n] == n:
        return n
    union[n] = root(union[n])
    return union[n]

def is_error(i, x, y):
    if not (1 <= x <= n and 1 <= y <= n):
        return True
    if i == 1:
        return root((x, 'a')) in {root((y, g)) for g in ('b', 'c')}
    if i == 2:
        return root((x, 'a')) in {root((y, g)) for g in ('a', 'c')}
    return True

ans = 0
for i, x, y in xs:
    if is_error(i, x, y):
        ans += 1
    elif i == 1:
        for g in groups:
            union[root((x, g))] = root((y, g))
    else:
        for g1, g2 in (('a', 'b'), ('b', 'c'), ('c', 'a')):
            union[root((x, g1))] = root((y, g2))
print(ans) 

まあ実装はそこまで特筆すべきものはない・・・ 矛盾があるかをチェックする時に「x is A」と繋がったyに関する命題の矛盾だけ調べればいい、という点くらいか。

しかし解法自体は非常に面白かった。「命題がノードになる」というのはけっこう盲点で、蟻本の解説を読むまではどう書けばいいのかなかなか想像がつかなかった。

この視点はUnion Find木の応用の幅が広がりそうなのでこれからも大切にしたい。

結論

Union Find好き。