蟻本初級編攻略 - 2-1 部分和
定期的に「蟻本を全部読むぞ!」と思って読み始めては放置、を繰り返している。
そもそもAtCoderやCodeForceに参加することなく「ある程度読んでから競プロ試してみよう」と思っていたのが間違いだったのではないか。
ABC100にリアルタイム参加もしたし、今度こそ蟻本を攻略したい。
幸いこういう素晴らしい記事もある:
本を読み、AtCoderの問題を解き、その解法をブログに書いていこうと思う。
部分和問題
蟻本の冒頭あたりについては今年の三月にも記事を書いていた。
というわけで書いたブログとコードを引っ張り出してみる:
が、なんだこれは・・・ ひどいねコレは。Pythonで添字をあまり使いたくないのはわかるがそれでこんなにコードが長く読みにくくなっては本末転倒も甚だしい。
と、完全に「数ヶ月前の自分のコードはひどく見える」を実践してしまった。
蟻本の題意である「Yes/Noを返す」を素直に実装するなら:
n = int(input()) xs = [int(x) for x in input().split()] k = int(input()) def dfs(i, total): if i == n: return k == total return dfs(i+1, total+xs[i]) or dfs(i+1, total) print('Yes' if dfs(0, 0) else 'No')
if i == n
をif i == n or k <= total
として途中で抜け出すような最適化も可能なことが一見してわかる。
もし「条件を満たす集合を返す」という実装をするとしてもこんな感じだろう:
n = int(input()) xs = [int(x) for x in input().split()] k = int(input()) def dfs(i, total): if k == total: yield [] if i == n or k < total: return for res in dfs(i+1, total+xs[i]): res.append(i) yield res for res in dfs(i+1, total): yield res res = list(dfs(0, 0)) sum_string = ' + '.join(str(xs[i]) for i in reversed(res[0])) if res else '' print('Yes ({} = {})'.format(k, sum_string) if res else 'No')
ちなみにPythonで枝刈りなしの全列挙ならitertoolsを使ったほうがいい:
from itertools import compress, product n = int(input()) xs = [int(x) for x in input().split()] k = int(input()) subsets = (list(compress(xs, selectors)) for selectors in product((0,1), repeat=n)) res = [subset for subset in subsets if sum(subset) == k] sum_string = ' + '.join(str(xs[i]) for i in reversed(res[0])) if res else '' print('Yes ({} = {})'.format(k, sum_string) if res else 'No')
再帰リミットにぶつかる心配をする必要もないし、より抽象度の高い概念で記述できる。
ABC061C - たくさんの数式
というわけでAtCoderの類題を解いてみる。
これは枝刈りなしの全列挙なのでitertoolsを使う:
from itertools import product s = input() s2 = '{}'.join(s) betweenss = product(('', '+'), repeat=len(s)-1) print(sum(eval(s2.format(*betweens)) for betweens in betweenss))
sの各文字の間に{}
を埋めて、そこに''と'+'のセットをあてはめて評価していく。
ABC079C - Train Ticket
ロジックはまったく同じ:
from itertools import product s = input() s2 = '{}'.join(s) opss = product('+-', repeat=3) lhss = (s2.format(*ops) for ops in opss) eqs = (lhs+'=7' for lhs in lhss if eval(lhs)==7) print(next(eqs))
「かならず解がひとつ以上存在する」「複数解がある場合はどれを返してもいい」という条件なのでジェネレータを作ってnextで最初の要素を取り出している。