Arantium Maestum

プログラミング、囲碁、読書の話題

OCamlでTypical Dynamic Programming Contest E問題を解いてみた

ここ一ヶ月ほどAtCoderPythonですすめてきて、Pythonの遅さにひやりとすることはあっても基本的にアルゴリズムがあっていて素直な最適化を施していればABCの問題であれば通せそう、という感触を得ている。

しかし少し上の問題を見てみるとそうはいかないようだ。例えばAtCoderのTypical Dynamic Programming ContestというDP問題ばかり集めたサンプルコンテストのE問題:

tdpc.contest.atcoder.jp

かなり素直な桁DPなのだが、桁数が10000でmodをとる値が最大100なので計算量のオーダーが10000 * 100 * 10(最後の10は0~9までの数で回すから)、つまり107になる。 この計算量はコンパイル言語なら余裕だがPythonだと普通にアウトのようだ。

tdpc.contest.atcoder.jp

(まあこのコードが定数で最適かといわれると違うのだが・・・)

というわけでコンパイル言語で書いてみる。

競技プログラミング的に素直な選択はC++だが、もう少しPythonに近い宣言的な表現力のある言語がいい。

表現力という意味ではHaskellが相当高いが、純粋関数型だとアルゴリズムにけっこう工夫が必要になり、Pythonで考えるのとはまったく違う世界になる。

速度もあり、適度に関数型で宣言的に記述することができ、かつ必要なところで副作用を気軽に扱える言語、ということでOCamlを選んだ。

tdpc.contest.atcoder.jp

最長251msで無事AC。

open Batteries
 
let digitList s = s |> String.to_list |> List.map (fun c -> Char.code c - 48)
let id x = x
 
let d = Scanf.scanf "%d\n" id
let n = Scanf.scanf "%s\n" id
let m = 1000000007
 
let f x n =
  let v, a = x in
  let a1 = Array.make d 0 in
  for i = 0 to d-1 do
    for j = 0 to 9 do
      a1.((i+j) mod d) <- (a1.((i+j) mod d) + a.(i)) mod m
    done
  done;
  for j = 0 to n-1 do
    a1.((v+j) mod d) <- (a1.((v+j) mod d) + 1) mod m
  done;
  ((v+n) mod d, a1)
 
let v, a = List.fold_left f (0, Array.make d 0) (digitList n)
let () = Printf.printf "%d\n" (a.(0)  + (if v == 0 then 1 else 0) - 1)

まともにOCamlを書くのははじめてだったのでまだいろいろ汚いと思うが、とりあえず自分がしたいことを比較的素直に記述できて速度が出るのは非常にうれしい。

f関数の中で動的計画法らしくArrayに対する破壊的代入を繰り返しているが、引数には変更は加えずに命令的に作成したArrayを結果として返しているので外側からみると純粋関数だ。

そのf関数を引数としてfold_leftで畳み込んだり、文字列を|>を使った関数のチェーンで一桁の数字のリストに変換したり、といったところは抽象度が高い関数型らしい記述ができる。

その上で実行速度が107のオーダーをAtCoderの時間制限の2000msでも問題なくこなせるのだからこれは強い。

習熟度が非常に低いので実際のコンテストではまだ使えないだろうが、少しずつ過去問(とくに計算量のオーダーが大きそうな問題で)をOCamlで挑戦していきたい。

トポロジカル・ソートをPythonで実装してみた

DPとはDAGの最短経路を、トポロジカルソート順に埋めていくことで計算する手法という話からそもそもトポロジカルソートってどうやるんだっけ?となり、Pythonで一つのアプローチを実装してみた:

from collections import defaultdict, deque

v, n = map(int, input().split())
es = [[int(x) for x in input().split()] for _ in range(n)]

outs = defaultdict(list)
ins = defaultdict(int)
for v1, v2 in es:
    outs[v1].append(v2)
    ins[v2] += 1

q = deque(v1 for v1 in range(v) if ins[v1] == 0)
res = []
while q:
    v1 = q.popleft()
    res.append(v1)
    for v2 in outs[v1]:
        ins[v2] -= 1
        if ins[v2] == 0:
            q.append(v2)

print(', '.join(res))

冒頭にリンクも貼った「DPの話」の記事で言われている「入次数 0 のノード (とそこから伸びる辺) をひたすら取り除きまくる」実装。

vがノードの数、nが辺の数、esが各辺を表す親ノード、小ノードのペア。

outsが各ノードから出ていく先のノードのリストの辞書。insが各ノードの入次数の辞書。

qが入次数が0のノードのキュー。

  • キューの頭のノードを取り出し、ソートされたリストの次のメンバとしてappend
  • そのノードから伸びる各辺について
    • その行き先のノードの入次数を1減らし
    • もしそのノードの入次数が0に落ちたらキューに入れる

という処理を繰り返す。

outsとinsの作成にO(E)、qの作成にO(V)、resの作成のループはO(V+E)(外側のループがO(V)回、内側のループは合計でO(E)回)ということで全体的にO(E+V)の計算量。

実際にDPでトポロジカルソートをする必要があることはあまりないらしい(入力が与えられた時点ですでにDAGとしてソートされていることが多いとのこと)。だけどとりあえず必要だったらすぐ実装して使えるようにしておきたい。今回でわかった通りロジックも実装も非常に簡単だし。

桁DPに入門してみた

昨日のAtCoder ABC101でD問題に歯が立たず非常に悔しかったので、桁関連の問題や解法を少し読んでいる。

その経緯で桁に関する動的計画法(桁DP)について下のすごい記事に遭遇した:

pekempey.hatenablog.com

今蟻本でも動的計画法の章を読んでいることだし、この機会に桁DPに入門してみる。

以下にPythonでの実装と自分なりの気づきをメモってみた:

A以下の非負整数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0] = 1

for i, less in product(range(n), (0,1)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        dp[i + 1, less_] += dp[i, less]

print(sum(dp[n, less] for less in (0, 1)))

初期値としてdp[0, 0, 0, 0]に1を入れておくのが非直観的で面白い。

dp[i, 1]で上からi番目の桁まで比較した場合nより低い数がひとつ以上出てくるパターンの数。

dp[i, 0]は各iで必ず値が1になるはず(n以下の整数を数えているので、すべての桁がnのもの以上のパターンはnと完全一致する場合だけ)

A以下の非負整数のうち、3の付く数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0] = 1

for i, less, has3 in product(range(n), (0,1), (0,1)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        dp[i + 1, less_, has3_] += dp[i, less, has3]

print(sum(dp[n, less, 1] for less in (0, 1)))

この条件を加えるのは非常に簡単。

A以下の非負整数のうち、「3が付くまたは3の倍数」の数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0, 0] = 1

for i, less, has3, mod3 in product(range(n), (0,1), (0,1), range(3)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        mod3_ = (mod3 + d) % 3
        dp[i + 1, less_, has3_, mod3_] += dp[i, less, has3, mod3]

criteria = ((n, less, has3, mod3) for less, has3, mod3 
                                  in product((0, 1), (0,1), range(3)) 
                                  if has3 or mod3==0)

print(sum(dp[c] for c in criteria))

mod3_ = (mod3 + d) % 3の部分が「桁のトリック」っぽい。「ある数とその数の桁の和は法3で合同」という事実を使っている。

もとの記事には説明がなかったので桁関連の考え方では常識的なのかも。あまり馴染みがなかったので、最初に見たとき何をしているのか考え込んでしまった。mod3_ = (mod3*10 + d) % 3を少し短く済ませている。

A以下の非負整数のうち、「3が付くまたは3の倍数」かつ「8の倍数でない」数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0, 0, 0] = 1

for i, less, has3, mod3, mod8 in product(range(n), (0,1), (0,1), range(3), range(8)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        mod3_ = (mod3 + d) % 3
        mod8_ = (mod8*10 + d) % 8
        dp[i + 1, less_, has3_, mod3_, mod8_] += dp[i, less, has3, mod3, mod8]

criteria = ((n, less, has3, mod3, mod8) for less, has3, mod3, mod8
                                        in product((0, 1), (0,1), range(3), range(8))
                                        if (has3 or mod3==0) and mod8!=0)

print(sum(dp[c] for c in criteria))

print(sum(n % 8 and (n % 3==0 or '3' in str(n)) for n in range(int(a)+1)))

形が決まってしまえばコードの複雑さはあまり増加しないのがうれしいところ。

ちなみにテストするためには定義通りに書いた非効率な実装と、小さい数字で結果が一致するか調べるのも手だ。例えば以下のワンライナーが使える:

print(sum(n % 8 and (n % 3==0 or '3' in str(n)) for n in range(int(a)+1)))

A以下B以上の非負整数のうち、「3が付くまたは3の倍数」かつ「8の倍数でない」数の総数を求める

from itertools import product
from collections import defaultdict

a = int(input())
b = int(input())

def count(x):
    a = str(x)
    n = len(a)
    dp = defaultdict(int)
    dp[0, 0, 0, 0, 0] = 1

    for i, less, has3, mod3, mod8 in product(range(n), (0,1), (0,1), range(3), range(8)):
        max_d = 9 if less else int(a[i])
        for d in range(max_d+1):
            less_ = less or d < max_d
            has3_ = has3 or d == 3
            mod3_ = (mod3 + d) % 3
            mod8_ = (mod8*10 + d) % 8
            dp[i + 1, less_, has3_, mod3_, mod8_] += dp[i, less, has3, mod3, mod8]

    criteria = ((n, less, has3, mod3, mod8) for less, has3, mod3, mod8
                                            in product((0, 1), (0,1), range(3), range(8))
                                            if (has3 or mod3==0) and mod8!=0)

    return sum(dp[c] for c in criteria)

print(count(a) - count(b-1))

関数化すれば非常に楽。

ただし、一つのループに「B以上」という条件も組み込むほうがより効率的なはず。今度時間があったらその実装も試してみたい。

蟻本初級編攻略 - 2-2 Saruman's Army

蟻本の貪欲法で「交換しても悪化しないことがわかっている中で最大のものをとる」という考えの例題:

Saruman's Army - POJ 3069 - Virtual Judge

直線上に配置された人のうち最小何人選べば、すべての人が選ばれた人の一定距離内におさまるかを求める問題。

左端の人から一定距離以内でもっとも右側の人を選び、「その人から一定距離外でもっとも左側の人」から一定距離以内でもっとも右側の人を選び、とやり続ける:

def saruman(xs, r):
    bearer, edge = False, xs[0]
    for x, y in zip(xs, xs[1:]):
        while y > edge + r:
            if not bearer:
                yield x
                bearer, edge = True, x
            else:
                bearer, edge = False, y
    if not bearer:
        yield y

while True:
    r, n = map(int, input().split())
    if r == n == -1:
        break
    xs = sorted(int(x) for x in input().split())
    print(len(list(saruman(xs, r))))

ちょっとロジックがコードからはわかりにくい。まだ最適な記述を見つけていない気がする。やっぱり添字が一番か?

基本的にPythonでは記述上添字でループを回すことは少ないのだが(Cで実装されているCPythonインタプリタは当然裏で添字でループを回している)、こういう問題の場合は添字が自然な気がする。どういう条件下では添字がベストになるんだろう・・・

それではAtCoderの類題を解いていく。

ABC083C - Multiple Gift

abc083.contest.atcoder.jp

abc083.contest.atcoder.jp

X以上Y以下の数で構成された数列で、n個目の数はn-1個目の数の倍数かずn-1個目の数より真に大きい。この制約で最大の長さを求める。

一回ごとの増加を最低に抑えるのが貪欲ポイント。とすると倍数は2。Xに何回2をかけてY以下に収められるか。

from itertools import count, takewhile
 
x, y = map(int, input().split())
 
print(max(takewhile(lambda i: x * (2**i) <= y,  count())) + 1)

頑張ればO(1)にできたと思うのだが、とりあえずO(log N)。

ARC006C - 積み重ね

arc006.contest.atcoder.jp

arc006.contest.atcoder.jp

順番に入ってくるダンボールを床に置くか、より重いダンボールの上に置くかできる。床に直に置かれているダンボールの数を最小化する問題。

N <= 50なのでO(N2)で問題ない。とすると素直に実装できる:

n = int(input())
ws = [int(input()) for _ in range(n)]
 
xs = []
for w in ws:
    ys = [(x, i) for i, x in enumerate(xs) if x >= w]
    if ys:
        _, i = min(ys)
        xs[i] = w
    else:
        xs.append(w)
print(len(xs))

xsが各山の一番上に乗っているダンボールの重さ。入ってきたダンボールより重いものがあるか調べて、あるならその内の最小のものの上に置く(xsのその添字の値を書き換える)。「入ってきたダンボールより重い、最小のもの」が貪欲法要素。

ABC005C - おいしいたこ焼きの売り方

abc005.contest.atcoder.jp

abc005.contest.atcoder.jp

たこ焼きが出来上がる時間と客が来る時間と、たこ焼きが出来上がってから買われるまでの許容される最長の時間が与えられ、すべての客にすぐにたこ焼きを提供できるか調べる問題。

現在たこ焼きをできた順に、できた時間でキューとして保持して、客が来た時点で許容される時間外のものはキューから削除。その上でキューが空でなければキューの頭にある「許容される上で最も古い」たこ焼きを提供する。(ここが貪欲法)

from collections import deque
 
t = int(input())
n = int(input())
takos = deque(int(x) for x in input().split())
m = int(input())
customers = [int(y) for y in input().split()]
 
q = deque()
for time in customers:
    while q and q[0] < time - t:
        q.popleft()
    while takos and takos[0] <= time:
        q.append(takos.popleft())
    if q:
        q.popleft()
    else:
        print('no')
        break
else:
    print('yes')

これはO(N+M)で計算量のオーダーは最善なはず。

ABC091C - 2D Plane 2N Points

abc091.contest.atcoder.jp

abc091.contest.atcoder.jp

平面にN個ずつ赤い点と青い点がある。赤い点のx,y座標ともに青い点のものより小さい場合ペアにすることができるとして、作れるペアの最大数を求める問題。

x値が最小の青い点からはじめて順番に、ペアになり得る赤い点のうち最もy値が大きいものを選んでいく、という貪欲法:

from collections import defaultdict
 
n = int(input())
reds = [tuple(int(x) for x in input().split())for _ in range(n)]
blues = [tuple(int(x) for x in input().split())for _ in range(n)]
 
blue_dict = defaultdict(list)
for bx, by in blues:
    for rx, ry in reds:
        if rx < bx and ry < by:
            blue_dict[bx, by].append((rx, ry))

for val in blue_dict.values():
    val.sort(key=lambda x:-x[1])
 
red_dict = {}
for blue, red_list in sorted(blue_dict.items(), key=lambda x:x[0][0]):
    for red in red_list:
        if red in red_dict:
            continue
        red_dict[red] = blue
        break
 
print(len(red_dict))

うーん、これももうちょっと綺麗に書ける気がする。

計算量は

  • 各青い点ごとにペアになり得るすべての赤い点のリスト化にO(N2)
  • 各リストのソートにO(N log N)
  • 青い点ごとにペアとなる赤い点をソートされたリストから探すのにO(N2)

なのでワーストケースでO(N2)。N<=100なので問題なし。

Fence Repair問題

貪欲法の章最後の例題であるFence RepairはAtCoderの類題が見つかっていない上に、以前書いた解法を改善する余地が思い浮かばないので飛ばす。

以前書いた記事:

zehnpaard.hatenablog.com

次はようやく動的計画法について。

蟻本初級編攻略 - 2-2 ABC009C 辞書式順序ふたたび、ふたたび

zehnpaard.hatenablog.com

昨日の続き。

AtCoderの「辞書式順序ふたたび」の解法がまだややこしかったのでいろいろといじり続けたら最終的にコードが短くなり、処理速度も三倍ほど上がった。

abc009.contest.atcoder.jp

昨日の時点でのコード:

from collections import Counter
 
n, k = map(int, input().split())
s = input()
counters = {i+1:Counter(s[:i+1]) for i in range(n)}
 
def possible(solution):
    diff1 = sum(a != b for a, b in zip(solution, s))
    if diff1 > k: return False
    diff2 = sum((counters[len(solution)] - Counter(solution)).values())
    return diff1 + diff2 <= k
 
solution = []
remaining = sorted(s)
for _ in range(n):
    for j, c in enumerate(remaining):
        if c is None: continue
        solution.append(c)
        if possible(solution):
            remaining[j] = None
            break
        solution.pop()
 
print(''.join(solution))

まずpossible関数をまず分解してfor-loopの中に入れてみた。

そうすると、diff1の計算に毎回O(N)の処理をする必要はなく、「これまでのsolutionとsの不一致の数に次の文字が不一致か否かで1か0を足す」というO(1)の処理で済むことがわかった。

同様に、diff2も外側のループで一回「これまでのsolutionとs[:i+1]の文字カウントの不一致」を計算してしまえば、内側のループではO(1)で算出できる。

外側のループ一回分でO(N)で「remainingから要素を探して削除する」という処理をすれば内側のループの回数を抑えることができる上にコードもすっきりとする。

といった変更を加えていくと以下のようなコードになった:

abc009.contest.atcoder.jp

from collections import Counter
 
n, k = map(int, input().split())
s = input()
 
solution = []
remaining = sorted(s)
diff = 0
for i in range(n):
    counter = Counter(s[:i+1]) - Counter(solution)
    counts = sum(counter.values())
    for c in remaining:
        diff1 = diff + (c != s[i])
        diff2 = counts - (counter[c] > 0)
        if diff1 + diff2 <= k:
            solution.append(c)
            remaining.remove(c)
            diff = diff1
            break
 
print(''.join(solution))

テストケースは最長で24ms。Pythonインタプリタ起動に17msかかることを考えると実際のロジック部分は7msほどしかかかっていないことがわかる。

もともと170msほどかかっていたことを考えると大きな改善である。コンテスト的には無意味だが・・・

追記:

二行長くなるが、このほうがループの不変条件がすべてそろっていてわかりやすい気がする:

from collections import Counter
 
n, k = map(int, input().split())
s = input()
 
solution = []
remaining = sorted(s)
diff = 0
counter = Counter(s[:1])
counts = 1
for i in range(n):
    for c in remaining:
        diff1 = diff + (c != s[i])
        diff2 = counts - (counter[c] > 0)
        if diff1 + diff2 <= k:
            solution.append(c)
            remaining.remove(c)
            diff = diff1
            counter = Counter(s[:i+2]) - Counter(solution)
            counts = sum(counter.values())
            break
 
print(''.join(solution))

スピードは変わらず。Counterをdefaultdict(int)に変えれば

counter = Counter(s[:i+2]) - Counter(solution)

の行はO(1)にできるが

counts = sum(counter.values())

の部分はO(N)のままで、試したところスピードになんの変化もなかった。

あとはremainingを逆向きにして後ろから添字を使って「今調べているものは常にリストの最後にいるようスワップ」して見つかったらpopする、という処理はO(N)がO(1)になる。

なるのだがけっきょくO(N2)だし、コードを複雑化しての最適化は意外と効果が出にくい。

もし入力がもっと大きくて現在のコードではTLEを起こす、というような状態だと各部分でO(N)をO(1)に落とすのは意味があると思う。

すでに24msくらいで走っている場合で複雑な最適化をすると、今までインタプリタがCに落とし込んでいた部分が純Pythonで評価されるようになってでかい定数倍ペナルティをくらうことが多い。

個人的には「Pythonで最適化とはなるべくアルゴリズムを簡潔に明瞭に記述して重複している処理を減らしていく」くらいに留めるのがいい気がする。それでダメならC++で書こう・・・

蟻本初級編攻略 - 2-2 Best Cow Line

POJからのBest Cow Lineという辞書順についての問題。

Best Cow Line - POJ 3617 - Virtual Judge

牛まったく関係ない。文字列の先頭か後尾から一文字ずつとっていって、辞書順で最小の文字列を作るというもの。

例によって以前も記事を書いていた:

zehnpaard.hatenablog.com

やはりなんだかごちゃごちゃしている。添字でいろいろやっているのは各ステップで新しいデータ構造を作らないためだと思うが・・・

from collections import deque

s = deque(input())
t = []

while s:
    left = True
    for l, r in zip(s, reversed(s)):
        if l != r:
            left = l < r
            break
    c = s.popleft() if left else s.pop()
    t.append(c)

print(''.join(t))

deque使えば添字使わずに同一のデータ構造で処理が完結するし、reversedは逆順のリストではなく遅延評価的なイテレータを返す。シンプル・イズ・ベスト(でもオランダ的なシンプル)。

それではAtCoder 版!蟻本 (初級編) - Qiitaに載っているAtCoder上の類題を解いていく。

ABC076C - Dubious Document 2

abc076.contest.atcoder.jp

abc076.contest.atcoder.jp

英小文字と?でできている文字列Sと英小文字でできている文字列Tがあたえられ、Sの?を任意の文字で埋めてTを含んだ辞書順最小の文字列S'を作る問題。

辞書順最小にする戦略は簡単で、後ろの方からTに一致させられる部分文字列を探して、残りの?はすべてaで埋めればいい。

s = input()
t = input()
 
def match(s1, t):
    return all(a in (b, '?') for a, b in zip(s1, t))
 
substrings = [s[i:i+len(t)] for i in range(len(s)-len(t)+1)]
res = 'UNRESTORABLE'
for ss in reversed(substrings):
    if match(ss, t):
        left, _, right = s.rpartition(ss)
        res = left + t + right
        res = res.replace('?', 'a')
        break
print(res)

match関数で「すべての文字が一致するか?文字か」を判定してrpartitionで左側優先で文字列を分割して再合成。残った?をaに入れ替えて終了。

ABC007B - 辞書式順序

abc007.contest.atcoder.jp

abc007.contest.atcoder.jp

与えられた文字列より辞書順で小さい文字列を出力する問題。

s = input()
print('a' if 'a' < s else -1)

うんまあ・・・ あまりコメントも必要ない気がするが、'a'が最小なのでそれと比較して'a'か-1を返す。

Pythonのstring interning実装に頼ってprint(-1 if 'a' is s else 'a')だともしかするともっと効率的な可能性もあるが、実装依存なのと、問題の制限内の文字列長なら誤差なので無視。

ABC009C - 辞書式順序ふたたび

abc009.contest.atcoder.jp

abc009.contest.atcoder.jp

ある文字列の最大K個の文字の位置を入れ替えて可能なかぎり辞書順最小な文字列を作る問題。

一気に難易度がはねあがった。数分考えてあまりいい解法が思い浮かばなかったのでヒントを見てしまった・・・

ヒントを読んだら実装まではほぼ一直線:

from collections import Counter

n, k = map(int, input().split())
s = input()

def diffs(s1, s2):
    return sum(a!=b for a, b in zip(s1, s2))

def possible(i, j, c, solution, best):
    head = solution + [c]
    diff = diffs(head, s[:i+1])
    if diff > k:
        return False
    remains = Counter(d for n, d in enumerate(best) if d and n != j)
    substring = Counter(s[i+1:])
    diff2 = sum((remains - substring).values())
    return diff + diff2 <= k

best = sorted(s)
solution = []
for i in range(n):
    for j, c in enumerate(best):
        if c == None:
            continue
        if possible(i, j, c, solution, best):
            solution.append(c)
            best[j] = None
            break

print(''.join(solution))

先頭から順に「残っている文字で最小のものを当てて制約を満たすことができるか」をループで調べ続ける。「各位置において」「残っている文字を」「あてはめられるかチェック」の各パートでN分のループが回っているので最終的にO(N3)・・・ だがN<=100なのでまったく問題ない。

二つの集まりの要素の違いを数で把握したいときにCounterクラスの引き算は便利。長さが同じ文字列の比較なので一回の引き算でいくつの要素が違っているかがわかる(Counter同士の引き算ではマイナスは切り捨てになる)

追記:

dequeを使ってもう少しすっきりするよう書き直した:

from collections import Counter, deque
 
n, k = map(int, input().split())
s = input()
 
def possible(solution, s=s, k=k):
    leftstr, rightstr = s[:len(solution)], s[len(solution):]
    diff1 = sum(a != b for a, b in zip(solution, leftstr))
    if diff1 > k: return False
    diff2 = sum((Counter(s) - Counter(solution) - Counter(rightstr)).values())
    return diff1 + diff2 <= k
 
solution = []
remaining = deque(sorted(s))
for _ in range(n):
    for j, c in enumerate(remaining):
        solution.append(c)
        if possible(solution):
            del remaining[j]
            break
        solution.pop()
 
print(''.join(solution))

dequeについてはもうちょっと調べたほうがいい気がする。特殊なdoubly-linked-listのようだ。

dequeはやめて、さらに最適化してみた:

from collections import Counter
 
n, k = map(int, input().split())
s = input()
counters = {i+1:Counter(s[:i+1]) for i in range(n)}
 
def possible(solution):
    diff1 = sum(a != b for a, b in zip(solution, s))
    if diff1 > k: return False
    diff2 = sum((counters[len(solution)] - Counter(solution)).values())
    return diff1 + diff2 <= k
 
solution = []
remaining = sorted(s)
for _ in range(n):
    for j, c in enumerate(remaining):
        if c is None: continue
        solution.append(c)
        if possible(solution):
            remaining[j] = None
            break
        solution.pop()
 
print(''.join(solution))

さきほどのコードもよく見てみると読みやすくなったのはdequeの恩恵ではなくCounterをうまく使ったことによるpossible関数の整理が大きかった。

行った最適化は二点:

  • zip(solution, s[:len(solution)])zip(solution, s)に置き換えられるし後者のほうが効率的

というマイナーな点(こちらは最適化よりもコードの見通しがよくなることのほうが嬉しい)、そして:

  • Counter(s) - Counter(solution) - Counter(s[len(solution):])Counter(s) - Counter(s[len(solution):]) - Counter(solution)に置き換えられる
  • Counter(s) - Counter(s[len(solution):])Counter(s[:len(solution)])に置き換えられる
  • Counter(s[:len(solution)])はsolutionごとに計算する必要はないのでsの先頭からはじまる部分文字列について最初に計算して辞書化しておく。

これで大体二倍くらいのスピードアップが得られる。

remainingをLinked Listにするともうちょっと効率化できるか?ただ、Pythonには標準のSingly Linked List実装はないので自前で書くことになるから多分実際には逆効果だと思う。

def possible(solution, s=s, k=k):のように外側のスコープの名前を関数に束縛する手法、はやくなると言われてやってみたが計測してみるとまったく関係なかった。読みにくくなるだけだからやめよう・・・

蟻本初級編攻略 - 2-2 硬貨と区間

ようやく全探索章を終え、貪欲法の章に進む。

貪欲法に関して蟻本で出てくる最初の2問はあまり適当なAtCoderの問題がないようだ。とりあえず蟻本の問題を解いていく。

硬貨の問題

1円から500円までの硬貨を特定の数ずつ持っているとして、ある金額に合計する最小の硬貨の組み合わせを見つける問題。

*coins, total = [int(x) for x in input().split()]
values = [1, 5, 10, 50, 100, 500]

res = 0 
for max_count, value in zip(coins[::-1], values[::-1]):
    count = min(total // value, max_count)
    res += count
    total -= count * value

print(res)

大きい硬貨からはじめる貪欲法で、割り切れるだけの数とその硬貨を持っている数のminの分金額をどんどん差し引いていく。

区間スケジューリング

始まりと終わりの時間が指定されている複数のタスクがあり、時間がまったくかぶらないという条件でこなせるタスクの最大数を求める問題。

タスク終了時がもっとも早いものを貪欲法でとっていく。

この問題は三ヶ月前にも解いていた:

zehnpaard.hatenablog.com

今だったらどう書くかな、と気になったので試しに1から解いてみた:

_ = int(input())
s = [int(x) for x in input().split()]
t = [int(x) for x in input().split()]

solution = [(None, -1)]
xs = sorted(zip(s, t), key=lambda x:x[1])
for task in xs:
    if task[0] > solution[-1][1]:
        solution.append(task)

print(len(solution)-1)

ほとんど変わっていない。変数や行数が減っているがその分少しだけロジックがややこしくなっているか?