Arantium Maestum

プログラミング、囲碁、読書の話題

蟻本初級編攻略 - 2-2 ABC009C 辞書式順序ふたたび、ふたたび

zehnpaard.hatenablog.com

昨日の続き。

AtCoderの「辞書式順序ふたたび」の解法がまだややこしかったのでいろいろといじり続けたら最終的にコードが短くなり、処理速度も三倍ほど上がった。

abc009.contest.atcoder.jp

昨日の時点でのコード:

from collections import Counter
 
n, k = map(int, input().split())
s = input()
counters = {i+1:Counter(s[:i+1]) for i in range(n)}
 
def possible(solution):
    diff1 = sum(a != b for a, b in zip(solution, s))
    if diff1 > k: return False
    diff2 = sum((counters[len(solution)] - Counter(solution)).values())
    return diff1 + diff2 <= k
 
solution = []
remaining = sorted(s)
for _ in range(n):
    for j, c in enumerate(remaining):
        if c is None: continue
        solution.append(c)
        if possible(solution):
            remaining[j] = None
            break
        solution.pop()
 
print(''.join(solution))

まずpossible関数をまず分解してfor-loopの中に入れてみた。

そうすると、diff1の計算に毎回O(N)の処理をする必要はなく、「これまでのsolutionとsの不一致の数に次の文字が不一致か否かで1か0を足す」というO(1)の処理で済むことがわかった。

同様に、diff2も外側のループで一回「これまでのsolutionとs[:i+1]の文字カウントの不一致」を計算してしまえば、内側のループではO(1)で算出できる。

外側のループ一回分でO(N)で「remainingから要素を探して削除する」という処理をすれば内側のループの回数を抑えることができる上にコードもすっきりとする。

といった変更を加えていくと以下のようなコードになった:

abc009.contest.atcoder.jp

from collections import Counter
 
n, k = map(int, input().split())
s = input()
 
solution = []
remaining = sorted(s)
diff = 0
for i in range(n):
    counter = Counter(s[:i+1]) - Counter(solution)
    counts = sum(counter.values())
    for c in remaining:
        diff1 = diff + (c != s[i])
        diff2 = counts - (counter[c] > 0)
        if diff1 + diff2 <= k:
            solution.append(c)
            remaining.remove(c)
            diff = diff1
            break
 
print(''.join(solution))

テストケースは最長で24ms。Pythonインタプリタ起動に17msかかることを考えると実際のロジック部分は7msほどしかかかっていないことがわかる。

もともと170msほどかかっていたことを考えると大きな改善である。コンテスト的には無意味だが・・・

追記:

二行長くなるが、このほうがループの不変条件がすべてそろっていてわかりやすい気がする:

from collections import Counter
 
n, k = map(int, input().split())
s = input()
 
solution = []
remaining = sorted(s)
diff = 0
counter = Counter(s[:1])
counts = 1
for i in range(n):
    for c in remaining:
        diff1 = diff + (c != s[i])
        diff2 = counts - (counter[c] > 0)
        if diff1 + diff2 <= k:
            solution.append(c)
            remaining.remove(c)
            diff = diff1
            counter = Counter(s[:i+2]) - Counter(solution)
            counts = sum(counter.values())
            break
 
print(''.join(solution))

スピードは変わらず。Counterをdefaultdict(int)に変えれば

counter = Counter(s[:i+2]) - Counter(solution)

の行はO(1)にできるが

counts = sum(counter.values())

の部分はO(N)のままで、試したところスピードになんの変化もなかった。

あとはremainingを逆向きにして後ろから添字を使って「今調べているものは常にリストの最後にいるようスワップ」して見つかったらpopする、という処理はO(N)がO(1)になる。

なるのだがけっきょくO(N2)だし、コードを複雑化しての最適化は意外と効果が出にくい。

もし入力がもっと大きくて現在のコードではTLEを起こす、というような状態だと各部分でO(N)をO(1)に落とすのは意味があると思う。

すでに24msくらいで走っている場合で複雑な最適化をすると、今までインタプリタがCに落とし込んでいた部分が純Pythonで評価されるようになってでかい定数倍ペナルティをくらうことが多い。

個人的には「Pythonで最適化とはなるべくアルゴリズムを簡潔に明瞭に記述して重複している処理を減らしていく」くらいに留めるのがいい気がする。それでダメならC++で書こう・・・