2020-12-14 08:20:53 +01:00
|
|
|
import functools
|
|
|
|
import re
|
2022-12-02 13:35:44 +01:00
|
|
|
from collections.abc import Iterable
|
|
|
|
from typing import Union
|
2020-12-14 08:20:53 +01:00
|
|
|
|
|
|
|
|
|
|
|
def main(filename: str, expected_part_1: int = None, expected_part_2: int = None):
|
|
|
|
print(f"\n+ Running on {filename}")
|
|
|
|
with open(filename) as f:
|
|
|
|
instructions = f.read().strip().split("\n")
|
|
|
|
|
|
|
|
program = ProgramPart1(instructions)
|
|
|
|
program.run()
|
|
|
|
counter_part_1 = program.compute_memory_sum()
|
|
|
|
|
|
|
|
print(f"1. Found {counter_part_1}")
|
|
|
|
if expected_part_1:
|
|
|
|
assert expected_part_1 == counter_part_1
|
|
|
|
|
2020-12-14 09:07:17 +01:00
|
|
|
program = ProgramPart2(instructions)
|
|
|
|
program.run()
|
|
|
|
counter_part_2 = program.compute_memory_sum()
|
2020-12-14 08:20:53 +01:00
|
|
|
print(f"2. Found {counter_part_2}")
|
|
|
|
if expected_part_2:
|
|
|
|
assert expected_part_2 == counter_part_2
|
|
|
|
|
|
|
|
|
2022-12-02 13:35:44 +01:00
|
|
|
Memory = dict[Union[str, int], Union[str, int]]
|
2020-12-14 08:20:53 +01:00
|
|
|
|
|
|
|
|
2020-12-14 09:07:17 +01:00
|
|
|
class Program:
|
2020-12-14 08:20:53 +01:00
|
|
|
def __init__(self, instructions: Iterable[str]):
|
|
|
|
self.memory = dict() # type: Memory
|
|
|
|
self.instructions = instructions
|
|
|
|
self.mem_line_regex = re.compile(r"^mem\[(?P<address>\d+)] = (?P<value>\d+)$")
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
for line in self.instructions:
|
|
|
|
self.run_line(line)
|
|
|
|
|
|
|
|
def compute_memory_sum(self):
|
2020-12-14 09:07:17 +01:00
|
|
|
raise NotImplementedError
|
2020-12-14 08:20:53 +01:00
|
|
|
|
|
|
|
def run_line(self, line: str):
|
|
|
|
if "mask" in line:
|
|
|
|
self.update_mask(line)
|
|
|
|
else:
|
|
|
|
self.update_memory(line)
|
|
|
|
|
|
|
|
def update_mask(self, line: str):
|
|
|
|
self.mask = line.split(" = ")[1]
|
|
|
|
|
2020-12-14 09:07:17 +01:00
|
|
|
def update_memory(self, line: str):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class ProgramPart1(Program):
|
|
|
|
def __init__(self, instructions: Iterable[str]):
|
|
|
|
super().__init__(instructions)
|
|
|
|
self.mask = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
|
|
|
|
|
|
|
|
def compute_memory_sum(self):
|
|
|
|
int_base_2 = functools.partial(int, base=2)
|
|
|
|
return sum(map(int_base_2, self.memory.values()))
|
|
|
|
|
2020-12-14 08:20:53 +01:00
|
|
|
def update_memory(self, line: str):
|
|
|
|
match = self.mem_line_regex.match(line)
|
|
|
|
if not match:
|
|
|
|
raise RuntimeError("Memory line regex doesn't match")
|
|
|
|
groups = match.groupdict()
|
|
|
|
address = groups["address"]
|
|
|
|
value = int(groups["value"])
|
|
|
|
self.memory[address] = self.get_masked_value(value)
|
|
|
|
|
|
|
|
def get_masked_value(self, value: int) -> str:
|
2022-12-02 13:35:44 +01:00
|
|
|
binary_value = f"{value:036b}"
|
2020-12-14 08:20:53 +01:00
|
|
|
masked_value = []
|
|
|
|
for binary_bit, mask_bit in zip(binary_value, self.mask):
|
|
|
|
if mask_bit == "X":
|
|
|
|
masked_value.append(binary_bit)
|
|
|
|
else:
|
|
|
|
masked_value.append(mask_bit)
|
|
|
|
|
|
|
|
return "".join(masked_value)
|
|
|
|
|
|
|
|
|
2020-12-14 09:07:17 +01:00
|
|
|
class ProgramPart2(Program):
|
|
|
|
def __init__(self, instructions: Iterable[str]):
|
|
|
|
super().__init__(instructions)
|
|
|
|
self.mask = "000000000000000000000000000000000000"
|
|
|
|
|
|
|
|
def compute_memory_sum(self):
|
|
|
|
return sum(self.memory.values())
|
|
|
|
|
|
|
|
def update_memory(self, line: str):
|
|
|
|
match = self.mem_line_regex.match(line)
|
|
|
|
if not match:
|
|
|
|
raise RuntimeError("Memory line regex doesn't match")
|
|
|
|
groups = match.groupdict()
|
|
|
|
address = int(groups["address"])
|
|
|
|
value = int(groups["value"])
|
|
|
|
for masked_address in self.get_masked_addresses(address):
|
|
|
|
self.memory[masked_address] = value
|
|
|
|
|
|
|
|
def get_masked_addresses(self, address: int) -> Iterable[int]:
|
2022-12-02 13:35:44 +01:00
|
|
|
binary_address = f"{address:036b}"
|
2020-12-14 09:07:17 +01:00
|
|
|
corrected_binary_address = ""
|
|
|
|
for binary_bit, mask_bit in zip(binary_address, self.mask):
|
|
|
|
if mask_bit == "1":
|
|
|
|
corrected_binary_address += "1"
|
|
|
|
else:
|
|
|
|
corrected_binary_address += binary_bit
|
|
|
|
addresses = self.get_floating_addresses("", corrected_binary_address, self.mask)
|
|
|
|
int_base_2 = functools.partial(int, base=2)
|
|
|
|
return map(int_base_2, addresses)
|
|
|
|
|
2022-12-02 13:35:44 +01:00
|
|
|
def get_floating_addresses(self, prefix: str, address: str, mask: str) -> list[str]:
|
2020-12-14 09:07:17 +01:00
|
|
|
if "X" not in mask:
|
|
|
|
return [prefix + address]
|
|
|
|
first_x = mask.index("X")
|
|
|
|
collector = []
|
|
|
|
new_address = address[first_x + 1 :]
|
|
|
|
new_prefix = prefix + address[:first_x]
|
|
|
|
new_mask = mask[first_x + 1 :]
|
|
|
|
collector.extend(
|
|
|
|
self.get_floating_addresses(new_prefix + "0", new_address, new_mask)
|
|
|
|
)
|
|
|
|
collector.extend(
|
|
|
|
self.get_floating_addresses(new_prefix + "1", new_address, new_mask)
|
|
|
|
)
|
|
|
|
return collector
|
2020-12-14 08:20:53 +01:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-12-14 09:07:17 +01:00
|
|
|
# main("inputs/day14-test1", 165) # too slow for part 2
|
|
|
|
main("inputs/day14-test2", 51, 208)
|
|
|
|
main("inputs/day14", 6559449933360, 3369767240513)
|