Chapter12 Random Walk
xxxxxxxxxx
6
1
begin
2
using ReinforcementLearning
3
using Flux
4
using Statistics
5
using Plots
6
end
21
xxxxxxxxxx
1
1
N = 21
-1.0:0.1:1.0
xxxxxxxxxx
1
1
true_values = -1:0.1:1
RecordRMS
xxxxxxxxxx
3
1
Base. struct RecordRMS <: AbstractHook
2
rms::Vector{Float64}=[]
3
end
xxxxxxxxxx
1
1
(h::RecordRMS)(::PostEpisodeStage, agent, env) = push!(h.rms, sqrt(mean((agent.policy.learner.approximator.table[2:end-1] - true_values[2:end-1]).^2)))
create_agent_env (generic function with 1 method)
xxxxxxxxxx
16
1
function create_agent_env(α, λ)
2
env = RandomWalk1D(N=21)
3
ns, na = length(state_space(env)), length(action_space(env))
4
agent = Agent(
5
policy=VBasedPolicy(
6
learner=TDλReturnLearner(
7
approximator=TabularVApproximator(;n_state=ns, opt=Descent(α)),
8
γ=1.0,
9
λ=λ
10
),
11
mapping = (env, V) -> rand(1:na)
12
),
13
trajectory=VectorSARTTrajectory()
14
)
15
agent, env
16
end
records (generic function with 2 methods)
xxxxxxxxxx
9
1
function records(α, λ, nruns=10)
2
rms = []
3
for _ in 1:nruns
4
hook = RecordRMS()
5
run(create_agent_env(α, λ)..., StopAfterEpisode(10, is_show_progress=false),hook)
6
push!(rms, mean(hook.rms))
7
end
8
mean(rms)
9
end
xxxxxxxxxx
10
1
begin
2
As = [0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.05:0.5, 0:0.02:0.2, 0:0.01:0.1]
3
Λ = [0., 0.4, .8, 0.9, 0.95, 0.975, 0.99, 1.]
4
p = plot(legend=:topright)
5
for (A, λ) in zip(As, Λ)
6
plot!(p, A, [records(α, λ) for α in A], label="lambda = $λ")
7
end
8
ylims!(p, (0.25, 0.55))
9
p
10
end