Arantium Maestum

プログラミング、囲碁、読書の話題

桁DPに入門してみた

昨日のAtCoder ABC101でD問題に歯が立たず非常に悔しかったので、桁関連の問題や解法を少し読んでいる。

その経緯で桁に関する動的計画法(桁DP)について下のすごい記事に遭遇した:

pekempey.hatenablog.com

今蟻本でも動的計画法の章を読んでいることだし、この機会に桁DPに入門してみる。

以下にPythonでの実装と自分なりの気づきをメモってみた:

A以下の非負整数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0] = 1

for i, less in product(range(n), (0,1)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        dp[i + 1, less_] += dp[i, less]

print(sum(dp[n, less] for less in (0, 1)))

初期値としてdp[0, 0, 0, 0]に1を入れておくのが非直観的で面白い。

dp[i, 1]で上からi番目の桁まで比較した場合nより低い数がひとつ以上出てくるパターンの数。

dp[i, 0]は各iで必ず値が1になるはず(n以下の整数を数えているので、すべての桁がnのもの以上のパターンはnと完全一致する場合だけ)

A以下の非負整数のうち、3の付く数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0] = 1

for i, less, has3 in product(range(n), (0,1), (0,1)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        dp[i + 1, less_, has3_] += dp[i, less, has3]

print(sum(dp[n, less, 1] for less in (0, 1)))

この条件を加えるのは非常に簡単。

A以下の非負整数のうち、「3が付くまたは3の倍数」の数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0, 0] = 1

for i, less, has3, mod3 in product(range(n), (0,1), (0,1), range(3)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        mod3_ = (mod3 + d) % 3
        dp[i + 1, less_, has3_, mod3_] += dp[i, less, has3, mod3]

criteria = ((n, less, has3, mod3) for less, has3, mod3 
                                  in product((0, 1), (0,1), range(3)) 
                                  if has3 or mod3==0)

print(sum(dp[c] for c in criteria))

mod3_ = (mod3 + d) % 3の部分が「桁のトリック」っぽい。「ある数とその数の桁の和は法3で合同」という事実を使っている。

もとの記事には説明がなかったので桁関連の考え方では常識的なのかも。あまり馴染みがなかったので、最初に見たとき何をしているのか考え込んでしまった。mod3_ = (mod3*10 + d) % 3を少し短く済ませている。

A以下の非負整数のうち、「3が付くまたは3の倍数」かつ「8の倍数でない」数の総数を求める

from itertools import product
from collections import defaultdict

a = input()
n = len(a)

dp = defaultdict(int)
dp[0, 0, 0, 0, 0] = 1

for i, less, has3, mod3, mod8 in product(range(n), (0,1), (0,1), range(3), range(8)):
    max_d = 9 if less else int(a[i])
    for d in range(max_d+1):
        less_ = less or d < max_d
        has3_ = has3 or d == 3
        mod3_ = (mod3 + d) % 3
        mod8_ = (mod8*10 + d) % 8
        dp[i + 1, less_, has3_, mod3_, mod8_] += dp[i, less, has3, mod3, mod8]

criteria = ((n, less, has3, mod3, mod8) for less, has3, mod3, mod8
                                        in product((0, 1), (0,1), range(3), range(8))
                                        if (has3 or mod3==0) and mod8!=0)

print(sum(dp[c] for c in criteria))

print(sum(n % 8 and (n % 3==0 or '3' in str(n)) for n in range(int(a)+1)))

形が決まってしまえばコードの複雑さはあまり増加しないのがうれしいところ。

ちなみにテストするためには定義通りに書いた非効率な実装と、小さい数字で結果が一致するか調べるのも手だ。例えば以下のワンライナーが使える:

print(sum(n % 8 and (n % 3==0 or '3' in str(n)) for n in range(int(a)+1)))

A以下B以上の非負整数のうち、「3が付くまたは3の倍数」かつ「8の倍数でない」数の総数を求める

from itertools import product
from collections import defaultdict

a = int(input())
b = int(input())

def count(x):
    a = str(x)
    n = len(a)
    dp = defaultdict(int)
    dp[0, 0, 0, 0, 0] = 1

    for i, less, has3, mod3, mod8 in product(range(n), (0,1), (0,1), range(3), range(8)):
        max_d = 9 if less else int(a[i])
        for d in range(max_d+1):
            less_ = less or d < max_d
            has3_ = has3 or d == 3
            mod3_ = (mod3 + d) % 3
            mod8_ = (mod8*10 + d) % 8
            dp[i + 1, less_, has3_, mod3_, mod8_] += dp[i, less, has3, mod3, mod8]

    criteria = ((n, less, has3, mod3, mod8) for less, has3, mod3, mod8
                                            in product((0, 1), (0,1), range(3), range(8))
                                            if (has3 or mod3==0) and mod8!=0)

    return sum(dp[c] for c in criteria)

print(count(a) - count(b-1))

関数化すれば非常に楽。

ただし、一つのループに「B以上」という条件も組み込むほうがより効率的なはず。今度時間があったらその実装も試してみたい。