読者です 読者をやめる 読者になる 読者になる

Arantium Maestum

プログラミング、囲碁、読書の話題

Python 覆面式ソルバー 終

前回からの続き。

そうこうしているうちに、ついにルール追加である。

以下の二つのルールを追加する。

  • 二つの数の和の一番上の桁に関するルール
  • 二つの数とその和の同じ桁に関するルール

その前に、足し算と引き算でルールをケース分けするのはめんどうなので、数字の順番をひっくり返して引き算を足し算に変える。(掛け算・割り算は今回はやらない)

    if operator == '-':
        num1, num3 = num3, num1

x - y = zが成り立つ場合、z + y = xも成り立つ。

さらに、足し算の場合も足し合わせる数の桁が合わない場合、どちらが多いかケース分けするのは大変なので、必ず最初の数字のほうが桁が多いようにする。

    if len(num1) < len(num2):
        num1, num2 = num2, num1

まず第一に、和の一番上の桁についてのルール。

もし和の桁がほかの二つの数字よりも多い場合、一番上の桁は1しかありえない。

    if len(num3) > max(len(num1), len(num2)):
        def one_only(alpha_dict, value={1}):
            return value
        rule_dict[num3[0]].append(one_only)

第二のルールとして、同じ桁の二つの数がわかっている場合、もう一つの数字はほぼ絞れる(繰り越しがあるので二つの可能性がある)という点がある。

下一桁だと繰り上がりがないので完全に一つの数に絞られる。例えばABC + XYZ = RSTという式で、CZの値がわかっていればT(C+Z)%10に絞り切れる。

    c1 = num1[-1]
    c2 = num2[-1]
    c3 = num3[-1]

    def check_c1(alpha_dict, c1=c1, c2=c2, c3=c3,
            default=set(range(10))):
        n2 = alpha_dict[c2]
        n3 = alpha_dict[c3]
        if not n2 or not n3:
            return default
        return {(n3 - n2) % 10}
    rule_dict[c1].append(check_c1)

    def check_c2(alpha_dict, c1=c1, c2=c2, c3=c3,
            default=set(range(10))):
        n1 = alpha_dict[c1]
        n3 = alpha_dict[c3]
        if not n1 or not n3:
            return default
        return {(n3 - n1) % 10}
    rule_dict[c2].append(check_c2)

    def check_c3(alpha_dict, c1=c1, c2=c2, c3=c3,
            default=set(range(10))):
        n1 = alpha_dict[c1]
        n2 = alpha_dict[c2]
        if not n1 or not n2:
            return default
        return {(n1 + n2) % 10}
    rule_dict[c3].append(check_c3)

それ以外の桁の場合、二つの数字に絞れる。ただし、num2がnum1より桁が少ない場合を考慮する必要がある。ABC + XYZ = RSTで、BYがわかっていればS(B+Y)%10(B+Y)%10 + 1のどちらかになる。

    for i, c1 in enumerate(num1[-1::-1]):
        try:
            c2 = num2[-1::-1][i]
            c3 = num3[-1::-1][i]

            def check_c1(alpha_dict, c1=c1, c2=c2, c3=c3,
                    default=set(range(10))):
                n2 = alpha_dict[c2]
                n3 = alpha_dict[c3]
                if not n2 or not n3:
                    return default
                return {(n3 - n2) % 10, (n3 - n2 - 1) % 10}
            rule_dict[c1].append(check_c1)

            def check_c2(alpha_dict, c1=c1, c2=c2, c3=c3,
                    default=set(range(10))):
                n1 = alpha_dict[c1]
                n3 = alpha_dict[c3]
                if not n1 or not n3:
                    return default
                return {(n3 - n1) % 10, (n3 - n1 - 1) % 10}
            rule_dict[c2].append(check_c2)

            def check_c3(alpha_dict, c1=c1, c2=c2, c3=c3,
                    default=set(range(10))):
                n1 = alpha_dict[c1]
                n2 = alpha_dict[c2]
                if not n1 or not n2:
                    return default
                return {(n1 + n2) % 10, (n1 + n2 + 1) % 10}
            rule_dict[c3].append(check_c3)

        except IndexError:
            c3 = num3[-1::-1][i]

            def check_c1(alpha_dict, c1=c1, c3=c3,
                    default=set(range(10))):
                n3 = alpha_dict[c3]
                if not n2:
                    return default
                return {n3 % 10, (n3 - 1) % 10}
            rule_dict[c1].append(check_c1)

            def check_c3(alpha_dict, c1=c1, c3=c3,
                    default=set(range(10))):
                n1 = alpha_dict[c1]
                if not n1:
                    return default
                return {n1 % 10, (n1 + 1) % 10}
            rule_dict[c3].append(check_c3)

これらをcreate_rulesに組み入れると、その時点で実行速度は0.15秒~0.3秒、0.7秒~5秒となる。これで終わってもいいのだが、これらのルールを踏まえてcharactersの順番を組み替えてみる。

    if operator == '-':
        num1, num3 = num3, num1
    if len(num1) < len(num2):
        num1, num2 = num2, num1

    characters = []
    if len(num3) > len(num1):
        characters.append(num3[0])
    for tup in zip_longest(num1[::-1], num2[::-1],num3[::-1]):
        for c in tup:
            if c and (c not in characters):
                characters.append(c)

num3の桁数が多い場合、その一番上の桁を先頭に持ってくる。それ以外は、下のほうから同じ桁の文字が並ぶようにする。これで効率よく絞り込みが利くようになる。

最終的なコードは(長くなるが)以下のようになる。

from operator import add, sub
from string import ascii_letters
from collections import defaultdict
from itertools import zip_longest

def solve(input_string):
    rules = create_rules(input_string)

    bracket_list = [bracket_letter(c) for c in input_string]
    bracket_string = ''.join(bracket_list)

    num1, operator, num2, _, num3 = input_string.split()
    b1, _, b2, _, b3 = bracket_string.split()

    if operator == '-':
        num1, num3 = num3, num1
        b1, b3 = b3, b1
    if len(num1) < len(num2):
        num1, num2 = num2, num1
        b1, b2 = b2, b1

    characters = []
    if len(num3) > len(num1):
        characters.append(num3[0])
    for tup in zip_longest(num1[::-1], num2[::-1],num3[::-1]):
        for c in tup:
            if c and (c not in characters):
                characters.append(c)

    alpha_dict = {c:None for c in characters}


    for alpha_dict in recurse(0,
            characters, alpha_dict, b1, b2, b3, rules):
        print(bracket_string.format(**alpha_dict))


def bracket_letter(c):
    if c in ascii_letters:
        c = '{'+c+'}'
    return c


def recurse(i, 
            characters, alpha_dict, b1, b2, b3, rules):
    c = characters[i]
    possible_numbers = rules(c, alpha_dict)
    for num1 in possible_numbers:
        alpha_dict[c] = num1
        if i == len(characters)-1:
            n1 = int(b1.format(**alpha_dict))
            n2 = int(b2.format(**alpha_dict))
            n3 = int(b3.format(**alpha_dict))
            if n1 + n2 == n3:
                yield alpha_dict.copy()
        else:
            yield from recurse(i+1,
                    characters, alpha_dict, b1, b2, b3, rules)
    alpha_dict[c] = None


def create_rules(input_string):
    num1, operator, num2, _, num3 = input_string.split()
    if operator == '-':
        num1, num3 = num3, num1
    if len(num1) < len(num2):
        num1, num2 = num2, num1

    rule_dict = defaultdict(list)

    for first_character in (n[0] for n in (num1, num2, num3)):
        def not_zero(alpha_dict, value=set(range(1,10))):
            return value
        rule_dict[first_character].append(not_zero)

    if len(num3) > max(len(num1), len(num2)):
        def one_only(alpha_dict, value={1}):
            return value
        rule_dict[num3[0]].append(one_only)


    c1 = num1[-1]
    c2 = num2[-1]
    c3 = num3[-1]

    def check_c1(alpha_dict, c1=c1, c2=c2, c3=c3,
            default=set(range(10))):
        n2 = alpha_dict[c2]
        n3 = alpha_dict[c3]
        if not n2 or not n3:
            return default
        return {(n3 - n2) % 10}
    rule_dict[c1].append(check_c1)

    def check_c2(alpha_dict, c1=c1, c2=c2, c3=c3,
            default=set(range(10))):
        n1 = alpha_dict[c1]
        n3 = alpha_dict[c3]
        if not n1 or not n3:
            return default
        return {(n3 - n1) % 10}
    rule_dict[c2].append(check_c2)

    def check_c3(alpha_dict, c1=c1, c2=c2, c3=c3,
            default=set(range(10))):
        n1 = alpha_dict[c1]
        n2 = alpha_dict[c2]
        if not n1 or not n2:
            return default
        return {(n1 + n2) % 10}
    rule_dict[c3].append(check_c3)

    for i, c1 in enumerate(num1[-1::-1]):
        try:
            c2 = num2[-1::-1][i]
            c3 = num3[-1::-1][i]

            def check_c1(alpha_dict, c1=c1, c2=c2, c3=c3,
                    default=set(range(10))):
                n2 = alpha_dict[c2]
                n3 = alpha_dict[c3]
                if not n2 or not n3:
                    return default
                return {(n3 - n2) % 10, (n3 - n2 - 1) % 10}
            rule_dict[c1].append(check_c1)

            def check_c2(alpha_dict, c1=c1, c2=c2, c3=c3,
                    default=set(range(10))):
                n1 = alpha_dict[c1]
                n3 = alpha_dict[c3]
                if not n1 or not n3:
                    return default
                return {(n3 - n1) % 10, (n3 - n1 - 1) % 10}
            rule_dict[c2].append(check_c2)

            def check_c3(alpha_dict, c1=c1, c2=c2, c3=c3,
                    default=set(range(10))):
                n1 = alpha_dict[c1]
                n2 = alpha_dict[c2]
                if not n1 or not n2:
                    return default
                return {(n1 + n2) % 10, (n1 + n2 + 1) % 10}
            rule_dict[c3].append(check_c3)

        except IndexError:
            c3 = num3[-1::-1][i]

            def check_c1(alpha_dict, c1=c1, c3=c3,
                    default=set(range(10))):
                n3 = alpha_dict[c3]
                if not n2:
                    return default
                return {n3 % 10, (n3 - 1) % 10}
            rule_dict[c1].append(check_c1)

            def check_c3(alpha_dict, c1=c1, c3=c3,
                    default=set(range(10))):
                n1 = alpha_dict[c1]
                if not n1:
                    return default
                return {n1 % 10, (n1 + 1) % 10}
            rule_dict[c3].append(check_c3)


    def rules(c, alpha_dict, rule_dict=rule_dict):
        if c in rule_dict:
            possible_sets = [rule(alpha_dict) for rule 
                                in rule_dict[c]]
        else:
            possible_sets = set()

        output_set = set(n for n in range(10)
                            if n not in alpha_dict.values())
        for s in possible_sets:
            output_set = output_set.intersection(s)

        return output_set

    return rules

もともとはあんなにシンプルだったコードが200行近くになってしまった。が、実行速度は'SEND + MORE = MONEY'で0.15秒、'WWWDOT - GOOGLE = DOTCOM'で0.7秒でどちらもまず1秒を超えることはない。

ルールをさらに追加していったり、あるいはプロファイルしてミクロ的なスピードアップを図ったり、とさらに数倍は速くできそうだが、とりあえずここで一旦終わりにしておく。どちらかというと、ルールまわりのコードをより簡潔に表記できないか、という方向には興味があるがそれも置いておこう。