Python 覆面式ソルバー 終
前回からの続き。
そうこうしているうちに、ついにルール追加である。
以下の二つのルールを追加する。
- 二つの数の和の一番上の桁に関するルール
- 二つの数とその和の同じ桁に関するルール
その前に、足し算と引き算でルールをケース分けするのはめんどうなので、数字の順番をひっくり返して引き算を足し算に変える。(掛け算・割り算は今回はやらない)
if operator == '-': num1, num3 = num3, num1
x - y = z
が成り立つ場合、z + y = x
も成り立つ。
さらに、足し算の場合も足し合わせる数の桁が合わない場合、どちらが多いかケース分けするのは大変なので、必ず最初の数字のほうが桁が多いようにする。
if len(num1) < len(num2): num1, num2 = num2, num1
まず第一に、和の一番上の桁についてのルール。
もし和の桁がほかの二つの数字よりも多い場合、一番上の桁は1しかありえない。
if len(num3) > max(len(num1), len(num2)): def one_only(alpha_dict, value={1}): return value rule_dict[num3[0]].append(one_only)
第二のルールとして、同じ桁の二つの数がわかっている場合、もう一つの数字はほぼ絞れる(繰り越しがあるので二つの可能性がある)という点がある。
下一桁だと繰り上がりがないので完全に一つの数に絞られる。例えばABC + XYZ = RST
という式で、C
とZ
の値がわかっていればT
は(C+Z)%10
に絞り切れる。
c1 = num1[-1] c2 = num2[-1] c3 = num3[-1] def check_c1(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n2 = alpha_dict[c2] n3 = alpha_dict[c3] if not n2 or not n3: return default return {(n3 - n2) % 10} rule_dict[c1].append(check_c1) def check_c2(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] n3 = alpha_dict[c3] if not n1 or not n3: return default return {(n3 - n1) % 10} rule_dict[c2].append(check_c2) def check_c3(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] n2 = alpha_dict[c2] if not n1 or not n2: return default return {(n1 + n2) % 10} rule_dict[c3].append(check_c3)
それ以外の桁の場合、二つの数字に絞れる。ただし、num2がnum1より桁が少ない場合を考慮する必要がある。ABC + XYZ = RST
で、B
とY
がわかっていればS
は(B+Y)%10
か(B+Y)%10 + 1
のどちらかになる。
for i, c1 in enumerate(num1[-1::-1]): try: c2 = num2[-1::-1][i] c3 = num3[-1::-1][i] def check_c1(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n2 = alpha_dict[c2] n3 = alpha_dict[c3] if not n2 or not n3: return default return {(n3 - n2) % 10, (n3 - n2 - 1) % 10} rule_dict[c1].append(check_c1) def check_c2(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] n3 = alpha_dict[c3] if not n1 or not n3: return default return {(n3 - n1) % 10, (n3 - n1 - 1) % 10} rule_dict[c2].append(check_c2) def check_c3(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] n2 = alpha_dict[c2] if not n1 or not n2: return default return {(n1 + n2) % 10, (n1 + n2 + 1) % 10} rule_dict[c3].append(check_c3) except IndexError: c3 = num3[-1::-1][i] def check_c1(alpha_dict, c1=c1, c3=c3, default=set(range(10))): n3 = alpha_dict[c3] if not n2: return default return {n3 % 10, (n3 - 1) % 10} rule_dict[c1].append(check_c1) def check_c3(alpha_dict, c1=c1, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] if not n1: return default return {n1 % 10, (n1 + 1) % 10} rule_dict[c3].append(check_c3)
これらをcreate_rules
に組み入れると、その時点で実行速度は0.15秒~0.3秒、0.7秒~5秒となる。これで終わってもいいのだが、これらのルールを踏まえてcharacters
の順番を組み替えてみる。
if operator == '-': num1, num3 = num3, num1 if len(num1) < len(num2): num1, num2 = num2, num1 characters = [] if len(num3) > len(num1): characters.append(num3[0]) for tup in zip_longest(num1[::-1], num2[::-1],num3[::-1]): for c in tup: if c and (c not in characters): characters.append(c)
num3
の桁数が多い場合、その一番上の桁を先頭に持ってくる。それ以外は、下のほうから同じ桁の文字が並ぶようにする。これで効率よく絞り込みが利くようになる。
最終的なコードは(長くなるが)以下のようになる。
from operator import add, sub from string import ascii_letters from collections import defaultdict from itertools import zip_longest def solve(input_string): rules = create_rules(input_string) bracket_list = [bracket_letter(c) for c in input_string] bracket_string = ''.join(bracket_list) num1, operator, num2, _, num3 = input_string.split() b1, _, b2, _, b3 = bracket_string.split() if operator == '-': num1, num3 = num3, num1 b1, b3 = b3, b1 if len(num1) < len(num2): num1, num2 = num2, num1 b1, b2 = b2, b1 characters = [] if len(num3) > len(num1): characters.append(num3[0]) for tup in zip_longest(num1[::-1], num2[::-1],num3[::-1]): for c in tup: if c and (c not in characters): characters.append(c) alpha_dict = {c:None for c in characters} for alpha_dict in recurse(0, characters, alpha_dict, b1, b2, b3, rules): print(bracket_string.format(**alpha_dict)) def bracket_letter(c): if c in ascii_letters: c = '{'+c+'}' return c def recurse(i, characters, alpha_dict, b1, b2, b3, rules): c = characters[i] possible_numbers = rules(c, alpha_dict) for num1 in possible_numbers: alpha_dict[c] = num1 if i == len(characters)-1: n1 = int(b1.format(**alpha_dict)) n2 = int(b2.format(**alpha_dict)) n3 = int(b3.format(**alpha_dict)) if n1 + n2 == n3: yield alpha_dict.copy() else: yield from recurse(i+1, characters, alpha_dict, b1, b2, b3, rules) alpha_dict[c] = None def create_rules(input_string): num1, operator, num2, _, num3 = input_string.split() if operator == '-': num1, num3 = num3, num1 if len(num1) < len(num2): num1, num2 = num2, num1 rule_dict = defaultdict(list) for first_character in (n[0] for n in (num1, num2, num3)): def not_zero(alpha_dict, value=set(range(1,10))): return value rule_dict[first_character].append(not_zero) if len(num3) > max(len(num1), len(num2)): def one_only(alpha_dict, value={1}): return value rule_dict[num3[0]].append(one_only) c1 = num1[-1] c2 = num2[-1] c3 = num3[-1] def check_c1(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n2 = alpha_dict[c2] n3 = alpha_dict[c3] if not n2 or not n3: return default return {(n3 - n2) % 10} rule_dict[c1].append(check_c1) def check_c2(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] n3 = alpha_dict[c3] if not n1 or not n3: return default return {(n3 - n1) % 10} rule_dict[c2].append(check_c2) def check_c3(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] n2 = alpha_dict[c2] if not n1 or not n2: return default return {(n1 + n2) % 10} rule_dict[c3].append(check_c3) for i, c1 in enumerate(num1[-1::-1]): try: c2 = num2[-1::-1][i] c3 = num3[-1::-1][i] def check_c1(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n2 = alpha_dict[c2] n3 = alpha_dict[c3] if not n2 or not n3: return default return {(n3 - n2) % 10, (n3 - n2 - 1) % 10} rule_dict[c1].append(check_c1) def check_c2(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] n3 = alpha_dict[c3] if not n1 or not n3: return default return {(n3 - n1) % 10, (n3 - n1 - 1) % 10} rule_dict[c2].append(check_c2) def check_c3(alpha_dict, c1=c1, c2=c2, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] n2 = alpha_dict[c2] if not n1 or not n2: return default return {(n1 + n2) % 10, (n1 + n2 + 1) % 10} rule_dict[c3].append(check_c3) except IndexError: c3 = num3[-1::-1][i] def check_c1(alpha_dict, c1=c1, c3=c3, default=set(range(10))): n3 = alpha_dict[c3] if not n2: return default return {n3 % 10, (n3 - 1) % 10} rule_dict[c1].append(check_c1) def check_c3(alpha_dict, c1=c1, c3=c3, default=set(range(10))): n1 = alpha_dict[c1] if not n1: return default return {n1 % 10, (n1 + 1) % 10} rule_dict[c3].append(check_c3) def rules(c, alpha_dict, rule_dict=rule_dict): if c in rule_dict: possible_sets = [rule(alpha_dict) for rule in rule_dict[c]] else: possible_sets = set() output_set = set(n for n in range(10) if n not in alpha_dict.values()) for s in possible_sets: output_set = output_set.intersection(s) return output_set return rules
もともとはあんなにシンプルだったコードが200行近くになってしまった。が、実行速度は'SEND + MORE = MONEY'で0.15秒、'WWWDOT - GOOGLE = DOTCOM'で0.7秒でどちらもまず1秒を超えることはない。
ルールをさらに追加していったり、あるいはプロファイルしてミクロ的なスピードアップを図ったり、とさらに数倍は速くできそうだが、とりあえずここで一旦終わりにしておく。どちらかというと、ルールまわりのコードをより簡潔に表記できないか、という方向には興味があるがそれも置いておこう。