PythonでForth処理系を書く! その1(序)
Forthというプログラミング言語がある。分類としては「スタックマシン型言語」になるだろう。
スタックマシンというのは、データやアドレスをスタックに乗せたり取り出したりすることを基本の操作とするプログラム実行モデルで、有名どころでいうとJava Virtual MachineやPython Bytecode Interpreterのような「いったん中間言語としてコンパイルされたものを実行する」ことによく使われている。
Forthはそのスタックマシンを直接操るプログラミング言語だ。
実行モデルがC(というかAlgol)系やML/Haskell的関数型、あるいはLispなどとも大きく違っていて面白い。個人的にもっとよく知りたい言語のひとつだ。
魅力の一つとして、言語処理系としては実装が非常に簡単かつ効率的かつ表現力が高いことが挙げられる。組み込みなどのシステムで、自前でForthを作って使う、ということも十分可能なようだ。
というわけで以前からForth処理系を作ってみたいと思っていたのだが、ある日Kindleでそのものずばりな「Forthを作ってみる」という電子書籍が売っていたのでそれを買ってみた。以下のサイトの内容を書籍化したもののようだ:
読んでみて大体の概要がわかったのでPythonで似たような機能を持つForth Interpreterを実装してみた。
これから何回かに分けてこの実装過程について書いていきたい。
AtCoder Beginner Contest 103に参加してみた
今回はノーミスで全完したのだが、C問題にあまりにも時間がかかってしまい、実感としてはかなりまずい印象だった。
前回のABCから三週間ぶりで感覚がおかしかった気がする。
第1問
「最小→真ん中→最大」あるいは「最大→真ん中→最小」のどちらかの順番ですすむのが最善で、結果は最大-最小の差になる。
xs = [int(x) for x in input().split()] print(max(xs) - min(xs))
いきなりWAを飛ばしてびっくりした。何回見直しても大丈夫そうだったので放置しておいたら、AtCoder側の不具合があったらしくリジャッジでACになっていた。
第2問
ある文字列が別の文字列を回転させたものかどうか、という質問。文字列の長さが最長100なので愚直に書いて問題ない。
s = input() t = input() print('Yes' if any(s == t[i:] + t[:i] for i in range(len(t))) else 'No')
ただ、ツイッターでprint('Yes' if s in t+t else 'No')
を見たときには、自分で思いつかなかったのが悔しかった。
第3問
何故かこの問題に1時間ほどかかった。
- 規則性?LCM?とかうだうだと考えるのに四十五分ほど
- (a1 * a2 * ... an) - 1 だとどの数でも割り切れないし最大化されるな、しかし巨大数になりそう((10 ^ 5) ^ 3000)だからどうするか・・・ と悩むのに十分
- 「あ」と解法に気づいた瞬間から実装するのに一分かからなかったと思う
n = input() xs = [int(x) for x in input().split()] print(sum(x-1 for x in xs))
(a1 * a2 * ... an) - 1で「すべての数字が割り切るのに1足りない数」を作れるので、その数の結果はsum(a - 1)になる。気づいてしまえばなんということはないのだけど、何故かなかなか気付かなかった。
第4問
西から東へ一直線に並ぶ島をつなぐ橋を最小でいくつ消せば、「島aと島bは繋がっていない」という要望をすべて満たせるか、という問題。
残り三十分ほどだし無理かな、と思って眺めていたら蟻本初級編「Saruman's Army」の類題だった。
まず要望を西側の島がもっとも小さい(西側によっている)順に要望をソートする。あとは西端から貪欲に要望を満たしていく:
n, m = map(int, input().split()) lrs = [[int(x) for x in input().split()] for _ in range(m)] lrs.sort() ans = 0 m = -1 for l, r in lrs: if r < m: m = r if l >= m: ans += 1 m = r print(ans)
最初はプライオリティキューを使うことも考えたのだけど、まったく必要なかった。
感想・結果
ノーミス全完ながらも消費時間が93:46で順位は526、パフォーマンスは1271でちょこっとだけしかレートが上がらなかった。
それも仕方ないな、というくらいにCの解法が盲点に入ってしまっていた。実装自体は全問非常に簡単だったし。
どう考えればよかったのか・・・ 解法やアルゴリズムに入るまえに「そもそも想定し得る最大はいくつだ?」「その最大になりえる条件は?」と考えていればsum(a - 1)が早い段階で見えたかもしれない。
考察のやり方のまずさと、精神力の弱さ(AがWAだったことによる動揺も大きかった)が露呈したコンテストだった。
ただ、D問題も読んでみて「歯が立たないな」という印象はなく、最終的に今持っている知識で解答に確信を持って全完できたのは良かった。ここ2ヶ月弱の競プロ精進のひとつの成果だと思う。
考察の仕方や精神力は、もっと場数を踏むことで培っていこう。
蟻本初級編攻略 - 2-4 Union Find木 後編
前回のUnion Find木実装を踏まえてAtCoder ABC/ARCのD問題や蟻本の元の例題に取り組んでみる。
ABC049 D問題 - 連結/Connectivity
道路ネットワークと鉄道ネットワークが走っている国で、街ごとに道路でも鉄道でも連結になっている街の数を出力する。
AとBが道路でも鉄道でも連結である必要かつ十分条件は「道路ネットワークでも鉄道ネットワークでも同一のグループに属する」ということなのが(少し考えると)わかった。
あとは街ごとに(道路ネットワークのルート, 鉄道ネットワークのルート)というタプルで所属するグループを表して、そのグループに所属する街の数を算出してやればいい。
というわけで実装:
from collections import Counter n, k, l = map(int, input().split()) ps = [tuple(map(int, input().split())) for _ in range(k)] rs = [tuple(map(int, input().split())) for _ in range(l)] def root(n, u): if u[n] == n: return n u[n] = root(u[n], u) return u[n] union1 = {x:x for x in range(1, n+1)} for p, q in ps: union1[root(p, union1)] = root(q, union1) union2 = {x:x for x in range(1, n+1)} for r, s in rs: union2[root(r, union2)] = root(s, union2) gs = {x:(root(x, union1), root(x, union2)) for x in range(1, n+1)} c = Counter(gs.values()) print(' '.join(str(c[gs[x]]) for x in range(1, n+1)))
二回Union Findをして、最終的なグループをタプルで表し、グループごとの街数をCounterで算出し、街ごとに所属するグループの街数を出力している。
各Union FindがO(N log N)、最終的なグループを作るのもO(N log N)、カウンターはO(N)、出力もO(N)で計算量のオーダーはO(N log N)。
ARC097 D問題 - Equals
「いれかえが許された添字ペア」を何度でも利用して、配列の要素の値と添字を最大何個一致させることができるか。
「添字ペア」がUnion Findで言うところの連結の概念と一致することがわかれば実装は非常に楽。そのためには「(a, b), (b, c), (c, d) ... (y, z)」のような添字ペアの連なりがあれば、入れ替えを連続で行うことでa~zまでのすべての要素を任意の場所に配置できる、ということがわかればいい。
これは帰納法で考えられる。
添字ペアが(a, b)だけの場合:
- ありえる二つの配置(a, b)と(b, a)が「入れ替えられる」という問題の定義から自明に可能
すでに入れ替えを連続で行うことでa1~amまでのすべての要素を任意の場所に配置できるグループに、(am, an)を加えた場合:
ということで証明終わり。
あとは「値iを含む位置jと位置iが同じグループに属している」という条件を満たすiの数を数えるだけ(単一であることが制約により保証されていることも重要な成立条件)。
n, m = map(int, input().split()) ps = [int(x) for x in input().split()] xs = [tuple(map(int, input().split())) for _ in range(m)] union = {x:x for x in range(1, n+1)} def root(n): if union[n] == n: return n union[n] = root(union[n]) return union[n] for x,y in xs: union[root(x)] = root(y) print(sum(root(i+1) == root(ps[i]) for i in range(n)))
「Union Find木の典型」と言えるような解法になる。
ARC036 D問題 - 偶数メートル
街から街へ(もしかすると複数の)道が通っていて、距離の和が偶数になるような道順を通ってある街から別の街へ移動できるかを判定する問題。
各道の距離が与えられるが実際に重要なのはその距離が奇数か偶数かという点。
まず最初に浮かんだ解法がこれ:
Union Find木は「偶数で繋がっているグループ」を表し、それとは別に「ある街が奇数で繋がっているグループ」をdictionaryで表す。
街A <- 奇数道-> 街B <- 偶数道 -> 街C
のようになっている場合、街Aと街Cも奇数で繋がっているので「奇数で繋がっている相手」を「偶数グループ」で表せるのがポイント。
あと
街A <- 奇数道-> 街B <- 奇数道 -> 街C
だとAとCが同じ偶数グループに属するようになる、というのも重要。
それを実装するとこうなる:
n, q = map(int, input().split()) ws = [tuple(map(int, input().split())) for _ in range(q)] union = {x:x for x in range(1, n+1)} odds = {x:None for x in range(1, n+1)} def root(n): if union[n] == n: return n union[n] = root(union[n]) return union[n] for w, x, y, z in ws: rx, ry = root(x), root(y) if w == 2: print('YES' if rx == ry else 'NO') elif z % 2 == 0: union[rx] = ry if odds[rx] and odds[ry]: union[root(odds[rx])] = root(odds[ry]) else: odds[rx] = odds[ry] = odds[rx] or odds[ry] else: odds[rx] = odds[rx] or ry odds[ry] = odds[ry] or rx union[rx] = root(odds[ry]) union[ry] = root(odds[rx])
ちょっと奇数・偶数で場合分けした時のロジックが複雑になっている。
もう一つの解法はこれ:
Union Find木で「奇数で繋がっている」「偶数で繋がっている」を一括で管理する。各ノードは「街、奇偶」のタプルで表している。
実装が簡単になるなーと思ったのだが・・・
n, q = map(int, input().split()) ws = [tuple(map(int, input().split())) for _ in range(q)] union = {(x, z):(x, z) for x in range(1, n+1) for z in 'oe'} def root(n): if union[n] == n: return n union[n] = root(union[n]) return union[n] for w, x, y, z in ws: if w == 2: print('YES' if root((x, 'e')) == root((y, 'e')) else 'NO') elif z % 2 == 0: rex, rey = sorted([root((x, 'e')), root((y, 'e'))]) union[rex] = rey rox, roy = sorted([root((x, 'o')), root((y, 'o'))]) union[rox] = roy else: reox, reoy = sorted([root((x, 'e')), root((y, 'o'))]) union[reox] = reoy roex, roey = sorted([root((x, 'o')), root((y, 'e'))]) union[roex] = roey
多分同じ街が二回現れることが理由なんだと思うのだが、気をつけないとroot関数が無限ループに陥ってStack Overflowを起こす。実際テストケース最後の一つだけで何回かREしてしまった・・・
かならず「ソート順が一番大きいノード」に貼るようにすればループは起こりえない。その分少しコードが冗長になってしまったが、さきほどの解に比べると場合分けのロジックの対称性がはっきりしていてやはりわかりやすいように思う。
ARC090 D問題 - People in a line
直線上に並んでいる人たちについて、「LはRのXm左にいる」という情報が与えられる。その情報に矛盾がないか判定する問題。
Union Find木の発展系である重みつきUnion Find木を使うと簡単に解ける:
n, m = map(int, input().split()) xs = [tuple(map(int, input().split())) for _ in range(m)] union = {x:(x, 0) for x in range(1, n+1)} def root(n): if union[n][0] == n: return union[n] n1, d1 = union[n] n2, d2 = root(n1) union[n] = (n2, d2+d1) return union[n] for l, r, d in xs: n1, d1 = root(l) n2, d2 = root(r) if n1 != n2: union[n2] = (n1, d + d1 - d2) elif d != d2 - d1: print('No') break else: print('Yes')
「ルートノードの人、その人に対しての距離」をタプルとしてUnion Find木に入れている。
新しい情報が入ってくるたびに * もしルートノードが違えばUnion Find木をアップデートできる(二つのルートノードの相対位置がわかる) * もしルートノードが同じならそのルートノードに対してのLとRの位置の差が新しい情報と一致しているかを判定する
Union Findのノードをタプル化して追加で情報を持たせるのは「偶数メートル」と同じ。ただし、root関数内でその情報をまとめる必要がでてくるのでより複雑(といってもノード間の距離の総和を追加で返すだけだが)
あと少し気の利いたところとしては最後のelse。Pythonのfor-loopには「breakしなかったら最後にこれを実行」というelse構文があることはあまり知られていないように思う。ループ中に矛盾が見つかれば"No"と出力してbreak、もし見つからなければ最後にelseで"Yes"と出力、というロジック。
POJ 1182 - 食物連鎖
「AがBを食べ、BがCを食べ、CがAを食べる」とわかっているなかで、「xとyは同じ種」「xはyを食べる」という情報が順次入ってくる。そのうちで以前来た情報や個体番号の制限と矛盾するものは無視するとして、最終的に無視した情報の数を求める問題。
入力値の一例としてはこんな感じ:
n = 100 k = 7 xs = [(1, 101, 1), (2, 1, 2), (2, 2, 3), (2, 3, 3), (1, 1, 3), (2, 3, 1), (1, 5, 5)]
nが個体数、kが情報数、xsが情報のリスト。情報は
- 1=「同じ種」, 2=「xはyを食べる」という情報のタイプ
- xの番号
- yの番号
の三点からなるタプル。
上記の入力から期待される出力は3。
1 101 1 2 3 3 2 3 1
の三つの情報が矛盾する。
これもUnion Find木で解決するのだが、ノードは個体ではなく「1はAである」などの命題であり、ノードの連結は
命題Aが真の時そしてその時に限り命題Bも真 A <-> B
という関係を表している。
なので
- 「xとyは同じ種」という情報が入ってくるとx-Aとy-Aは互いに連結、x-Bとy-B、x-Cとy-Cも同じく互いに連結
- 「xはyを食べる」という情報が入ってくるとx-Aとy-Bは互いに連結、x-Bとy-C、x-Cとy-Aも同じく互いに連結
その上で「xはyを食べる」という情報が入ってきた時すでにx-Aとy-Aが連結であったりした場合、情報が矛盾するので無視することになる。
というのが以下の実装:
groups = ('a', 'b', 'c') union = {(x, g):(x, g) for x in range(1, n+1) for g in groups} def root(n): if union[n] == n: return n union[n] = root(union[n]) return union[n] def is_error(i, x, y): if not (1 <= x <= n and 1 <= y <= n): return True if i == 1: return root((x, 'a')) in {root((y, g)) for g in ('b', 'c')} if i == 2: return root((x, 'a')) in {root((y, g)) for g in ('a', 'c')} return True ans = 0 for i, x, y in xs: if is_error(i, x, y): ans += 1 elif i == 1: for g in groups: union[root((x, g))] = root((y, g)) else: for g1, g2 in (('a', 'b'), ('b', 'c'), ('c', 'a')): union[root((x, g1))] = root((y, g2)) print(ans)
まあ実装はそこまで特筆すべきものはない・・・ 矛盾があるかをチェックする時に「x is A」と繋がったyに関する命題の矛盾だけ調べればいい、という点くらいか。
しかし解法自体は非常に面白かった。「命題がノードになる」というのはけっこう盲点で、蟻本の解説を読むまではどう書けばいいのかなかなか想像がつかなかった。
この視点はUnion Find木の応用の幅が広がりそうなのでこれからも大切にしたい。
結論
Union Find好き。
蟻本初級編攻略 - 2-4 Union Find木 前編
「ある集合の二つの要素が繋がっているか」を効率的に判定できるデータ構造であるUnion Find木の話。
非常に簡単でエレガントな実装のわりに非常に強力な概念だという印象で、個人的にはとても好き。
蟻本の例題が一番特殊で面白いので後回し。まずはAtCoder Typical Contestから:
AtCoder Typical Contest 001 B問題 Union Find
ザ・Union Find典型。Union Findが逐次更新できるデータ構造で、途中のいつでも「現在明らかになった情報を元に二つの要素が繋がっているか判定できる」という点も強調されている。
Python実装はこんな感じ:
n, q = map(int, input().split()) xs = [tuple(map(int, input().split())) for _ in range(q)] union = {x:x for x in range(n)} def root(n): if union[n] == n: return n union[n] = root(union[n]) return union[n] for p, a, b in xs: if p == 0: union[root(a)] = root(b) else: print("Yes" if root(a) == root(b) else "No")
やはり非常に簡潔に済むのがうれしい。
Union Findはデータをdictionaryに持たせ、そのdictionaryをもとにある要素が繋がっているグループのルートノードを返す関数と、二つのグループをつなげるイディオムを提供する。二つの要素が繋がっているかを判定するのは、その二つの要素の属するグループのルートノードが一致するかを調べるだけ。
Union Findは気をつけないとroot関数が最悪O(n)かかってしまう。それを回避するメジャーな最適化が二つある。ここらへんはAtCoderのUnion FindについてのSlideShareに詳しい:
www.slideshare.net
両方の最適化を使うと計算量のオーダーが逆アッカーマン関数(ほぼ定数・・・)になるということだが、一つだけでもO(logN)になるうえ、実装が簡単で定数倍ははやい。
というわけで私は「ルートノードを探る度に枝をすべてルートノード直下に貼り直す」という最適化のみ実装している。
つまりroot関数のナイーブな実装:
def root(n): if union[n] == n: return n return root(union[n])
それをこう変えている:
def root(n): if union[n] == n: return n union[n] = root(union[n]) #ここ return union[n]
これだけでO(N)がO(log N)になるのはうれしい。
そういえば昨日PEP572が受理されたので将来こう書けるようになる:
def root(n): if union[n] == n: return n return union[n] := root(union[n]) #ここ
まあこれは大した違いじゃないけど・・・。
競プロ時にPythonでClassを使わずにデータ構造を実装することについて
ちなみに私は個人的に競技プログラミングでPythonを書く場合、Union Findくらいだったらインターフェースつけてclassにwrapするより、もとのデータ構造をむき出しにしておきたい。
ということをちょっと乱暴に呟いた:
競プロで絶対class書きたくないマン
— zehnpaard (@zehnpaard) 2018年7月5日
そしたら速攻でツッコミが入った:
タプル使うよりClass使ったほうがよくないですか?
— すとまと (@stmtk_01) 2018年7月5日
(次回出てくる重みつきUnion Find木の話)
もうちょっと丁寧に自分の考えを説明するとこんな感じ:
競プロだと「0から読み直して頭の中でフローを再構築できるコード」である必要はなくて「ある程度頭の中でコンテキストが出来上がっている中でフローが追いやすくエラーが見つけやすいコード」が最適なので、記述の仕方もやっぱりけっこう違うなーと思う
— zehnpaard (@zehnpaard) 2018年7月5日
(最初のnamedtupleについての言及はプロダクションコードだったら、の話)
さらにいうとデータ構造がむき出しのほうが短期的にはいじりやすい。微妙にデータ構造に持たせる情報や算出させる情報を変えたいときに、すくなくともABCレベルの競技プログラミングだと変更箇所はかなり限られているから使われているその場で変えてしまったほうが楽。
数百行以上のコードのいろんなところで使われているなら状況は逆転する。やはりプロダクションコードと競技プログラミングでは細かい作法がけっこう違ってくるのは自然だと思う。
次回予告
つらつらと書いていたら長くなったので記事を二回にわける。
次回はAtCoder ABC/ARCからUnion Findを使ったD問題4つと蟻本の例題の解法を説明する。
蟻本初級編攻略 - 2-4 二分探索木
蟻本の章名である「二分探索木」というサブタイトルを付けたが、Pythonなので二分探索木使わない。setもdictもハッシュマップだ。
ABC085 B問題 Kagami Mochi
上に乗せる鏡餅は下の餅より直径が小さくなくてはいけないという制約のもと、何段の鏡餅が作れるかという問題。
制約は言い換えれば「同じ直径の鏡餅は使えない」ということ。なので直径の重複を除いて数えればいい。
n = int(input()) print(len(set(int(input()) for _ in range(n))))
Pythonなのでsetにしてlenをとるだけ。set化がO(n)、lenがO(1)。
ABC091 B問題 Two Colors Card Game
青と赤のカードに文字列が書いてあって、ある文字列を選ぶとその文字列が書いてある青いカードの数だけポイントがもらえ、赤いカードの数だけポイントが減る。ポイントを最大化する文字列は何か、という問題。
青いカードに書かれている文字列のうち、青いカードの数と赤いカードの数の差額で最も大きいもの、あるいは0が答え:
from collections import Counter n = int(input()) ss = Counter(input() for _ in range(n)) m = int(input()) ts = Counter(input() for _ in range(m)) answer = max(cs - ts[s] for s, cs in ss.items()) print(answer if answer > 0 else 0)
標準ライブラリに入っている、dictを拡張したCounterクラスを使っている。上記のロジックをそのままで記述している。
Counter作成はO(n)、ss.items()のループは全部でO(n)、ts[s]はO(1)、maxはO(n)、と全体でもO(n)の計算量。
蟻本初級編攻略 - 2-4 Expedition
動的計画法関連のところはPythonでやるとTLEする問題もいくつかあり、OCamlでやることも視野に入れつつ少し寝かせておく。
というわけで先にPriority Queueを使った問題3つ。まずは蟻本に載っているPOJ問題から:
POJ - Expedition
現在地から目的地への直線上にある複数のガソリンスタンドに最小で何回止まって給油する必要があるか、という問題。
ガソリンが足りなくなる限界までの区間にあるガソリンスタンドの給油量をPriorityQueueにいれて最大値を取り出し、ガソリンが足りなくなる限界距離を更新して考慮に入れるスタンドをQueueにいれて最大値を取り出し、というループを限界距離が目的距離に到達するまで続ける:
from queue import PriorityQueue n = int(input()) xs = [tuple(int(x) for x in input().split()) for _ in range(n)] l, p = map(int, input().split()) current = p count = 0 q = PriorityQueue() xs = list(sorted([(l-x, v) for x, v in xs], reverse=True)) while current < l: while xs and xs[-1][0] <= current: q.put(-xs.pop()[1]) if not q: break current -= q.get() count += 1 print(count if current >= l else -1)
PythonのPriorityQueueやheapqはmin-heapしか提供していない。ちょっとショック。なので「最大値を常にとってくる」という挙動のためにマイナスでかけている。その点以外は比較的素直な実装ではないだろうか。
それではAtCoderの類題2問を解いてみる。
Code Thanks Festival 2017 C問題 - Factory
code-thanks-festival-2017-open.contest.atcoder.jp
code-thanks-festival-2017-open.contest.atcoder.jp
プレゼントを作るごとにスピードが劣化していく複数の機械をどう使うと最小の時間で一定数のプレゼントが作れるか、という問題。
キューに「かかる時間」と「劣化の度合い」をタプルとして入れて、かかる時間が最小のものを選んで、そして「かかる時間+劣化」と「劣化の度合い」のタプルをキューに戻して、とループしていく。ほしいプレゼント数の分だけループが回って「かかる時間」の和が答え。
import heapq n, k = map(int, input().split()) xys = [tuple(int(x) for x in input().split()) for _ in range(n)] heapq.heapify(xys) time = 0 for _ in range(k): time += xys[0][0] heapq.heappushpop(xys, (xys[0][0]+xys[0][1], xys[0][1])) print(time)
今回はheapqで実装してみた。heapqには普通のlistをヒープ化したり、ヒープ化されたリストに対するヒープ操作をしたり、という関数が用意されている。実装がリスト準拠なので、最小の要素をpeekするだけならxs[0]でできる。queueモジュールのPriorityQueueは実はスレッドセーフな実装になっており、その分のオーバーヘッドがかかるのでheapqのほうが定数倍効率がいい。
heapqがO(N)、heappushpopがO(logN)なので全部でO(n + k logn)の計算量。
ARC074 D問題 - 3N Numbers
3*N個の正の整数のリストからN個要素を取り除き、「前半N個の和 - 後半N個の和」を最大化する問題。
制約から見えてきたいくつかのポイント:
- 前半の和を最大化し、後半の和を最小化する必要がある
- 前半はA[:2*n]の要素のうちのN個によって構成される
- 後半はA[n:]の要素のうちのN個によって構成される
- 前半が使える部分と後半が使える部分のオーバーラップであるA[n:2*n]のどこに線を引くかがポイント
というわけで
- i = 0 ~ nの各iごとにA[:n+i]からN個選んだ場合の和の最大値を求める
- j = 0 ~ nの各ごとにA[2*n-j:]からN個選んだ場合の和の最小値を求める
ということをPriorityQueueを使って一歩ずつ算出していく。その上でi+j = nで差が最大になるiとjのペアを見つける。
from heapq import heapify, heappushpop n = int(input()) xs = [int(x) for x in input().split()] xs1 = xs[:n] heapify(xs1) r1 = [sum(xs1)] for x in xs[n:2*n]: r1.append(r1[-1] + x - heappushpop(xs1, x)) xs2 = [-x for x in xs[2*n:]] heapify(xs2) r2 = [sum(xs2)] for x in reversed(xs[n:2*n]): r2.append(r2[-1] + (-x) - heappushpop(xs2, -x)) print(max(a+b for a,b in zip(r1, reversed(r2))))
前半の和の最大値を求めるのに、A[n+i]の要素をキューにいれてから最小値を出して、(A[n+i] - 最小値)をそれまでの和に足している。後半も全く同じロジック、ただし「最大値をキューから取り出す」という挙動のためにキューに入れる・キューから取り出す値をすべてマイナスでかけている。
ほとんどの部分がO(N)だが、ループしている部分だけN回ループが回ってO(logN)のheappushpopが毎回行われているので、そこが計算量O(N logN)で大きい。
Priority Queue感想
やっぱり便利。今回の問題はみんなPriority Queueが前面に出てきているが、そうでなくても解法の一部として計算量を下げるのに活躍してくれる場面は多そう。
しかしmax heapがないのはやっぱり不便だな、微妙なところで実装を間違えてバグりそうだから要注意だ。
AtCoder Beginner Contest 102に参加してみた
第1問
入力値nと2の最小公倍数を求める問題。nが偶数ならn、そうじゃないなら2*n
n = int(input()) print(2*n if n%2 else n)
第2問
Aの(添字の)異なる 2 要素の差の絶対値の最大値を求めてください。
と言われるとややこしく感じてしまうが「2要素の差が一番大きい」なのでmaxとminの差。
入力値に「Aの要素が2つ以上」という制約がある以上、要素の値がすべて同じだったとしても「添字の異なる2要素」は満たせる。
n = int(input()) xs = [int(x) for x in input().split()] print(max(xs) - min(xs))
第3問
Ai と b+iの差
と考えるとちょっとややこしいかもしれないがabs(Ai - (b+i) = abs((Ai - i) - b)
なのでまずAi-i
の数列Bを作ってしまう。
その数列Bの各要素と任意の数bの距離の和を最小化したい。
たとえばbがBのどの要素よりも小さかったとする。そうするとb+1はBのすべての要素に1ずつ近くなるので明らかにbより勝る。
bがBの最小の要素と同じ値だったとする(Bの最小の要素が単一だと仮定する)。するとb+1はその最小の要素からの距離は1増え、n-1個の要素との距離は1ずつ減る。
この「1ずれるごとに距離が1増える要素と距離が1減る要素の数」がちょうどバランスするところが最善であり、そのポイントでの距離の和が答えとなる。
それは中央値を構成する要素が1つの場合はちょうどその要素の値となるし、もし中央値が二つの要素の平均だとしたら、その二つの要素の値の間ならどこでも最善になる。
なのでてっとりばやく中央値を使える:
from statistics import median n = int(input()) xs = [int(x) for x in input().split()] ys = [x-i-1 for i, x in enumerate(xs)] m = int(median(ys)) print(sum(abs(y-m) for y in ys))
第4問
この問題はコンテスト中には解けなかった・・・ しゃくとり法や左右から端を狭めていく方法などをずっと考えていて最後に苦し紛れの実装でWAを出したりしながらタイムアップ。悔しいので解説は見ずに1日ぼーっと考えたりしていた。
累積和をとると便利なのは最初から考えていたのだが、「真ん中をまず割ってみる」というところに発想がいかなかったのが敗因。
真ん中を割ってしまえば、「左右の配列を個別にどう割るか」という二つの小さい問題になる。
まず左の配列について考えてみる:
- 配列をどう割るのが最善か、と考えると、ここでも和の差を最小化するのが一番だとわかる
- 和の差が最小化されるためには一つ一つの和ができるだけ平均に近ければいい
- 累積和を平均値で二分探索すればいい
- ただし二分探索で得られた添字の周りが答えの可能性もあるので-1, +1も調べて和の差が最小化されるポイントを選ぶ
右の配列も似た論理だ。ただし累積和がすこしややこしい(左の配列の総和を差し引く必要がある)。
あとは左右の配列の結果を組み合わせて最大と最小の和の差をとれば、あるポイントを真ん中にして割った場合の最小の和の差がわかる。
つまり真ん中が決定していれば、最小の和の差を2*O(log N/2)で計算できる。
あとは真ん中候補すべてに対してこの計算をして最小の結果が全体の答え。なので全体でO(N logN)。
from itertools import accumulate, islice from bisect import bisect_left n = int(input()) xs = [int(x) for x in input().split()] ys = list(accumulate(xs)) zs, total = ys[:-1], ys[-1] res = max(xs) - min(xs) for i, z in islice(enumerate(zs), 1, n-1): j = bisect_left(zs, z//2) splits = ((zs[j+k], z - zs[j+k]) for k in (-1, 0, 1) if 0 <= j+k <= i) a, b = min(splits, key=lambda s:abs(s[0]-s[1])) j = bisect_left(zs, (total - z)//2 + z) splits = ((zs[j+k] - z, total - zs[j+k]) for k in (-1, 0, 1) if i <= j+k <= n-2) c, d = min(splits, key=lambda s:abs(s[0]-s[1])) minn, _, _, maxx = sorted((a,b,c,d)) if maxx - minn < res: res = maxx - minn print(res)
Pythonだとテストケースにかかる時間が最長で1800msとかなりギリギリだ。うっかりj = bisect_left(zs, z/2)
などと浮動小数点との比較で二分探索してしまうと簡単にTLEになる。今度OCamlで実装してどう書けてどれくらい速度がでるか試してみたい。