Arantium Maestum

プログラミング、囲碁、読書の話題

桁DPに入門してみた

昨日のAtCoder ABC101でD問題に歯が立たず非常に悔しかったので、桁関連の問題や解法を少し読んでいる。

その経緯で桁に関する動的計画法(桁DP)について下のすごい記事に遭遇した:

pekempey.hatenablog.com

今蟻本でも動的計画法の章を読んでいることだし、この機会に桁DPに入門してみる。

以下にPythonでの実装と自分なりの気づきをメモってみた:

A以下の非負整数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0] = 1

for i, less in product(range(n), (0,1)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        dp[i + 1, less_] += dp[i, less]

print(sum(dp[n, less] for less in (0, 1)))

初期値としてdp[0, 0, 0, 0]に1を入れておくのが非直観的で面白い。

dp[i, 1]で上からi番目の桁まで比較した場合nより低い数がひとつ以上出てくるパターンの数。

dp[i, 0]は各iで必ず値が1になるはず(n以下の整数を数えているので、すべての桁がnのもの以上のパターンはnと完全一致する場合だけ)

A以下の非負整数のうち、3の付く数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0] = 1

for i, less, has3 in product(range(n), (0,1), (0,1)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        dp[i + 1, less_, has3_] += dp[i, less, has3]

print(sum(dp[n, less, 1] for less in (0, 1)))

この条件を加えるのは非常に簡単。

A以下の非負整数のうち、「3が付くまたは3の倍数」の数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0, 0] = 1

for i, less, has3, mod3 in product(range(n), (0,1), (0,1), range(3)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        mod3_ = (mod3 + d) % 3
        dp[i + 1, less_, has3_, mod3_] += dp[i, less, has3, mod3]

criteria = ((n, less, has3, mod3) for less, has3, mod3 
                                  in product((0, 1), (0,1), range(3)) 
                                  if has3 or mod3==0)

print(sum(dp[c] for c in criteria))

mod3_ = (mod3 + d) % 3の部分が「桁のトリック」っぽい。「ある数とその数の桁の和は法3で合同」という事実を使っている。

もとの記事には説明がなかったので桁関連の考え方では常識的なのかも。あまり馴染みがなかったので、最初に見たとき何をしているのか考え込んでしまった。mod3_ = (mod3*10 + d) % 3を少し短く済ませている。

A以下の非負整数のうち、「3が付くまたは3の倍数」かつ「8の倍数でない」数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0, 0, 0] = 1

for i, less, has3, mod3, mod8 in product(range(n), (0,1), (0,1), range(3), range(8)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        mod3_ = (mod3 + d) % 3
        mod8_ = (mod8*10 + d) % 8
        dp[i + 1, less_, has3_, mod3_, mod8_] += dp[i, less, has3, mod3, mod8]

criteria = ((n, less, has3, mod3, mod8) for less, has3, mod3, mod8
                                        in product((0, 1), (0,1), range(3), range(8))
                                        if (has3 or mod3==0) and mod8!=0)

print(sum(dp[c] for c in criteria))

print(sum(n % 8 and (n % 3==0 or '3' in str(n)) for n in range(int(a)+1)))

形が決まってしまえばコードの複雑さはあまり増加しないのがうれしいところ。

ちなみにテストするためには定義通りに書いた非効率な実装と、小さい数字で結果が一致するか調べるのも手だ。例えば以下のワンライナーが使える:

print(sum(n % 8 and (n % 3==0 or '3' in str(n)) for n in range(int(a)+1)))

A以下B以上の非負整数のうち、「3が付くまたは3の倍数」かつ「8の倍数でない」数の総数を求める

from itertools import product
from collections import defaultdict

a = int(input())
b = int(input())

def count(x):
    a = str(x)
    n = len(a)
    dp = defaultdict(int)
    dp[0, 0, 0, 0, 0] = 1

    for i, less, has3, mod3, mod8 in product(range(n), (0,1), (0,1), range(3), range(8)):
        max_d = 9 if less else int(a[i])
        for d in range(max_d+1):
            less_ = less or d < max_d
            has3_ = has3 or d == 3
            mod3_ = (mod3 + d) % 3
            mod8_ = (mod8*10 + d) % 8
            dp[i + 1, less_, has3_, mod3_, mod8_] += dp[i, less, has3, mod3, mod8]

    criteria = ((n, less, has3, mod3, mod8) for less, has3, mod3, mod8
                                            in product((0, 1), (0,1), range(3), range(8))
                                            if (has3 or mod3==0) and mod8!=0)

    return sum(dp[c] for c in criteria)

print(count(a) - count(b-1))

関数化すれば非常に楽。

ただし、一つのループに「B以上」という条件も組み込むほうがより効率的なはず。今度時間があったらその実装も試してみたい。

蟻本初級編攻略 - 2-2 Saruman's Army

蟻本の貪欲法で「交換しても悪化しないことがわかっている中で最大のものをとる」という考えの例題:

Saruman's Army - POJ 3069 - Virtual Judge

直線上に配置された人のうち最小何人選べば、すべての人が選ばれた人の一定距離内におさまるかを求める問題。

左端の人から一定距離以内でもっとも右側の人を選び、「その人から一定距離外でもっとも左側の人」から一定距離以内でもっとも右側の人を選び、とやり続ける:

def saruman(xs, r):
    bearer, edge = False, xs[0]
    for x, y in zip(xs, xs[1:]):
        while y > edge + r:
            if not bearer:
                yield x
                bearer, edge = True, x
            else:
                bearer, edge = False, y
    if not bearer:
        yield y

while True:
    r, n = map(int, input().split())
    if r == n == -1:
        break
    xs = sorted(int(x) for x in input().split())
    print(len(list(saruman(xs, r))))

ちょっとロジックがコードからはわかりにくい。まだ最適な記述を見つけていない気がする。やっぱり添字が一番か?

基本的にPythonでは記述上添字でループを回すことは少ないのだが(Cで実装されているCPythonインタプリタは当然裏で添字でループを回している)、こういう問題の場合は添字が自然な気がする。どういう条件下では添字がベストになるんだろう・・・

それではAtCoderの類題を解いていく。

ABC083C - Multiple Gift

abc083.contest.atcoder.jp

abc083.contest.atcoder.jp

X以上Y以下の数で構成された数列で、n個目の数はn-1個目の数の倍数かずn-1個目の数より真に大きい。この制約で最大の長さを求める。

一回ごとの増加を最低に抑えるのが貪欲ポイント。とすると倍数は2。Xに何回2をかけてY以下に収められるか。

from itertools import count, takewhile
 
x, y = map(int, input().split())
 
print(max(takewhile(lambda i: x * (2**i) <= y,  count())) + 1)

頑張ればO(1)にできたと思うのだが、とりあえずO(log N)。

ARC006C - 積み重ね

arc006.contest.atcoder.jp

arc006.contest.atcoder.jp

順番に入ってくるダンボールを床に置くか、より重いダンボールの上に置くかできる。床に直に置かれているダンボールの数を最小化する問題。

N <= 50なのでO(N2)で問題ない。とすると素直に実装できる:

n = int(input())
ws = [int(input()) for _ in range(n)]
 
xs = []
for w in ws:
    ys = [(x, i) for i, x in enumerate(xs) if x >= w]
    if ys:
        _, i = min(ys)
        xs[i] = w
    else:
        xs.append(w)
print(len(xs))

xsが各山の一番上に乗っているダンボールの重さ。入ってきたダンボールより重いものがあるか調べて、あるならその内の最小のものの上に置く(xsのその添字の値を書き換える)。「入ってきたダンボールより重い、最小のもの」が貪欲法要素。

ABC005C - おいしいたこ焼きの売り方

abc005.contest.atcoder.jp

abc005.contest.atcoder.jp

たこ焼きが出来上がる時間と客が来る時間と、たこ焼きが出来上がってから買われるまでの許容される最長の時間が与えられ、すべての客にすぐにたこ焼きを提供できるか調べる問題。

現在たこ焼きをできた順に、できた時間でキューとして保持して、客が来た時点で許容される時間外のものはキューから削除。その上でキューが空でなければキューの頭にある「許容される上で最も古い」たこ焼きを提供する。(ここが貪欲法)

from collections import deque
 
t = int(input())
n = int(input())
takos = deque(int(x) for x in input().split())
m = int(input())
customers = [int(y) for y in input().split()]
 
q = deque()
for time in customers:
    while q and q[0] < time - t:
        q.popleft()
    while takos and takos[0] <= time:
        q.append(takos.popleft())
    if q:
        q.popleft()
    else:
        print('no')
        break
else:
    print('yes')

これはO(N+M)で計算量のオーダーは最善なはず。

ABC091C - 2D Plane 2N Points

abc091.contest.atcoder.jp

abc091.contest.atcoder.jp

平面にN個ずつ赤い点と青い点がある。赤い点のx,y座標ともに青い点のものより小さい場合ペアにすることができるとして、作れるペアの最大数を求める問題。

x値が最小の青い点からはじめて順番に、ペアになり得る赤い点のうち最もy値が大きいものを選んでいく、という貪欲法:

from collections import defaultdict
 
n = int(input())
reds = [tuple(int(x) for x in input().split())for _ in range(n)]
blues = [tuple(int(x) for x in input().split())for _ in range(n)]
 
blue_dict = defaultdict(list)
for bx, by in blues:
    for rx, ry in reds:
        if rx < bx and ry < by:
            blue_dict[bx, by].append((rx, ry))

for val in blue_dict.values():
    val.sort(key=lambda x:-x[1])
 
red_dict = {}
for blue, red_list in sorted(blue_dict.items(), key=lambda x:x[0][0]):
    for red in red_list:
        if red in red_dict:
            continue
        red_dict[red] = blue
        break
 
print(len(red_dict))

うーん、これももうちょっと綺麗に書ける気がする。

計算量は

  • 各青い点ごとにペアになり得るすべての赤い点のリスト化にO(N2)
  • 各リストのソートにO(N log N)
  • 青い点ごとにペアとなる赤い点をソートされたリストから探すのにO(N2)

なのでワーストケースでO(N2)。N<=100なので問題なし。

Fence Repair問題

貪欲法の章最後の例題であるFence RepairはAtCoderの類題が見つかっていない上に、以前書いた解法を改善する余地が思い浮かばないので飛ばす。

以前書いた記事:

zehnpaard.hatenablog.com

次はようやく動的計画法について。

蟻本初級編攻略 - 2-2 ABC009C 辞書式順序ふたたび、ふたたび

zehnpaard.hatenablog.com

昨日の続き。

AtCoderの「辞書式順序ふたたび」の解法がまだややこしかったのでいろいろといじり続けたら最終的にコードが短くなり、処理速度も三倍ほど上がった。

abc009.contest.atcoder.jp

昨日の時点でのコード:

from collections import Counter
 
n, k = map(int, input().split())
s = input()
counters = {i+1:Counter(s[:i+1]) for i in range(n)}
 
def possible(solution):
    diff1 = sum(a != b for a, b in zip(solution, s))
    if diff1 > k: return False
    diff2 = sum((counters[len(solution)] - Counter(solution)).values())
    return diff1 + diff2 <= k
 
solution = []
remaining = sorted(s)
for _ in range(n):
    for j, c in enumerate(remaining):
        if c is None: continue
        solution.append(c)
        if possible(solution):
            remaining[j] = None
            break
        solution.pop()
 
print(''.join(solution))

まずpossible関数をまず分解してfor-loopの中に入れてみた。

そうすると、diff1の計算に毎回O(N)の処理をする必要はなく、「これまでのsolutionとsの不一致の数に次の文字が不一致か否かで1か0を足す」というO(1)の処理で済むことがわかった。

同様に、diff2も外側のループで一回「これまでのsolutionとs[:i+1]の文字カウントの不一致」を計算してしまえば、内側のループではO(1)で算出できる。

外側のループ一回分でO(N)で「remainingから要素を探して削除する」という処理をすれば内側のループの回数を抑えることができる上にコードもすっきりとする。

といった変更を加えていくと以下のようなコードになった:

abc009.contest.atcoder.jp

from collections import Counter
 
n, k = map(int, input().split())
s = input()
 
solution = []
remaining = sorted(s)
diff = 0
for i in range(n):
    counter = Counter(s[:i+1]) - Counter(solution)
    counts = sum(counter.values())
    for c in remaining:
        diff1 = diff + (c != s[i])
        diff2 = counts - (counter[c] > 0)
        if diff1 + diff2 <= k:
            solution.append(c)
            remaining.remove(c)
            diff = diff1
            break
 
print(''.join(solution))

テストケースは最長で24ms。Pythonインタプリタ起動に17msかかることを考えると実際のロジック部分は7msほどしかかかっていないことがわかる。

もともと170msほどかかっていたことを考えると大きな改善である。コンテスト的には無意味だが・・・

追記:

二行長くなるが、このほうがループの不変条件がすべてそろっていてわかりやすい気がする:

from collections import Counter
 
n, k = map(int, input().split())
s = input()
 
solution = []
remaining = sorted(s)
diff = 0
counter = Counter(s[:1])
counts = 1
for i in range(n):
    for c in remaining:
        diff1 = diff + (c != s[i])
        diff2 = counts - (counter[c] > 0)
        if diff1 + diff2 <= k:
            solution.append(c)
            remaining.remove(c)
            diff = diff1
            counter = Counter(s[:i+2]) - Counter(solution)
            counts = sum(counter.values())
            break
 
print(''.join(solution))

スピードは変わらず。Counterをdefaultdict(int)に変えれば

counter = Counter(s[:i+2]) - Counter(solution)

の行はO(1)にできるが

counts = sum(counter.values())

の部分はO(N)のままで、試したところスピードになんの変化もなかった。

あとはremainingを逆向きにして後ろから添字を使って「今調べているものは常にリストの最後にいるようスワップ」して見つかったらpopする、という処理はO(N)がO(1)になる。

なるのだがけっきょくO(N2)だし、コードを複雑化しての最適化は意外と効果が出にくい。

もし入力がもっと大きくて現在のコードではTLEを起こす、というような状態だと各部分でO(N)をO(1)に落とすのは意味があると思う。

すでに24msくらいで走っている場合で複雑な最適化をすると、今までインタプリタがCに落とし込んでいた部分が純Pythonで評価されるようになってでかい定数倍ペナルティをくらうことが多い。

個人的には「Pythonで最適化とはなるべくアルゴリズムを簡潔に明瞭に記述して重複している処理を減らしていく」くらいに留めるのがいい気がする。それでダメならC++で書こう・・・

蟻本初級編攻略 - 2-2 Best Cow Line

POJからのBest Cow Lineという辞書順についての問題。

Best Cow Line - POJ 3617 - Virtual Judge

牛まったく関係ない。文字列の先頭か後尾から一文字ずつとっていって、辞書順で最小の文字列を作るというもの。

例によって以前も記事を書いていた:

zehnpaard.hatenablog.com

やはりなんだかごちゃごちゃしている。添字でいろいろやっているのは各ステップで新しいデータ構造を作らないためだと思うが・・・

from collections import deque

s = deque(input())
t = []

while s:
    left = True
    for l, r in zip(s, reversed(s)):
        if l != r:
            left = l < r
            break
    c = s.popleft() if left else s.pop()
    t.append(c)

print(''.join(t))

deque使えば添字使わずに同一のデータ構造で処理が完結するし、reversedは逆順のリストではなく遅延評価的なイテレータを返す。シンプル・イズ・ベスト(でもオランダ的なシンプル)。

それではAtCoder 版!蟻本 (初級編) - Qiitaに載っているAtCoder上の類題を解いていく。

ABC076C - Dubious Document 2

abc076.contest.atcoder.jp

abc076.contest.atcoder.jp

英小文字と?でできている文字列Sと英小文字でできている文字列Tがあたえられ、Sの?を任意の文字で埋めてTを含んだ辞書順最小の文字列S'を作る問題。

辞書順最小にする戦略は簡単で、後ろの方からTに一致させられる部分文字列を探して、残りの?はすべてaで埋めればいい。

s = input()
t = input()
 
def match(s1, t):
    return all(a in (b, '?') for a, b in zip(s1, t))
 
substrings = [s[i:i+len(t)] for i in range(len(s)-len(t)+1)]
res = 'UNRESTORABLE'
for ss in reversed(substrings):
    if match(ss, t):
        left, _, right = s.rpartition(ss)
        res = left + t + right
        res = res.replace('?', 'a')
        break
print(res)

match関数で「すべての文字が一致するか?文字か」を判定してrpartitionで左側優先で文字列を分割して再合成。残った?をaに入れ替えて終了。

ABC007B - 辞書式順序

abc007.contest.atcoder.jp

abc007.contest.atcoder.jp

与えられた文字列より辞書順で小さい文字列を出力する問題。

s = input()
print('a' if 'a' < s else -1)

うんまあ・・・ あまりコメントも必要ない気がするが、'a'が最小なのでそれと比較して'a'か-1を返す。

Pythonのstring interning実装に頼ってprint(-1 if 'a' is s else 'a')だともしかするともっと効率的な可能性もあるが、実装依存なのと、問題の制限内の文字列長なら誤差なので無視。

ABC009C - 辞書式順序ふたたび

abc009.contest.atcoder.jp

abc009.contest.atcoder.jp

ある文字列の最大K個の文字の位置を入れ替えて可能なかぎり辞書順最小な文字列を作る問題。

一気に難易度がはねあがった。数分考えてあまりいい解法が思い浮かばなかったのでヒントを見てしまった・・・

ヒントを読んだら実装まではほぼ一直線:

from collections import Counter

n, k = map(int, input().split())
s = input()

def diffs(s1, s2):
    return sum(a!=b for a, b in zip(s1, s2))

def possible(i, j, c, solution, best):
    head = solution + [c]
    diff = diffs(head, s[:i+1])
    if diff > k:
        return False
    remains = Counter(d for n, d in enumerate(best) if d and n != j)
    substring = Counter(s[i+1:])
    diff2 = sum((remains - substring).values())
    return diff + diff2 <= k

best = sorted(s)
solution = []
for i in range(n):
    for j, c in enumerate(best):
        if c == None:
            continue
        if possible(i, j, c, solution, best):
            solution.append(c)
            best[j] = None
            break

print(''.join(solution))

先頭から順に「残っている文字で最小のものを当てて制約を満たすことができるか」をループで調べ続ける。「各位置において」「残っている文字を」「あてはめられるかチェック」の各パートでN分のループが回っているので最終的にO(N3)・・・ だがN<=100なのでまったく問題ない。

二つの集まりの要素の違いを数で把握したいときにCounterクラスの引き算は便利。長さが同じ文字列の比較なので一回の引き算でいくつの要素が違っているかがわかる(Counter同士の引き算ではマイナスは切り捨てになる)

追記:

dequeを使ってもう少しすっきりするよう書き直した:

from collections import Counter, deque
 
n, k = map(int, input().split())
s = input()
 
def possible(solution, s=s, k=k):
    leftstr, rightstr = s[:len(solution)], s[len(solution):]
    diff1 = sum(a != b for a, b in zip(solution, leftstr))
    if diff1 > k: return False
    diff2 = sum((Counter(s) - Counter(solution) - Counter(rightstr)).values())
    return diff1 + diff2 <= k
 
solution = []
remaining = deque(sorted(s))
for _ in range(n):
    for j, c in enumerate(remaining):
        solution.append(c)
        if possible(solution):
            del remaining[j]
            break
        solution.pop()
 
print(''.join(solution))

dequeについてはもうちょっと調べたほうがいい気がする。特殊なdoubly-linked-listのようだ。

dequeはやめて、さらに最適化してみた:

from collections import Counter
 
n, k = map(int, input().split())
s = input()
counters = {i+1:Counter(s[:i+1]) for i in range(n)}
 
def possible(solution):
    diff1 = sum(a != b for a, b in zip(solution, s))
    if diff1 > k: return False
    diff2 = sum((counters[len(solution)] - Counter(solution)).values())
    return diff1 + diff2 <= k
 
solution = []
remaining = sorted(s)
for _ in range(n):
    for j, c in enumerate(remaining):
        if c is None: continue
        solution.append(c)
        if possible(solution):
            remaining[j] = None
            break
        solution.pop()
 
print(''.join(solution))

さきほどのコードもよく見てみると読みやすくなったのはdequeの恩恵ではなくCounterをうまく使ったことによるpossible関数の整理が大きかった。

行った最適化は二点:

  • zip(solution, s[:len(solution)])zip(solution, s)に置き換えられるし後者のほうが効率的

というマイナーな点(こちらは最適化よりもコードの見通しがよくなることのほうが嬉しい)、そして:

  • Counter(s) - Counter(solution) - Counter(s[len(solution):])Counter(s) - Counter(s[len(solution):]) - Counter(solution)に置き換えられる
  • Counter(s) - Counter(s[len(solution):])Counter(s[:len(solution)])に置き換えられる
  • Counter(s[:len(solution)])はsolutionごとに計算する必要はないのでsの先頭からはじまる部分文字列について最初に計算して辞書化しておく。

これで大体二倍くらいのスピードアップが得られる。

remainingをLinked Listにするともうちょっと効率化できるか?ただ、Pythonには標準のSingly Linked List実装はないので自前で書くことになるから多分実際には逆効果だと思う。

def possible(solution, s=s, k=k):のように外側のスコープの名前を関数に束縛する手法、はやくなると言われてやってみたが計測してみるとまったく関係なかった。読みにくくなるだけだからやめよう・・・

蟻本初級編攻略 - 2-2 硬貨と区間

ようやく全探索章を終え、貪欲法の章に進む。

貪欲法に関して蟻本で出てくる最初の2問はあまり適当なAtCoderの問題がないようだ。とりあえず蟻本の問題を解いていく。

硬貨の問題

1円から500円までの硬貨を特定の数ずつ持っているとして、ある金額に合計する最小の硬貨の組み合わせを見つける問題。

*coins, total = [int(x) for x in input().split()]
values = [1, 5, 10, 50, 100, 500]

res = 0 
for max_count, value in zip(coins[::-1], values[::-1]):
    count = min(total // value, max_count)
    res += count
    total -= count * value

print(res)

大きい硬貨からはじめる貪欲法で、割り切れるだけの数とその硬貨を持っている数のminの分金額をどんどん差し引いていく。

区間スケジューリング

始まりと終わりの時間が指定されている複数のタスクがあり、時間がまったくかぶらないという条件でこなせるタスクの最大数を求める問題。

タスク終了時がもっとも早いものを貪欲法でとっていく。

この問題は三ヶ月前にも解いていた:

zehnpaard.hatenablog.com

今だったらどう書くかな、と気になったので試しに1から解いてみた:

_ = int(input())
s = [int(x) for x in input().split()]
t = [int(x) for x in input().split()]

solution = [(None, -1)]
xs = sorted(zip(s, t), key=lambda x:x[1])
for task in xs:
    if task[0] > solution[-1][1]:
        solution.append(task)

print(len(solution)-1)

ほとんど変わっていない。変数や行数が減っているがその分少しだけロジックがややこしくなっているか?

蟻本初級編攻略 - 2-1 特殊な状態の列挙

蟻本ではちゃんと例題めいたものが載っていなかった話題。

Pythonだとitertools.permutationなどで簡単に実装できる。

ABC054C - One-stroke Path

abc054.contest.atcoder.jp

abc054.contest.atcoder.jp

重み無し無向グラフを、特定の始点「1」からはじめてすべての頂点を一回だけ訪れるパスを数える問題。

「1」以外の頂点を並べる順番をすべて列挙してから始点に「1」を加え、隣同士のペアがすべて辺の集合に含まれているかを調べる。

from itertools import permutations

n, m = map(int, input().split())
edges = set()
for _ in range(m):
    a, b = map(int, input().split())
    edges.add((a, b))
    edges.add((b, a))

orderings = ((1,) + p for p in permutations(range(2, n+1)))
print(sum(all(move in edges for move in zip(p, p[1:])) for p in orderings))

ちなみに始点「1」を加えるのに、AtCoderで使えるPython3のバージョンが3.4なので(1,)+pとしている。

Python3.5からはUnpacking SyntaxがPEP448で拡張されて'(1, *p)'と書ける。

蟻本初級編攻略 - 2-1 迷路の最短路

ここらへんは全部今年の三月に記事を書いているなー。

zehnpaard.hatenablog.com

このころは

map_をdict化すればかなり綺麗になるが、競プロ的にどうなんだろう・・・

などと言っていたが、現在は躊躇なくdictionary化するなあ。実際そこのO(N)が問題になることはまずない。(すくなくともABCのD問題くらいでは)

「もしコストが気になるならはじめからdictionaryとしてパースしてしまえばいいじゃない」と心の中のIQ145の女子高生に囁かれて試してみたのが以下のコード:

from collections import deque

n, m = map(int, input().split())
maze = {(i,j):c for i in range(n) for j, c in enumerate(input())}

def bfs(coord):
    dq = deque()
    dq.appendleft((0, coord))
    while dq:
        steps, (i, j) = dq.pop()
        current = maze.get((i, j), '#')
        if current == 'G':
            return steps
        if current == '#':
            continue
        maze[(i, j)] = '#'
        dirs = ((-1, 0), (1, 0), (0, 1), (0, -1))
        dq.extendleft((steps+1, (i+di, j+dj)) for di, dj in dirs)

start = next(coord for coord, c in maze.items() if c == 'S')
print(bfs(start))

DFSをスタックでやることの隠れたメリットはBFSとDFSがほぼ同型になること。DFSではLIFOのスタックを使っていたのをBFSではFIFOのキューを使う、以外はほぼ同じコード。

ABC007C - 幅優先探索

abc007.contest.atcoder.jp

abc007.contest.atcoder.jp

蟻本の問題とほぼ同じ。スタートとゴールが地図とは別に座標で与えられることとその座標が0-indexedじゃないことが注意点。

from collections import deque
 
n, m = map(int, input().split())
start = tuple(int(x)-1 for x in input().split())
goal = tuple(int(x)-1 for x in input().split())
maze = {(i,j):c for i in range(n) for j, c in enumerate(input())}
 
def bfs(coord):
    dq = deque()
    dq.appendleft((0, coord))
    while dq:
        steps, (i, j) = dq.pop()
        if (i, j) == goal:
            return steps
        if maze.get((i, j), '#') == '#':
            continue
        maze[(i, j)] = '#'
        dirs = ((-1, 0), (1, 0), (0, 1), (0, -1))
        dq.extendleft((steps+1, (i+di, j+dj)) for di, dj in dirs)
 
print(bfs(start))

AtCoderはほぼすべての問題が1-indexedで修正するのをうっかりすると後から見つけにくい。

ABC088D - Grid Repainting

abc088.contest.atcoder.jp

abc088.contest.atcoder.jp

「BFSの結果のルート」と「元からあった黒いマス」以外のマスをすべて黒く塗れるという点に気づけば、あとはほぼ同じ問題。

from collections import deque
 
h, w = map(int, input().split())
maze = {(i,j):c for i in range(h) for j, c in enumerate(input())}
 
blacks = sum(c == '#' for c in maze.values())
 
def bfs(i, j):
    dq = deque()
    dq.appendleft((1, (i, j)))
    while dq:
        steps, (i, j) = dq.pop()
        if (i, j) == (h-1, w-1):
            return steps
        if maze.get((i, j), '#') == '#':
            continue
        maze[(i, j)] = '#'
        dirs = ((-1, 0), (1, 0), (0, 1), (0, -1))
        dq.extendleft((steps+1, (i+di, j+dj)) for di, dj in dirs)
 
res = bfs(0, 0)
print(-1 if res is None else h*w - res - blacks)

ただ、この問題ではそもそもスタート地点からゴールまでの経路が存在しないケースがあり得る。最初に提出したときはうっかりしていた・・・

ARC005C - 器物損壊!高橋君

arc005.contest.atcoder.jp

arc005.contest.atcoder.jp

壁を二回まで破壊して通過できるという条件でスタートからゴールまで到達できるか、という問題。

最初は「DFSで隣接している壁を探して、隣接している壁がダブっているかどうかで繋がっている『部屋』を見つけて2ステップでゴールまで到達できるか」的なコードを書いていたのだが、あえなくTLEした。

その後に記事の説明を読んで01-BFSなる用語を調べたところ、非常にエレガントな解法で感動した。

「距離」を「移動コスト」再定義して「通路の移動コストは0」「壁の移動コストは1」とすれば「壁を最大二回移動してゴールに到達できるか」が算出できる。

h, w = map(int, input().split())
maze = {(i,j):c for i in range(h) for j, c in enumerate(input())}
dirs = ((-1, 0), (1, 0), (0, 1), (0, -1))
 
def bfs01(start, maze=maze, dirs=dirs):
    q = [[start], [], []]
    while q:
        i, j = q[0].pop()
        maze[i, j] = 'X'
        for di, dj in dirs:
            next_ = maze.get((i+di, j+dj), 'X')
            if next_ == 'g':
                return True
            if next_ == 'X' or (next_ == '#' and len(q) == 1):
                continue
            q[next_ == '#'].append((i+di, j+dj))
        while q and not q[0]:
            q = q[1:]
    return False
 
start = next(coord for coord, c in maze.items() if c == 's')
print('YES' if bfs01(start) else 'NO')

ただ、アイディアのエレガンスの対してコード実装が泥臭いのが難点・・・ 

最初はPriority QueueでやってみたのだがやはりO(n log n)だとTLEになる。リストを束ねて「スタックのキュー」的なデータ構造にした。

これで実行時間は800ms弱。Pythonだとここらへんが上限だろうか?

ちなみにPyPy3で同じコードを走らせてみると、Python3で20msぐらいかかるテストケースで165msかかり、すこし重いテストケースは軒並みMLEを起こしていた。

すくなくともAtCoderでPyPy3は使い物にならないと言ってしまってもいいのではないか。

理由としては:

  • JITコンパイラが温まりきらない(数秒は継続して走っていないとJITの恩恵を受けない)
  • PyPyはCPythonよりスタートアップに時間がかかる
  • PyPy3はPyPy2に比べて最適化などがまだ甘い(Python2.7の仕様がもう10年もほとんど変わっていないので追いやすかった?)
  • AtCoderが使っているPyPy3のバージョンが「CPython3より遅いことも多い」と知られている古いバージョン
  • PyPyはitertoolsやlist comprehensionなどよりfor loopなどのほうが効率化しやすい

などが挙げられると思う。これらを合わせると正直PyPy3は当面無視するのが得策だろう。AtCoderでCythonが使えたらなかなか面白いことになるとは思うが・・・