蟻本初級編攻略 - 2-4 Union Find木 前編
「ある集合の二つの要素が繋がっているか」を効率的に判定できるデータ構造であるUnion Find木の話。
非常に簡単でエレガントな実装のわりに非常に強力な概念だという印象で、個人的にはとても好き。
蟻本の例題が一番特殊で面白いので後回し。まずはAtCoder Typical Contestから:
AtCoder Typical Contest 001 B問題 Union Find
ザ・Union Find典型。Union Findが逐次更新できるデータ構造で、途中のいつでも「現在明らかになった情報を元に二つの要素が繋がっているか判定できる」という点も強調されている。
Python実装はこんな感じ:
n, q = map(int, input().split()) xs = [tuple(map(int, input().split())) for _ in range(q)] union = {x:x for x in range(n)} def root(n): if union[n] == n: return n union[n] = root(union[n]) return union[n] for p, a, b in xs: if p == 0: union[root(a)] = root(b) else: print("Yes" if root(a) == root(b) else "No")
やはり非常に簡潔に済むのがうれしい。
Union Findはデータをdictionaryに持たせ、そのdictionaryをもとにある要素が繋がっているグループのルートノードを返す関数と、二つのグループをつなげるイディオムを提供する。二つの要素が繋がっているかを判定するのは、その二つの要素の属するグループのルートノードが一致するかを調べるだけ。
Union Findは気をつけないとroot関数が最悪O(n)かかってしまう。それを回避するメジャーな最適化が二つある。ここらへんはAtCoderのUnion FindについてのSlideShareに詳しい:
www.slideshare.net
両方の最適化を使うと計算量のオーダーが逆アッカーマン関数(ほぼ定数・・・)になるということだが、一つだけでもO(logN)になるうえ、実装が簡単で定数倍ははやい。
というわけで私は「ルートノードを探る度に枝をすべてルートノード直下に貼り直す」という最適化のみ実装している。
つまりroot関数のナイーブな実装:
def root(n): if union[n] == n: return n return root(union[n])
それをこう変えている:
def root(n): if union[n] == n: return n union[n] = root(union[n]) #ここ return union[n]
これだけでO(N)がO(log N)になるのはうれしい。
そういえば昨日PEP572が受理されたので将来こう書けるようになる:
def root(n): if union[n] == n: return n return union[n] := root(union[n]) #ここ
まあこれは大した違いじゃないけど・・・。
競プロ時にPythonでClassを使わずにデータ構造を実装することについて
ちなみに私は個人的に競技プログラミングでPythonを書く場合、Union Findくらいだったらインターフェースつけてclassにwrapするより、もとのデータ構造をむき出しにしておきたい。
ということをちょっと乱暴に呟いた:
競プロで絶対class書きたくないマン
— zehnpaard (@zehnpaard) 2018年7月5日
そしたら速攻でツッコミが入った:
タプル使うよりClass使ったほうがよくないですか?
— すとまと (@stmtk_01) 2018年7月5日
(次回出てくる重みつきUnion Find木の話)
もうちょっと丁寧に自分の考えを説明するとこんな感じ:
競プロだと「0から読み直して頭の中でフローを再構築できるコード」である必要はなくて「ある程度頭の中でコンテキストが出来上がっている中でフローが追いやすくエラーが見つけやすいコード」が最適なので、記述の仕方もやっぱりけっこう違うなーと思う
— zehnpaard (@zehnpaard) 2018年7月5日
(最初のnamedtupleについての言及はプロダクションコードだったら、の話)
さらにいうとデータ構造がむき出しのほうが短期的にはいじりやすい。微妙にデータ構造に持たせる情報や算出させる情報を変えたいときに、すくなくともABCレベルの競技プログラミングだと変更箇所はかなり限られているから使われているその場で変えてしまったほうが楽。
数百行以上のコードのいろんなところで使われているなら状況は逆転する。やはりプロダクションコードと競技プログラミングでは細かい作法がけっこう違ってくるのは自然だと思う。
次回予告
つらつらと書いていたら長くなったので記事を二回にわける。
次回はAtCoder ABC/ARCからUnion Findを使ったD問題4つと蟻本の例題の解法を説明する。
蟻本初級編攻略 - 2-4 二分探索木
蟻本の章名である「二分探索木」というサブタイトルを付けたが、Pythonなので二分探索木使わない。setもdictもハッシュマップだ。
ABC085 B問題 Kagami Mochi
上に乗せる鏡餅は下の餅より直径が小さくなくてはいけないという制約のもと、何段の鏡餅が作れるかという問題。
制約は言い換えれば「同じ直径の鏡餅は使えない」ということ。なので直径の重複を除いて数えればいい。
n = int(input()) print(len(set(int(input()) for _ in range(n))))
Pythonなのでsetにしてlenをとるだけ。set化がO(n)、lenがO(1)。
ABC091 B問題 Two Colors Card Game
青と赤のカードに文字列が書いてあって、ある文字列を選ぶとその文字列が書いてある青いカードの数だけポイントがもらえ、赤いカードの数だけポイントが減る。ポイントを最大化する文字列は何か、という問題。
青いカードに書かれている文字列のうち、青いカードの数と赤いカードの数の差額で最も大きいもの、あるいは0が答え:
from collections import Counter n = int(input()) ss = Counter(input() for _ in range(n)) m = int(input()) ts = Counter(input() for _ in range(m)) answer = max(cs - ts[s] for s, cs in ss.items()) print(answer if answer > 0 else 0)
標準ライブラリに入っている、dictを拡張したCounterクラスを使っている。上記のロジックをそのままで記述している。
Counter作成はO(n)、ss.items()のループは全部でO(n)、ts[s]はO(1)、maxはO(n)、と全体でもO(n)の計算量。
蟻本初級編攻略 - 2-4 Expedition
動的計画法関連のところはPythonでやるとTLEする問題もいくつかあり、OCamlでやることも視野に入れつつ少し寝かせておく。
というわけで先にPriority Queueを使った問題3つ。まずは蟻本に載っているPOJ問題から:
POJ - Expedition
現在地から目的地への直線上にある複数のガソリンスタンドに最小で何回止まって給油する必要があるか、という問題。
ガソリンが足りなくなる限界までの区間にあるガソリンスタンドの給油量をPriorityQueueにいれて最大値を取り出し、ガソリンが足りなくなる限界距離を更新して考慮に入れるスタンドをQueueにいれて最大値を取り出し、というループを限界距離が目的距離に到達するまで続ける:
from queue import PriorityQueue n = int(input()) xs = [tuple(int(x) for x in input().split()) for _ in range(n)] l, p = map(int, input().split()) current = p count = 0 q = PriorityQueue() xs = list(sorted([(l-x, v) for x, v in xs], reverse=True)) while current < l: while xs and xs[-1][0] <= current: q.put(-xs.pop()[1]) if not q: break current -= q.get() count += 1 print(count if current >= l else -1)
PythonのPriorityQueueやheapqはmin-heapしか提供していない。ちょっとショック。なので「最大値を常にとってくる」という挙動のためにマイナスでかけている。その点以外は比較的素直な実装ではないだろうか。
それではAtCoderの類題2問を解いてみる。
Code Thanks Festival 2017 C問題 - Factory
code-thanks-festival-2017-open.contest.atcoder.jp
code-thanks-festival-2017-open.contest.atcoder.jp
プレゼントを作るごとにスピードが劣化していく複数の機械をどう使うと最小の時間で一定数のプレゼントが作れるか、という問題。
キューに「かかる時間」と「劣化の度合い」をタプルとして入れて、かかる時間が最小のものを選んで、そして「かかる時間+劣化」と「劣化の度合い」のタプルをキューに戻して、とループしていく。ほしいプレゼント数の分だけループが回って「かかる時間」の和が答え。
import heapq n, k = map(int, input().split()) xys = [tuple(int(x) for x in input().split()) for _ in range(n)] heapq.heapify(xys) time = 0 for _ in range(k): time += xys[0][0] heapq.heappushpop(xys, (xys[0][0]+xys[0][1], xys[0][1])) print(time)
今回はheapqで実装してみた。heapqには普通のlistをヒープ化したり、ヒープ化されたリストに対するヒープ操作をしたり、という関数が用意されている。実装がリスト準拠なので、最小の要素をpeekするだけならxs[0]でできる。queueモジュールのPriorityQueueは実はスレッドセーフな実装になっており、その分のオーバーヘッドがかかるのでheapqのほうが定数倍効率がいい。
heapqがO(N)、heappushpopがO(logN)なので全部でO(n + k logn)の計算量。
ARC074 D問題 - 3N Numbers
3*N個の正の整数のリストからN個要素を取り除き、「前半N個の和 - 後半N個の和」を最大化する問題。
制約から見えてきたいくつかのポイント:
- 前半の和を最大化し、後半の和を最小化する必要がある
- 前半はA[:2*n]の要素のうちのN個によって構成される
- 後半はA[n:]の要素のうちのN個によって構成される
- 前半が使える部分と後半が使える部分のオーバーラップであるA[n:2*n]のどこに線を引くかがポイント
というわけで
- i = 0 ~ nの各iごとにA[:n+i]からN個選んだ場合の和の最大値を求める
- j = 0 ~ nの各ごとにA[2*n-j:]からN個選んだ場合の和の最小値を求める
ということをPriorityQueueを使って一歩ずつ算出していく。その上でi+j = nで差が最大になるiとjのペアを見つける。
from heapq import heapify, heappushpop n = int(input()) xs = [int(x) for x in input().split()] xs1 = xs[:n] heapify(xs1) r1 = [sum(xs1)] for x in xs[n:2*n]: r1.append(r1[-1] + x - heappushpop(xs1, x)) xs2 = [-x for x in xs[2*n:]] heapify(xs2) r2 = [sum(xs2)] for x in reversed(xs[n:2*n]): r2.append(r2[-1] + (-x) - heappushpop(xs2, -x)) print(max(a+b for a,b in zip(r1, reversed(r2))))
前半の和の最大値を求めるのに、A[n+i]の要素をキューにいれてから最小値を出して、(A[n+i] - 最小値)をそれまでの和に足している。後半も全く同じロジック、ただし「最大値をキューから取り出す」という挙動のためにキューに入れる・キューから取り出す値をすべてマイナスでかけている。
ほとんどの部分がO(N)だが、ループしている部分だけN回ループが回ってO(logN)のheappushpopが毎回行われているので、そこが計算量O(N logN)で大きい。
Priority Queue感想
やっぱり便利。今回の問題はみんなPriority Queueが前面に出てきているが、そうでなくても解法の一部として計算量を下げるのに活躍してくれる場面は多そう。
しかしmax heapがないのはやっぱり不便だな、微妙なところで実装を間違えてバグりそうだから要注意だ。
AtCoder Beginner Contest 102に参加してみた
第1問
入力値nと2の最小公倍数を求める問題。nが偶数ならn、そうじゃないなら2*n
n = int(input()) print(2*n if n%2 else n)
第2問
Aの(添字の)異なる 2 要素の差の絶対値の最大値を求めてください。
と言われるとややこしく感じてしまうが「2要素の差が一番大きい」なのでmaxとminの差。
入力値に「Aの要素が2つ以上」という制約がある以上、要素の値がすべて同じだったとしても「添字の異なる2要素」は満たせる。
n = int(input()) xs = [int(x) for x in input().split()] print(max(xs) - min(xs))
第3問
Ai と b+iの差
と考えるとちょっとややこしいかもしれないがabs(Ai - (b+i) = abs((Ai - i) - b)
なのでまずAi-i
の数列Bを作ってしまう。
その数列Bの各要素と任意の数bの距離の和を最小化したい。
たとえばbがBのどの要素よりも小さかったとする。そうするとb+1はBのすべての要素に1ずつ近くなるので明らかにbより勝る。
bがBの最小の要素と同じ値だったとする(Bの最小の要素が単一だと仮定する)。するとb+1はその最小の要素からの距離は1増え、n-1個の要素との距離は1ずつ減る。
この「1ずれるごとに距離が1増える要素と距離が1減る要素の数」がちょうどバランスするところが最善であり、そのポイントでの距離の和が答えとなる。
それは中央値を構成する要素が1つの場合はちょうどその要素の値となるし、もし中央値が二つの要素の平均だとしたら、その二つの要素の値の間ならどこでも最善になる。
なのでてっとりばやく中央値を使える:
from statistics import median n = int(input()) xs = [int(x) for x in input().split()] ys = [x-i-1 for i, x in enumerate(xs)] m = int(median(ys)) print(sum(abs(y-m) for y in ys))
第4問
この問題はコンテスト中には解けなかった・・・ しゃくとり法や左右から端を狭めていく方法などをずっと考えていて最後に苦し紛れの実装でWAを出したりしながらタイムアップ。悔しいので解説は見ずに1日ぼーっと考えたりしていた。
累積和をとると便利なのは最初から考えていたのだが、「真ん中をまず割ってみる」というところに発想がいかなかったのが敗因。
真ん中を割ってしまえば、「左右の配列を個別にどう割るか」という二つの小さい問題になる。
まず左の配列について考えてみる:
- 配列をどう割るのが最善か、と考えると、ここでも和の差を最小化するのが一番だとわかる
- 和の差が最小化されるためには一つ一つの和ができるだけ平均に近ければいい
- 累積和を平均値で二分探索すればいい
- ただし二分探索で得られた添字の周りが答えの可能性もあるので-1, +1も調べて和の差が最小化されるポイントを選ぶ
右の配列も似た論理だ。ただし累積和がすこしややこしい(左の配列の総和を差し引く必要がある)。
あとは左右の配列の結果を組み合わせて最大と最小の和の差をとれば、あるポイントを真ん中にして割った場合の最小の和の差がわかる。
つまり真ん中が決定していれば、最小の和の差を2*O(log N/2)で計算できる。
あとは真ん中候補すべてに対してこの計算をして最小の結果が全体の答え。なので全体でO(N logN)。
from itertools import accumulate, islice from bisect import bisect_left n = int(input()) xs = [int(x) for x in input().split()] ys = list(accumulate(xs)) zs, total = ys[:-1], ys[-1] res = max(xs) - min(xs) for i, z in islice(enumerate(zs), 1, n-1): j = bisect_left(zs, z//2) splits = ((zs[j+k], z - zs[j+k]) for k in (-1, 0, 1) if 0 <= j+k <= i) a, b = min(splits, key=lambda s:abs(s[0]-s[1])) j = bisect_left(zs, (total - z)//2 + z) splits = ((zs[j+k] - z, total - zs[j+k]) for k in (-1, 0, 1) if i <= j+k <= n-2) c, d = min(splits, key=lambda s:abs(s[0]-s[1])) minn, _, _, maxx = sorted((a,b,c,d)) if maxx - minn < res: res = maxx - minn print(res)
Pythonだとテストケースにかかる時間が最長で1800msとかなりギリギリだ。うっかりj = bisect_left(zs, z/2)
などと浮動小数点との比較で二分探索してしまうと簡単にTLEになる。今度OCamlで実装してどう書けてどれくらい速度がでるか試してみたい。
OCamlでTypical Dynamic Programming Contest E問題を解いてみた
ここ一ヶ月ほどAtCoderをPythonですすめてきて、Pythonの遅さにひやりとすることはあっても基本的にアルゴリズムがあっていて素直な最適化を施していればABCの問題であれば通せそう、という感触を得ている。
しかし少し上の問題を見てみるとそうはいかないようだ。例えばAtCoderのTypical Dynamic Programming ContestというDP問題ばかり集めたサンプルコンテストのE問題:
かなり素直な桁DPなのだが、桁数が10000でmodをとる値が最大100なので計算量のオーダーが10000 * 100 * 10(最後の10は0~9までの数で回すから)、つまり107になる。 この計算量はコンパイル言語なら余裕だがPythonだと普通にアウトのようだ。
(まあこのコードが定数で最適かといわれると違うのだが・・・)
というわけでコンパイル言語で書いてみる。
競技プログラミング的に素直な選択はC++だが、もう少しPythonに近い宣言的な表現力のある言語がいい。
表現力という意味ではHaskellが相当高いが、純粋関数型だとアルゴリズムにけっこう工夫が必要になり、Pythonで考えるのとはまったく違う世界になる。
速度もあり、適度に関数型で宣言的に記述することができ、かつ必要なところで副作用を気軽に扱える言語、ということでOCamlを選んだ。
最長251msで無事AC。
open Batteries let digitList s = s |> String.to_list |> List.map (fun c -> Char.code c - 48) let id x = x let d = Scanf.scanf "%d\n" id let n = Scanf.scanf "%s\n" id let m = 1000000007 let f x n = let v, a = x in let a1 = Array.make d 0 in for i = 0 to d-1 do for j = 0 to 9 do a1.((i+j) mod d) <- (a1.((i+j) mod d) + a.(i)) mod m done done; for j = 0 to n-1 do a1.((v+j) mod d) <- (a1.((v+j) mod d) + 1) mod m done; ((v+n) mod d, a1) let v, a = List.fold_left f (0, Array.make d 0) (digitList n) let () = Printf.printf "%d\n" (a.(0) + (if v == 0 then 1 else 0) - 1)
まともにOCamlを書くのははじめてだったのでまだいろいろ汚いと思うが、とりあえず自分がしたいことを比較的素直に記述できて速度が出るのは非常にうれしい。
f関数の中で動的計画法らしくArrayに対する破壊的代入を繰り返しているが、引数には変更は加えずに命令的に作成したArrayを結果として返しているので外側からみると純粋関数だ。
そのf関数を引数としてfold_leftで畳み込んだり、文字列を|>
を使った関数のチェーンで一桁の数字のリストに変換したり、といったところは抽象度が高い関数型らしい記述ができる。
その上で実行速度が107のオーダーをAtCoderの時間制限の2000msでも問題なくこなせるのだからこれは強い。
習熟度が非常に低いので実際のコンテストではまだ使えないだろうが、少しずつ過去問(とくに計算量のオーダーが大きそうな問題で)をOCamlで挑戦していきたい。
トポロジカル・ソートをPythonで実装してみた
DPとはDAGの最短経路を、トポロジカルソート順に埋めていくことで計算する手法という話からそもそもトポロジカルソートってどうやるんだっけ?となり、Pythonで一つのアプローチを実装してみた:
from collections import defaultdict, deque v, n = map(int, input().split()) es = [[int(x) for x in input().split()] for _ in range(n)] outs = defaultdict(list) ins = defaultdict(int) for v1, v2 in es: outs[v1].append(v2) ins[v2] += 1 q = deque(v1 for v1 in range(v) if ins[v1] == 0) res = [] while q: v1 = q.popleft() res.append(v1) for v2 in outs[v1]: ins[v2] -= 1 if ins[v2] == 0: q.append(v2) print(', '.join(res))
冒頭にリンクも貼った「DPの話」の記事で言われている「入次数 0 のノード (とそこから伸びる辺) をひたすら取り除きまくる」実装。
vがノードの数、nが辺の数、esが各辺を表す親ノード、小ノードのペア。
outsが各ノードから出ていく先のノードのリストの辞書。insが各ノードの入次数の辞書。
qが入次数が0のノードのキュー。
- キューの頭のノードを取り出し、ソートされたリストの次のメンバとしてappend
- そのノードから伸びる各辺について
- その行き先のノードの入次数を1減らし
- もしそのノードの入次数が0に落ちたらキューに入れる
という処理を繰り返す。
outsとinsの作成にO(E)、qの作成にO(V)、resの作成のループはO(V+E)(外側のループがO(V)回、内側のループは合計でO(E)回)ということで全体的にO(E+V)の計算量。
実際にDPでトポロジカルソートをする必要があることはあまりないらしい(入力が与えられた時点ですでにDAGとしてソートされていることが多いとのこと)。だけどとりあえず必要だったらすぐ実装して使えるようにしておきたい。今回でわかった通りロジックも実装も非常に簡単だし。
桁DPに入門してみた
昨日のAtCoder ABC101でD問題に歯が立たず非常に悔しかったので、桁関連の問題や解法を少し読んでいる。
その経緯で桁に関する動的計画法(桁DP)について下のすごい記事に遭遇した:
今蟻本でも動的計画法の章を読んでいることだし、この機会に桁DPに入門してみる。
以下にPythonでの実装と自分なりの気づきをメモってみた:
A以下の非負整数の総数を求める
from itertools import product from collections import defaultdict a = input() n = len(a) dp = defaultdict(int) dp[0, 0, 0] = 1 for i, less in product(range(n), (0,1)): max_d = 9 if less else int(a[i]) for d in range(max_d+1): less_ = less or d < max_d dp[i + 1, less_] += dp[i, less] print(sum(dp[n, less] for less in (0, 1)))
初期値としてdp[0, 0, 0, 0]に1を入れておくのが非直観的で面白い。
dp[i, 1]
で上からi番目の桁まで比較した場合nより低い数がひとつ以上出てくるパターンの数。
dp[i, 0]
は各iで必ず値が1になるはず(n以下の整数を数えているので、すべての桁がnのもの以上のパターンはnと完全一致する場合だけ)
A以下の非負整数のうち、3の付く数の総数を求める
from itertools import product from collections import defaultdict a = input() n = len(a) dp = defaultdict(int) dp[0, 0, 0] = 1 for i, less, has3 in product(range(n), (0,1), (0,1)): max_d = 9 if less else int(a[i]) for d in range(max_d+1): less_ = less or d < max_d has3_ = has3 or d == 3 dp[i + 1, less_, has3_] += dp[i, less, has3] print(sum(dp[n, less, 1] for less in (0, 1)))
この条件を加えるのは非常に簡単。
A以下の非負整数のうち、「3が付くまたは3の倍数」の数の総数を求める
from itertools import product from collections import defaultdict a = input() n = len(a) dp = defaultdict(int) dp[0, 0, 0, 0] = 1 for i, less, has3, mod3 in product(range(n), (0,1), (0,1), range(3)): max_d = 9 if less else int(a[i]) for d in range(max_d+1): less_ = less or d < max_d has3_ = has3 or d == 3 mod3_ = (mod3 + d) % 3 dp[i + 1, less_, has3_, mod3_] += dp[i, less, has3, mod3] criteria = ((n, less, has3, mod3) for less, has3, mod3 in product((0, 1), (0,1), range(3)) if has3 or mod3==0) print(sum(dp[c] for c in criteria))
mod3_ = (mod3 + d) % 3
の部分が「桁のトリック」っぽい。「ある数とその数の桁の和は法3で合同」という事実を使っている。
もとの記事には説明がなかったので桁関連の考え方では常識的なのかも。あまり馴染みがなかったので、最初に見たとき何をしているのか考え込んでしまった。mod3_ = (mod3*10 + d) % 3
を少し短く済ませている。
A以下の非負整数のうち、「3が付くまたは3の倍数」かつ「8の倍数でない」数の総数を求める
from itertools import product from collections import defaultdict a = input() n = len(a) dp = defaultdict(int) dp[0, 0, 0, 0, 0] = 1 for i, less, has3, mod3, mod8 in product(range(n), (0,1), (0,1), range(3), range(8)): max_d = 9 if less else int(a[i]) for d in range(max_d+1): less_ = less or d < max_d has3_ = has3 or d == 3 mod3_ = (mod3 + d) % 3 mod8_ = (mod8*10 + d) % 8 dp[i + 1, less_, has3_, mod3_, mod8_] += dp[i, less, has3, mod3, mod8] criteria = ((n, less, has3, mod3, mod8) for less, has3, mod3, mod8 in product((0, 1), (0,1), range(3), range(8)) if (has3 or mod3==0) and mod8!=0) print(sum(dp[c] for c in criteria)) print(sum(n % 8 and (n % 3==0 or '3' in str(n)) for n in range(int(a)+1)))
形が決まってしまえばコードの複雑さはあまり増加しないのがうれしいところ。
ちなみにテストするためには定義通りに書いた非効率な実装と、小さい数字で結果が一致するか調べるのも手だ。例えば以下のワンライナーが使える:
print(sum(n % 8 and (n % 3==0 or '3' in str(n)) for n in range(int(a)+1)))
A以下B以上の非負整数のうち、「3が付くまたは3の倍数」かつ「8の倍数でない」数の総数を求める
from itertools import product from collections import defaultdict a = int(input()) b = int(input()) def count(x): a = str(x) n = len(a) dp = defaultdict(int) dp[0, 0, 0, 0, 0] = 1 for i, less, has3, mod3, mod8 in product(range(n), (0,1), (0,1), range(3), range(8)): max_d = 9 if less else int(a[i]) for d in range(max_d+1): less_ = less or d < max_d has3_ = has3 or d == 3 mod3_ = (mod3 + d) % 3 mod8_ = (mod8*10 + d) % 8 dp[i + 1, less_, has3_, mod3_, mod8_] += dp[i, less, has3, mod3, mod8] criteria = ((n, less, has3, mod3, mod8) for less, has3, mod3, mod8 in product((0, 1), (0,1), range(3), range(8)) if (has3 or mod3==0) and mod8!=0) return sum(dp[c] for c in criteria) print(count(a) - count(b-1))
関数化すれば非常に楽。
ただし、一つのループに「B以上」という条件も組み込むほうがより効率的なはず。今度時間があったらその実装も試してみたい。