day9: use point array instead of knot pairs

more efficient way to model > 2 knots
main
Bryce Allen 3 years ago
parent e026efdce4
commit 516ba0c389

@ -7,11 +7,9 @@ struct Move
count :: Int count :: Int
end end
mutable struct KnotPair mutable struct Point{T}
head_x :: Int x::T
head_y :: Int y::T
tail_x :: Int
tail_y :: Int
end end
function parse_move(line) function parse_move(line)
@ -19,71 +17,54 @@ function parse_move(line)
return Move(parts[1][1], parse(Int, parts[2])) return Move(parts[1][1], parse(Int, parts[2]))
end end
function state_changed(old_state, new_state)
return (new_state.head_x != old_state.head_x
|| new_state.head_y != old_state.head_y
|| new_state.tail_x != old_state.tail_x
|| new_state.tail_y != old_state.tail_y)
end
function read_moves(io) function read_moves(io)
return [parse_move(line) for line in readlines(io)] return [parse_move(line) for line in readlines(io)]
end end
function apply_move(state, direction) function apply_move(head, tail, direction)
if direction == 'U' if direction == 'U'
state.head_y += 1 head.y += 1
elseif direction == 'D' elseif direction == 'D'
state.head_y -= 1 head.y -= 1
elseif direction == 'L' elseif direction == 'L'
state.head_x -= 1 head.x -= 1
elseif direction == 'R' elseif direction == 'R'
state.head_x += 1 head.x += 1
elseif direction != 'O' elseif direction != 'O'
throw(DomainError("Unknown direction")) throw(DomainError("Unknown direction"))
end end
x_diff = state.head_x - state.tail_x x_diff = head.x - tail.x
y_diff = state.head_y - state.tail_y y_diff = head.y - tail.y
if abs(x_diff) <= 1 && abs(y_diff) <= 1 if abs(x_diff) <= 1 && abs(y_diff) <= 1
return state return head, tail
end end
if abs(x_diff) <= 1 if abs(x_diff) <= 1
state.tail_x += x_diff tail.x += x_diff
state.tail_y += sign(y_diff) * (abs(y_diff) - 1) tail.y += sign(y_diff) * (abs(y_diff) - 1)
elseif abs(y_diff) <= 1 elseif abs(y_diff) <= 1
state.tail_y += y_diff tail.y += y_diff
state.tail_x += sign(x_diff) * (abs(x_diff) - 1) tail.x += sign(x_diff) * (abs(x_diff) - 1)
elseif abs(x_diff) == 2 && abs(y_diff) == 2 elseif abs(x_diff) == 2 && abs(y_diff) == 2
# possible with more than 2 knots # possible with more than 2 knots
state.tail_x += sign(x_diff) * (abs(x_diff) - 1) tail.x += sign(x_diff) * (abs(x_diff) - 1)
state.tail_y += sign(y_diff) * (abs(y_diff) - 1) tail.y += sign(y_diff) * (abs(y_diff) - 1)
else else
throw(DomainError("invalid state " * string(state))) throw(DomainError("invalid knots " * string(head) * " " * string(tail)))
end end
return state return head, tail
end end
function get_visited(moves; npairs::Int=1) function get_visited(moves; nknots::Int=2)
visited = Set{NTuple{2, Int}}() visited = Set{NTuple{2, Int}}()
states = [KnotPair(0, 0, 0, 0) for _ in 1:npairs] knots = [Point(0, 0) for _ in 1:nknots]
push!(visited, (states[npairs].tail_x, states[npairs].tail_y)) push!(visited, (knots[nknots].x, knots[nknots].y))
for move in moves for move in moves
for m in 1:move.count for _ in 1:move.count
# println(move, " ", m) knots[1], knots[2] = apply_move(knots[1], knots[2], move.direction)
# print(" 1 ", states[1], " -> ") for i in 2:nknots
states[1] = apply_move(states[1], move.direction) _, knots[i] = apply_move(knots[i-1], knots[i], 'O')
# println(states[1])
for i in 2:npairs
states[i].head_x = states[i-1].tail_x
states[i].head_y = states[i-1].tail_y
# old_state = deepcopy(states[i])
states[i] = apply_move(states[i], 'O')
# if state_changed(states[i], old_state)
# println(" ", i, " ", old_state, " -> ", states[i])
# end
end end
push!(visited, (states[npairs].tail_x, states[npairs].tail_y)) push!(visited, (knots[nknots].x, knots[nknots].y))
end end
end end
return visited return visited
@ -107,8 +88,8 @@ function test()
@test length(get_visited(read_moves("input.txt"))) == 6037 @test length(get_visited(read_moves("input.txt"))) == 6037
end end
@testset "visited count 10 knot" verbose=true begin @testset "visited count 10 knot" verbose=true begin
@test length(get_visited(read_moves("example.txt"), npairs=9)) == 1 @test length(get_visited(read_moves("example.txt"), nknots=10)) == 1
@test length(get_visited(read_moves("input.txt"), npairs=9)) == 2485 @test length(get_visited(read_moves("input.txt"), nknots=10)) == 2485
end end
end end
@ -126,7 +107,7 @@ function main()
display(visited_map(visited)) display(visited_map(visited))
println() println()
visited10 = get_visited(moves, npairs=9) visited10 = get_visited(moves, nknots=10)
println("visited10: ", visited10) println("visited10: ", visited10)
println("count10 : ", length(visited10)) println("count10 : ", length(visited10))
display(visited_map(visited10)) display(visited_map(visited10))

Loading…
Cancel
Save