Chapter06 Temporal-Difference Learning (Cliff Walking)
xxxxxxxxxx
6
1
begin
2
using ReinforcementLearning
3
using Plots
4
using Flux
5
using Statistics
6
end
In Example 6.6, a gridworld example of Cliff Walking is introduced to compare the difference between on-policy (SARSA) and off-policy (Q-learning). Although there's a package of GridWorlds.jl dedicated to 2-D environments, we decide to write an independent implementation here as a showcase.
4×12 LinearIndices{2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}:
1 5 9 13 17 21 25 29 33 37 41 45
2 6 10 14 18 22 26 30 34 38 42 46
3 7 11 15 19 23 27 31 35 39 43 47
4 8 12 16 20 24 28 32 36 40 44 48
xxxxxxxxxx
13
1
begin
2
const NX = 4
3
const NY = 12
4
const Start = CartesianIndex(4, 1)
5
const Goal = CartesianIndex(4, 12)
6
const LRUD = [
7
CartesianIndex(0, -1), # left
8
CartesianIndex(0, 1), # right
9
CartesianIndex(-1, 0), # up
10
CartesianIndex(1, 0), # down
11
]
12
const LinearInds = LinearIndices((NX, NY))
13
end
iscliff (generic function with 1 method)
xxxxxxxxxx
4
1
function iscliff(p::CartesianIndex{2})
2
x, y = Tuple(p)
3
x == 4 && y > 1 && y < NY
4
end
xxxxxxxxxx
2
1
# take a look at the wordmap
2
heatmap((!iscliff).(CartesianIndices((NX, NY))); yflip = true)
CliffWalkingEnv
xxxxxxxxxx
3
1
Base. mutable struct CliffWalkingEnv <: AbstractEnv
2
position::CartesianIndex{2} = Start
3
end
xxxxxxxxxx
4
1
function (env::CliffWalkingEnv)(a::Int)
2
x, y = Tuple(env.position + LRUD[a])
3
env.position = CartesianIndex(min(max(x, 1), NX), min(max(y, 1), NY))
4
end
xxxxxxxxxx
1
1
RLBase.state(env::CliffWalkingEnv) = LinearInds[env.position]
xxxxxxxxxx
1
1
RLBase.state_space(env::CliffWalkingEnv) = Base.OneTo(length(LinearInds))
xxxxxxxxxx
1
1
RLBase.action_space(env::CliffWalkingEnv) = Base.OneTo(length(LRUD))
xxxxxxxxxx
1
1
RLBase.reward(env::CliffWalkingEnv) = env.position == Goal ? 0.0 : (iscliff(env.position) ? -100.0 : -1.0)
xxxxxxxxxx
1
1
RLBase.is_terminated(env::CliffWalkingEnv) = env.position == Goal || iscliff(env.position)
xxxxxxxxxx
1
1
RLBase.reset!(env::CliffWalkingEnv) = env.position = Start
# CliffWalkingEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(48)`
## Action Space
`Base.OneTo(4)`
## Current State
```
4
```
xxxxxxxxxx
1
1
world = CliffWalkingEnv()
Now we have a workable environment. Next we create several factories to generate different policies for comparison.
4
xxxxxxxxxx
4
1
begin
2
NS = length(state_space(world))
3
NA = length(action_space(world))
4
end
create_agent (generic function with 1 method)
xxxxxxxxxx
16
1
create_agent(α, method) = Agent(
2
policy = QBasedPolicy(
3
learner=TDLearner(
4
approximator=TabularQApproximator(
5
;n_state=NS,
6
n_action=NA,
7
opt=Descent(α),
8
),
9
method=method,
10
γ=1.0,
11
n=0
12
),
13
explorer=EpsilonGreedyExplorer(0.1)
14
),
15
trajectory=VectorSARTTrajectory()
16
)
repeated_run (generic function with 2 methods)
xxxxxxxxxx
15
1
function repeated_run(α, method, N, n_episode, is_mean=true)
2
env = CliffWalkingEnv()
3
rewards = []
4
for _ in 1:N
5
h = TotalRewardPerEpisode()
6
run(
7
create_agent(α, method),
8
env,
9
StopAfterEpisode(n_episode;is_show_progress=false),
10
h
11
)
12
push!(rewards, is_mean ? mean(h.rewards) : h.rewards)
13
end
14
mean(rewards)
15
end
xxxxxxxxxx
6
1
begin
2
p = plot(legend=:bottomright)
3
plot!(p, repeated_run(0.5, :SARS, 1000, 500, false), label="QLearning")
4
plot!(p, repeated_run(0.5, :SARSA, 1000, 500, false), label="SARSA")
5
p
6
end
xxxxxxxxxx
13
1
begin
2
A = 0.1:0.05:0.95
3
fig_6_3 = plot(;legend=:bottomright)
4
5
plot!(fig_6_3, A, [repeated_run(α, :SARS, 100, 100) for α in A], linestyle=:dash ,markershape=:rect, label="Interim Q")
6
plot!(fig_6_3, A, [repeated_run(α, :SARSA, 100, 100) for α in A], linestyle=:dash, markershape=:dtriangle, label="Interim SARSA")
7
plot!(fig_6_3, A, [repeated_run(α, :ExpectedSARSA, 100, 100) for α in A], linestyle=:dash, markershape=:cross, label="Interim ExpectedSARSA")
8
9
plot!(fig_6_3, A, [repeated_run(α, :SARS, 10, 5000) for α in A], linestyle=:solid ,markershape=:rect, label="Asymptotic interim Q")
10
plot!(fig_6_3, A, [repeated_run(α, :SARSA, 10, 5000) for α in A], linestyle=:solid, markershape=:dtriangle, label="Asymptotic SARSA")
11
plot!(fig_6_3, A, [repeated_run(α, :ExpectedSARSA, 10, 5000) for α in A], linestyle=:solid, markershape=:cross, label="Asymptotic ExpectedSARSA")
12
fig_6_3
13
end