Solve day 16 part 2

This commit is contained in:
Gabriel Augendre 2020-12-16 09:45:46 +01:00
parent 4efe6231e0
commit 322563a7cf
No known key found for this signature in database
GPG key ID: 1E693F4CE4AEE7B4

View file

@ -1,6 +1,6 @@
import re import re
from functools import lru_cache from collections import defaultdict
from typing import List, Tuple from typing import Iterable, List, Tuple
def main(filename: str, expected_part_1: int = None, expected_part_2: int = None): def main(filename: str, expected_part_1: int = None, expected_part_2: int = None):
@ -8,13 +8,15 @@ def main(filename: str, expected_part_1: int = None, expected_part_2: int = None
with open(filename) as f: with open(filename) as f:
blocks = f.read().strip().split("\n\n") blocks = f.read().strip().split("\n\n")
counter_part_1 = solve_part_1(blocks) analyser = TicketAnalyserPart1(blocks)
counter_part_1 = analyser.get_error_rate()
print(f"1. Found {counter_part_1}") print(f"1. Found {counter_part_1}")
if expected_part_1: if expected_part_1:
assert expected_part_1 == counter_part_1 assert expected_part_1 == counter_part_1
counter_part_2 = solve_part_2(blocks) columns_assignation = analyser.compute_class_assignation()
counter_part_2 = analyser.get_departure_value(columns_assignation)
print(f"2. Found {counter_part_2}") print(f"2. Found {counter_part_2}")
if expected_part_2: if expected_part_2:
assert expected_part_2 == counter_part_2 assert expected_part_2 == counter_part_2
@ -24,70 +26,111 @@ Range = Tuple[int, int]
Ranges = List[Range] Ranges = List[Range]
def solve_part_1(blocks):
analyser = TicketAnalyserPart1(blocks)
return analyser.get_error_rate()
class TicketAnalyserPart1: class TicketAnalyserPart1:
def __init__(self, blocks): def __init__(self, blocks):
named_ranges = blocks[0].split("\n") named_ranges = blocks[0].split("\n")
ranges = [] self.ranges = {}
for named_range in named_ranges: for named_range in named_ranges:
ranges.extend(self.extract_ranges(named_range)) name, ranges = self.extract_ranges(named_range)
self.ranges = self.merge_ranges(ranges) self.ranges[name] = ranges
self.nearby_tickets = blocks[2].split("\n")[1:] self.my_ticket = list(map(int, blocks[1].split("\n")[1].split(",")))
self.nearby_tickets = [
list(map(int, ticket.split(","))) for ticket in blocks[2].split("\n")[1:]
]
self.valid_tickets = []
@staticmethod @staticmethod
def extract_ranges(named_range: str) -> Ranges: def extract_ranges(named_range: str) -> Tuple[str, Ranges]:
reg = re.compile(r"^.*: (\d+)-(\d+) or (\d+)-(\d+)$") name, ranges = named_range.split(": ")
matches = reg.match(named_range) reg = re.compile(r"(\d+)-(\d+) or (\d+)-(\d+)$")
matches = reg.match(ranges)
groups = [int(group) for group in matches.groups()] groups = [int(group) for group in matches.groups()]
return [(groups[0], groups[1]), (groups[2], groups[3])] return name, [(groups[0], groups[1]), (groups[2], groups[3])]
@staticmethod def get_error_rate(self) -> int:
def merge_ranges(times) -> Ranges:
ranges = []
saved = list(times[0])
for st, en in sorted([sorted(t) for t in times]):
if st <= saved[1]:
saved[1] = max(saved[1], en)
else:
ranges.append(tuple(saved))
saved[0] = st
saved[1] = en
ranges.append(tuple(saved))
return ranges
def get_error_rate(self):
error_rate = 0 error_rate = 0
for ticket in self.nearby_tickets: for ticket in self.nearby_tickets:
error_rate += sum(self.get_invalid_values(ticket)) invalid_values = self.get_invalid_values(ticket)
if invalid_values:
error_rate += sum(invalid_values)
else:
self.valid_tickets.append(ticket)
return error_rate return error_rate
def get_invalid_values(self, ticket: str) -> List[int]: def get_invalid_values(self, ticket: List[int]) -> List[int]:
ticket = map(int, ticket.split(","))
invalid_values = [] invalid_values = []
for value in ticket: for value in ticket:
if self.value_is_invalid(value): if self.value_is_invalid(value):
invalid_values.append(value) invalid_values.append(value)
return invalid_values return invalid_values
@lru_cache(None) def value_is_invalid(self, value: int, ranges: Iterable[Range] = None) -> bool:
def value_is_invalid(self, value: int) -> bool: return not self.value_is_valid(value, ranges)
return not self.value_is_valid(value)
def value_is_valid(self, value: int) -> bool: def value_is_valid(self, value: int, ranges: Iterable[Range] = None) -> bool:
for rng in self.ranges: if ranges is None:
ranges = self.iter_ranges()
for rng in ranges:
if value in range(rng[0], rng[1] + 1): if value in range(rng[0], rng[1] + 1):
return True return True
return False return False
def iter_ranges(self) -> Iterable[Range]:
for ranges in self.ranges.values():
for rng in ranges:
yield rng
def solve_part_2(blocks): def compute_class_assignation(self):
return 0 possible_columns_for_range = defaultdict(list)
for name, ranges in self.ranges.items():
for column in range(0, len(self.ranges)):
if self.column_is_possible(ranges, column):
possible_columns_for_range[name].append(column)
columns_assignation = {}
sorted_keys = self.get_sorted_keys(possible_columns_for_range)
for key in sorted_keys:
assigned_column = possible_columns_for_range[key][0]
columns_assignation[key] = assigned_column
possible_columns_for_range = self.delete_assigned_column(
possible_columns_for_range, assigned_column
)
return columns_assignation
def get_departure_value(self, columns_assignation):
total = 1
for name, column in columns_assignation.items():
if name.startswith("departure"):
total *= self.my_ticket[column]
return total
def column_is_possible(self, ranges, column):
for ticket in self.valid_tickets:
if self.value_is_invalid(ticket[column], ranges):
return False
return True
@staticmethod
def get_sorted_keys(possible_columns_for_range):
return [
item[0]
for item in sorted(
[
(name, len(columns))
for name, columns in possible_columns_for_range.items()
],
key=lambda x: x[1],
)
]
def delete_assigned_column(self, possible_columns_for_range, assigned_column):
new_possible_columns = {}
for name, columns in possible_columns_for_range.items():
new_possible_columns[name] = [
col for col in columns if col != assigned_column
]
return new_possible_columns
if __name__ == "__main__": if __name__ == "__main__":
main("inputs/day16-test1", 71) main("inputs/day16-test1", 71)
main("inputs/day16", 32835) main("inputs/day16", 32835, 514662805187)