mirror of
https://github.com/Crocmagnon/advent-of-code.git
synced 2024-11-22 06:28:11 +01:00
Solve day 14 part 2
This commit is contained in:
parent
adc8a5ee42
commit
12e608f853
3 changed files with 110 additions and 13 deletions
|
@ -1,8 +1,6 @@
|
||||||
import enum
|
|
||||||
import functools
|
import functools
|
||||||
import math
|
|
||||||
import re
|
import re
|
||||||
from typing import List, Dict, Iterable
|
from typing import List, Dict, Iterable, Union
|
||||||
|
|
||||||
|
|
||||||
def main(filename: str, expected_part_1: int = None, expected_part_2: int = None):
|
def main(filename: str, expected_part_1: int = None, expected_part_2: int = None):
|
||||||
|
@ -18,20 +16,21 @@ def main(filename: str, expected_part_1: int = None, expected_part_2: int = None
|
||||||
if expected_part_1:
|
if expected_part_1:
|
||||||
assert expected_part_1 == counter_part_1
|
assert expected_part_1 == counter_part_1
|
||||||
|
|
||||||
counter_part_2 = solve_part_2(instructions)
|
program = ProgramPart2(instructions)
|
||||||
|
program.run()
|
||||||
|
counter_part_2 = program.compute_memory_sum()
|
||||||
print(f"2. Found {counter_part_2}")
|
print(f"2. Found {counter_part_2}")
|
||||||
if expected_part_2:
|
if expected_part_2:
|
||||||
assert expected_part_2 == counter_part_2
|
assert expected_part_2 == counter_part_2
|
||||||
|
|
||||||
|
|
||||||
Memory = Dict[str, str]
|
Memory = Dict[Union[str, int], Union[str, int]]
|
||||||
|
|
||||||
|
|
||||||
class ProgramPart1:
|
class Program:
|
||||||
def __init__(self, instructions: Iterable[str]):
|
def __init__(self, instructions: Iterable[str]):
|
||||||
self.memory = dict() # type: Memory
|
self.memory = dict() # type: Memory
|
||||||
self.instructions = instructions
|
self.instructions = instructions
|
||||||
self.mask = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
|
|
||||||
self.mem_line_regex = re.compile(r"^mem\[(?P<address>\d+)] = (?P<value>\d+)$")
|
self.mem_line_regex = re.compile(r"^mem\[(?P<address>\d+)] = (?P<value>\d+)$")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
@ -39,8 +38,7 @@ class ProgramPart1:
|
||||||
self.run_line(line)
|
self.run_line(line)
|
||||||
|
|
||||||
def compute_memory_sum(self):
|
def compute_memory_sum(self):
|
||||||
int_base_2 = functools.partial(int, base=2)
|
raise NotImplementedError
|
||||||
return sum(map(int_base_2, self.memory.values()))
|
|
||||||
|
|
||||||
def run_line(self, line: str):
|
def run_line(self, line: str):
|
||||||
if "mask" in line:
|
if "mask" in line:
|
||||||
|
@ -51,6 +49,19 @@ class ProgramPart1:
|
||||||
def update_mask(self, line: str):
|
def update_mask(self, line: str):
|
||||||
self.mask = line.split(" = ")[1]
|
self.mask = line.split(" = ")[1]
|
||||||
|
|
||||||
|
def update_memory(self, line: str):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ProgramPart1(Program):
|
||||||
|
def __init__(self, instructions: Iterable[str]):
|
||||||
|
super().__init__(instructions)
|
||||||
|
self.mask = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
|
||||||
|
|
||||||
|
def compute_memory_sum(self):
|
||||||
|
int_base_2 = functools.partial(int, base=2)
|
||||||
|
return sum(map(int_base_2, self.memory.values()))
|
||||||
|
|
||||||
def update_memory(self, line: str):
|
def update_memory(self, line: str):
|
||||||
match = self.mem_line_regex.match(line)
|
match = self.mem_line_regex.match(line)
|
||||||
if not match:
|
if not match:
|
||||||
|
@ -72,10 +83,54 @@ class ProgramPart1:
|
||||||
return "".join(masked_value)
|
return "".join(masked_value)
|
||||||
|
|
||||||
|
|
||||||
def solve_part_2(program: Iterable[str]) -> int:
|
class ProgramPart2(Program):
|
||||||
return 0
|
def __init__(self, instructions: Iterable[str]):
|
||||||
|
super().__init__(instructions)
|
||||||
|
self.mask = "000000000000000000000000000000000000"
|
||||||
|
|
||||||
|
def compute_memory_sum(self):
|
||||||
|
return sum(self.memory.values())
|
||||||
|
|
||||||
|
def update_memory(self, line: str):
|
||||||
|
match = self.mem_line_regex.match(line)
|
||||||
|
if not match:
|
||||||
|
raise RuntimeError("Memory line regex doesn't match")
|
||||||
|
groups = match.groupdict()
|
||||||
|
address = int(groups["address"])
|
||||||
|
value = int(groups["value"])
|
||||||
|
for masked_address in self.get_masked_addresses(address):
|
||||||
|
self.memory[masked_address] = value
|
||||||
|
|
||||||
|
def get_masked_addresses(self, address: int) -> Iterable[int]:
|
||||||
|
binary_address = "{:036b}".format(address)
|
||||||
|
corrected_binary_address = ""
|
||||||
|
for binary_bit, mask_bit in zip(binary_address, self.mask):
|
||||||
|
if mask_bit == "1":
|
||||||
|
corrected_binary_address += "1"
|
||||||
|
else:
|
||||||
|
corrected_binary_address += binary_bit
|
||||||
|
addresses = self.get_floating_addresses("", corrected_binary_address, self.mask)
|
||||||
|
int_base_2 = functools.partial(int, base=2)
|
||||||
|
return map(int_base_2, addresses)
|
||||||
|
|
||||||
|
def get_floating_addresses(self, prefix: str, address: str, mask: str) -> List[str]:
|
||||||
|
if "X" not in mask:
|
||||||
|
return [prefix + address]
|
||||||
|
first_x = mask.index("X")
|
||||||
|
collector = []
|
||||||
|
new_address = address[first_x + 1 :]
|
||||||
|
new_prefix = prefix + address[:first_x]
|
||||||
|
new_mask = mask[first_x + 1 :]
|
||||||
|
collector.extend(
|
||||||
|
self.get_floating_addresses(new_prefix + "0", new_address, new_mask)
|
||||||
|
)
|
||||||
|
collector.extend(
|
||||||
|
self.get_floating_addresses(new_prefix + "1", new_address, new_mask)
|
||||||
|
)
|
||||||
|
return collector
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main("inputs/day14-test1", 165)
|
# main("inputs/day14-test1", 165) # too slow for part 2
|
||||||
main("inputs/day14")
|
main("inputs/day14-test2", 51, 208)
|
||||||
|
main("inputs/day14", 6559449933360, 3369767240513)
|
||||||
|
|
4
2020/inputs/day14-test2
Normal file
4
2020/inputs/day14-test2
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
mask = 000000000000000000000000000000X1001X
|
||||||
|
mem[42] = 100
|
||||||
|
mask = 00000000000000000000000000000000X0XX
|
||||||
|
mem[26] = 1
|
38
2020/test_day14_masked_addresses.py
Normal file
38
2020/test_day14_masked_addresses.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from day14_docking import ProgramPart2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def program():
|
||||||
|
return ProgramPart2([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_floating(program):
|
||||||
|
program.mask = "0" * 36
|
||||||
|
assert list(program.get_masked_addresses(8)) == [8]
|
||||||
|
|
||||||
|
|
||||||
|
def test_last_floating(program):
|
||||||
|
program.mask = "0" * 35 + "X"
|
||||||
|
assert list(program.get_masked_addresses(8)) == [8, 9]
|
||||||
|
|
||||||
|
|
||||||
|
def test_second_to_last_floating(program):
|
||||||
|
program.mask = "0" * 34 + "X0"
|
||||||
|
assert list(program.get_masked_addresses(8)) == [8, 10]
|
||||||
|
|
||||||
|
|
||||||
|
def test_last_two_floating(program):
|
||||||
|
program.mask = "0" * 34 + "XX"
|
||||||
|
assert list(program.get_masked_addresses(8)) == [8, 9, 10, 11]
|
||||||
|
|
||||||
|
|
||||||
|
def test_one_bit_replacement(program):
|
||||||
|
program.mask = "0" * 35 + "1"
|
||||||
|
assert list(program.get_masked_addresses(8)) == [9]
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_bits_replacement(program):
|
||||||
|
program.mask = "0" * 34 + "11"
|
||||||
|
assert list(program.get_masked_addresses(8)) == [11]
|
Loading…
Reference in a new issue