Sean McLemon | Advent of Code

Home | Czech | Blog | GitHub | Advent Of Code | Notes


2021-12-18 - Snailfish

(original .ipynb)

Day 18 puzzle input is a bunch of numbers defined as nested set of pairs inside square brackets (mine is here). Part 1 involves interpreting these as an equation and finding the sum. Part 2 involves finding which two numbers sum to the largest value.

from math import ceil, floor
from functools import reduce

puzzle_input_str = open("puzzle_input/day18.txt").read()

test_input_str = """[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]"""


def parse_number(number):
    parsed = []
    # support parsing multi-digit numbers to make test cases easier
    # even though the main task doesn't need it
    current_number = []
    for c in number:
        if c == "]":
            if len(current_number) > 0:
                parsed.append(int("".join(current_number)))
                current_number = []
            parsed.append("]")
        elif c == "[":
            parsed.append("[")
        elif c == ",":
            if len(current_number) > 0:
                parsed.append(int("".join(current_number)))
                current_number = []
        else:
            assert c.isdigit()
            current_number.append(c)
    
    assert len(current_number) == 0
    return parsed


def explode_number(number):
    level = 0
    last_digit = -1
    
    for i, token in enumerate(number):
        if token == "[":
            level += 1
        elif token == "]":
            level -= 1
        else:
            if level > 4:
                # distribute:
                # - this number left (to number[last_digit])
                # - next number right (to number[next_digit])
                next_digit = i + 2
                while next_digit < len(number):
                    if type(number[next_digit]) == int:
                        break
                    next_digit += 1
                
                if last_digit > 0:
                    number[last_digit] += token
                    
                if next_digit < len(number):
                    number[next_digit] += number[i+1]
                
                return number[:i-1] + [0] + number[i+3:]
            else:
                last_digit = i
                
    return number
                
                
def split_number(number):
    for i, token in enumerate(number):
        if type(token) == int:
            if token >= 10:
                return number[:i] + new_split_pair(token) + number[i+1:]
    return number
    
    
def new_split_pair(n):
    l = floor(n / 2)
    r =  ceil(n / 2)
    return ["[", l, r, "]"]
    

def reduce_step(number):
    new_number = explode_number(number)
    if len(number) != len(new_number):
        return new_number
    return split_number(number)


def reduce_number(number):
    new_number = reduce_step(number)
    while len(new_number) != len(number):
        number = new_number
        new_number = reduce_step(number)
    return new_number
    

def add_numbers(num1, num2):
    return reduce_number(["["] + num1 + num2 + ["]"])


def add_and_sum(numbers):
    reduced_numbers = (reduce_number(n) for n in numbers)
    return reduce(add_numbers, reduced_numbers)


def create_tree(number, pos=1):
    if number[pos] == "[":
        l, pos = create_tree(number, pos + 1)
    else:
        l = number[pos]
        pos += 1
    
    # then we parse a number or pair
    if number[pos] == "[":
        r, pos = create_tree(number, pos + 1)
    else:
        r = number[pos]
        pos += 1
        
    # then hopefully the closing bracket
    assert "]" == number[pos]
    pos += 1
    
    # then return our pair?
    return [l,r], pos


def magnitude(value):
    if type(value) == int:
        return value
    else:
        left, right = value
        return (3 * magnitude(left)) + (2 * magnitude(right))

    
def test_explode(test_number, expected):
    n = parse_number(test_number)
    actual = explode_number(n)
    actual   = "".join(str(s) if type(s) == int else s for s in actual)
    expected = "".join(expected).replace(",","") #.replace("  ", " ")
    assert actual == expected
       
    
def test_split(test_number, expected):
    test_number = parse_number(test_number)
    actual = split_number(test_number)
    actual   = "".join(str(s) if type(s) == int else s for s in actual).replace(",","")
    expected = "".join(expected).replace(",","") #.replace("  ", " ")
    assert actual == expected    
    
    
def test_reduce(test_number, expected):
    n = parse_number(test_number)
    actual = reduce_number(n)
    actual   = "".join(str(s) if type(s) == int else s for s in actual)
    expected = "".join(expected).replace(",","") #.replace("  ", " ")
    assert actual == expected


def test_add_numbers(test_numbers, expected):
    parsed_numbers = [parse_number(n) for n in test_numbers]
    actual = add_and_sum(parsed_numbers)
    actual   = "".join(str(s) if type(s) == int else s for s in actual).replace(",","")
    expected = "".join(expected).replace(",","") #.replace("  ", " ")
    assert actual == expected
    
    
def test_magnitude(test_number, expected):
    test_number = parse_number(test_number)
    tree, p = create_tree(test_number)
    actual = magnitude(tree)
    assert expected == actual
    
    
test_explode("[[[[[9,8],1],2],3],4]", "[[[[0,9],2],3],4]")
test_explode("[7,[6,[5,[4,[3,2]]]]]", "[7,[6,[5,[7,0]]]]")
test_explode("[[6,[5,[4,[3,2]]]],1]", "[[6,[5,[7,0]]],3]")
test_explode("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]", "[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]")
test_explode("[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]", "[[3,[2,[8,0]]],[9,[5,[7,0]]]]")    
test_split("[10]", "[[5,5]]") # modified to play nice with my test func (extra parens on expected)
test_split("[11]", "[[5,6]]") # modified to play nice with my test func (extra parens on expected)
test_split("[12]", "[[6,6]]") # modified to play nice with my test func (extra parens on expected)

# manually run the first full reduction test
test_explode("[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]", "[[[[0,7],4],[7,[[8,4],9]]],[1,1]]")
test_explode("[[[[0,7],4],[7,[[8,4],9]]],[1,1]]", "[[[[0,7],4],[15,[0,13]]],[1,1]]")
test_split("[[[[0,7],4],[15,[0,13]]],[1,1]]", "[[[[0,7],4],[[7,8],[0,13]]],[1,1]]")
test_split("[[[[0,7],4],[[7,8],[0,13]]],[1,1]]", "[[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]")
test_explode("[[[[0,7],4],[[7,8],[0,[6,7]]]],[1,1]]", "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]")

test_reduce("[[[[[4,3],4],4],[7,[[8,4],9]]],[1,1]]", "[[[[0,7],4],[[7,8],[6,0]]],[8,1]]")

test_add_numbers(["[1,1]","[2,2]","[3,3]","[4,4]"], "[[[[1,1],[2,2]],[3,3]],[4,4]]") 
test_add_numbers(["[1,1]","[2,2]","[3,3]","[4,4]","[5,5]"], "[[[[3,0],[5,3]],[4,4]],[5,5]]") 
test_add_numbers(["[1,1]","[2,2]","[3,3]","[4,4]","[5,5]","[6,6]"], "[[[[5,0],[7,4]],[5,5]],[6,6]]") 

larger_sum_test_case = [
    "[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]",
    "[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]",
    "[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]",
    "[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]",
    "[7,[5,[[3,8],[1,4]]]]",
    "[[2,[2,2]],[8,[8,1]]]",
    "[2,9]",
    "[1,[[[9,3],9],[[9,0],[0,7]]]]",
    "[[[5,[7,4]],7],1]",
    "[[[[4,2],2],6],[8,7]]",
]
test_add_numbers(larger_sum_test_case, "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]")
test_add_numbers(test_input_str.split("\n"), "[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]")

test_magnitude("[[1,2],[[3,4],5]]", 143)
test_magnitude("[[[[0,7],4],[[7,8],[6,0]]],[8,1]]", 1384)
test_magnitude("[[[[1,1],[2,2]],[3,3]],[4,4]]", 445)
test_magnitude("[[[[3,0],[5,3]],[4,4]],[5,5]]", 791)
test_magnitude("[[[[5,0],[7,4]],[5,5]],[6,6]]", 1137)
test_magnitude("[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]", 3488)


def part_one(input_str):
    parsed_numbers = [parse_number(n) for n in input_str.split("\n")]
    result = add_and_sum(parsed_numbers)
    tree, _ = create_tree(result)
    return magnitude(tree)


assert 4140 == part_one(test_input_str)
print("part one:", part_one(puzzle_input_str))
part one: 4184
from itertools import permutations


def part_two(input_str):
    parsed_numbers = [parse_number(n) for n in input_str.split("\n")]
    max_magnitude = 0
    
    # TODO: smush the body of this loop into one func, and comb thru with functools.reduce?
    for num_pair in permutations(parsed_numbers, 2):
        result = add_and_sum(num_pair)
        tree, _ = create_tree(result)
        m = magnitude(tree)
        if m > max_magnitude:
            max_magnitude = m
    
    return max_magnitude


assert 3993 == part_two(test_input_str)
print("part two:", part_two(puzzle_input_str))
part two: 4731

OK the task is done. But what you see above doesn't show the full picture. I wrote the below and fiddled with it for a couple of hours before throwing my hands up and starting over with very aggressive and thorough testing. It was behaving fine for the other test cases I threw at it, but not adding the "example homework assignment" correctly. I'm glad I went back to the drawing board - it went very smoothly once I broke everything down into nice little testable units. And the very first thing I wrote - the function to parse a string into a tree - was able to be re-used with a little modification :D

import math
from functools import reduce

puzzle_input_str = open("puzzle_input/day18.txt").read()

test_sum_input_str = """[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]"""

test_magnitude_input_str = """[1,2]
[[1,2],3]
[9,[8,7]]
[[1,9],[8,5]]
[[[[1,2],[3,4]],[[5,6],[7,8]]],9]
[[[9,[3,8]],[[0,9],6]],[[[3,7],[4,9]],3]]
[[[[1,3],[5,3]],[[1,3],[8,7]]],[[[4,9],[6,9]],[[8,2],[7,3]]]]"""

def parse_pair(string, pos=1):
    #print(f"parse_pair({string},{pos})")
    # either parse a number or pair
    if string[pos] == "[":
        #print("ok")
        l, pos = parse_pair(string, pos + 1)
    else:
        l = int(string[pos])
        pos += 1
    
    
    #print(string)
    padding = " "*(pos-1)
    #print(f"{padding}^")
    assert "," == string[pos]
    pos += 1
    
    
    # then we parse a number or pair
    if string[pos] == "[":
        r, pos = parse_pair(string, pos + 1)
    else:
        r = int(string[pos])
        pos += 1
        
    # then hopefully the closing bracket
    assert "]" == string[pos], string[pos]
    pos += 1
    
    # then return our pair?
    return [l,r], pos
        

def parse_input(input_str):
    return [tokenize(line) for line in input_str.split("\n")]


def tokenize(line_str):
    line = []
    for c in line_str:
        line.append(int(c) if c.isdigit() else c)
    return line
    

def reduce_step(line):
    line, changed = try_explode(line)
    if not changed:
        #print("let's split")
        line, changed = try_split(line)
    return line, changed


def dump_line(line):
    pass
    #print(recombine(line))

def recombine(line):    
    return "".join((str(c) if type(c) == int else c) for c in line)

def reduce_pairs(line):
    changed = True
    while changed:
        dump_line(line)
        line, changed = reduce_step(line)

    return line


def try_explode(line):
    last_numeric_position = None
    next_line = []
    level = 0
    exploded = False
    to_explode = None
    adjustment = 0
    
    for i, c in enumerate(line):
        #print(level, recombine(next_line), to_explode)
        if exploded:
            if type(c) == int:
                next_line.append(c + adjustment)                
                adjustment = 0
            else:
                next_line.append(c)
        else:
            if c == "[":
                if level == 4:
                    to_explode = []
                else:
                    level += 1
                    next_line.append(c)
                
            elif c == "]":
                if to_explode is not None:
                    #print("exploding", to_explode)
                    assert len(to_explode) == 2
                    l,r = to_explode
                    if last_numeric_position:
                        next_line[last_numeric_position] += l
                    adjustment = r                
                    next_line.append(0)
                    exploded = True
                    to_explode = None # needed?
                else:
                    next_line.append(c)
                    level -= 1
                
            elif type(c) == int:
                if to_explode is not None:
                    to_explode.append(c)
                else:
                    last_numeric_position = len(next_line)
                    next_line.append(c)
            else:
                if to_explode is None:
                    next_line.append(c)
                
    return next_line, exploded

assert "[[[[0,9],2],3],4]" == recombine(try_explode(tokenize("[[[[[9,8],1],2],3],4]"))[0])
assert "[7,[6,[5,[7,0]]]]" == recombine(try_explode(tokenize("[7,[6,[5,[4,[3,2]]]]]"))[0])
assert "[[6,[5,[7,0]]],3]" == recombine(try_explode(tokenize("[[6,[5,[4,[3,2]]]],1]"))[0])
assert "[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]" == recombine(try_explode(tokenize("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]"))[0])
assert "[[3,[2,[8,0]]],[9,[5,[7,0]]]]" == recombine(try_explode(tokenize("[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]"))[0])


# To split a regular number, replace it with a pair;
# - the left  element of the pair should be the regular number divided by two and rounded down
# - the right element of the pair should be the regular number divided by two and rounded up. 
# For example, 10 becomes [5,5], 11 becomes [5,6], 12 becomes [6,6], and so on.    
from math import ceil, floor
def create_split_pair(number):
    l = floor(number / 2)
    r =  ceil(number / 2)
    return ["[", l, ",", r, "]"]
    
def try_split(line):
    split = False
    new_line = []
    for i, c in enumerate(line):
        if type(c) == int:
            if c >= 10 and not split:
                split = True
                new_pair = create_split_pair(c)
                #print("splitting", c, "into", new_pair)
                new_line += new_pair
            else:
                new_line.append(c)
        else:
            new_line.append(c)
    return new_line, split

assert "[5,5]" == recombine(try_split([10])[0])
assert "[5,6]" == recombine(try_split([11])[0])
assert "[6,6]" == recombine(try_split([12])[0])

def snailfish_sum(numbers):
    number = numbers.pop(0)
    while len(numbers) > 0:
        number = reduce_pairs(tokenize(f"[{recombine(number)},{recombine(numbers.pop(0))}]"))
    return recombine(reduce_pairs(number))


def snailfish_sum2(numbers):
    number = numbers.pop(0)
    while len(numbers) > 0:
        a = recombine(number)
        b = recombine(numbers.pop(0))
        c = tokenize(f"[{a},{b}]")
        number = reduce_pairs(c)
    return recombine(number)

# The magnitude of a pair is 
# - 3 times the magnitude of its left element, plus 
# - 2 times the magnitude of its right element. 
# The magnitude of a regular number is just that number.

def magnitude(value):
    #print(f"magnitude({value})")
    if type(value) == int:
        #print("value is", value)
        return value
    else:
        #print(len(value))
        left, right = value
        return (3 * magnitude(left)) + (2 * magnitude(right))

def calc_magnitude(number):
    tree, total_parsed = parse_pair(number)
    assert total_parsed == len(number)
    return magnitude(tree)

def parse_and_sum(input_str):
    numbers = parse_input(input_str)
    return snailfish_sum2(numbers) 

def part_one(input_str):
    result = parse_and_sum(input_str)
    return calc_magnitude(result)


assert "[[[[1,1],[2,2]],[3,3]],[4,4]]" == parse_and_sum("[1,1]\n[2,2]\n[3,3]\n[4,4]")
assert "[[[[3,0],[5,3]],[4,4]],[5,5]]" == parse_and_sum("[1,1]\n[2,2]\n[3,3]\n[4,4]\n[5,5]")
assert "[[[[5,0],[7,4]],[5,5]],[6,6]]" == parse_and_sum("[1,1]\n[2,2]\n[3,3]\n[4,4]\n[5,5]\n[6,6]")
assert "[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]" == parse_and_sum(test_sum_input_str)
print("ok!")


assert 143 == part_one("[[1,2],[[3,4],5]]")
assert 1384 == part_one("[[[[0,7],4],[[7,8],[6,0]]],[8,1]]")
assert 445 == part_one("[[[[1,1],[2,2]],[3,3]],[4,4]]")
assert 791 == part_one("[[[[3,0],[5,3]],[4,4]],[5,5]]")
assert 1137 == part_one("[[[[5,0],[7,4]],[5,5]],[6,6]]")
assert 3488 == part_one("[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]")

#assert 7 == part_one(test_input_str)
# print("part one:", part_one(puzzle_input_str))
ok!
[[[[6,7],[7,7]],[[7,7],[7,7]]],[[[7,7],[7,9]],[[7,8],[0,9]]]]