Chapter 13 Short Corridor
xxxxxxxxxx
7
1
begin
2
using ReinforcementLearning
3
using Flux
4
using Statistics
5
using Plots
6
using LinearAlgebra:dot
7
end
xxxxxxxxxx
28
1
begin
2
Base. mutable struct ShortCorridorEnv <: AbstractEnv
3
position::Int = 1
4
end
5
6
RLBase.state_space(env::ShortCorridorEnv) = Base.OneTo(4)
7
RLBase.action_space(env::ShortCorridorEnv) = Base.OneTo(2)
8
9
function (env::ShortCorridorEnv)(a)
10
if env.position == 1 && a == 2
11
env.position += 1
12
elseif env.position == 2
13
env.position += a == 1 ? 1 : -1
14
elseif env.position == 3
15
env.position += a == 1 ? -1 : 1
16
end
17
nothing
18
end
19
20
function RLBase.reset!(env::ShortCorridorEnv)
21
env.position = 1
22
nothing
23
end
24
25
RLBase.state(env::ShortCorridorEnv) = env.position
26
RLBase.is_terminated(env::ShortCorridorEnv) = env.position == 4
27
RLBase.reward(env::ShortCorridorEnv) = env.position == 4 ? 0.0 : -1.0
28
end
# ShortCorridorEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(4)`
## Action Space
`Base.OneTo(2)`
## Current State
```
1
```
xxxxxxxxxx
1
1
world = ShortCorridorEnv()
4
2
xxxxxxxxxx
1
1
ns, na = length(state_space(world)), length(action_space(world))
run_once (generic function with 1 method)
xxxxxxxxxx
11
1
function run_once(A)
2
avg_rewards = []
3
for ϵ in A
4
p = TabularRandomPolicy(;table=Dict(s => [1-ϵ, ϵ] for s in 1:ns))
5
env = ShortCorridorEnv()
6
hook=TotalRewardPerEpisode()
7
run(p, env, StopAfterEpisode(1000),hook)
8
push!(avg_rewards, mean(hook.rewards[end-100:end]))
9
end
10
avg_rewards
11
end
0.05:0.05:0.95
xxxxxxxxxx
1
1
X = 0.05:0.05:0.95
xxxxxxxxxx
1
1
plot(X, mean([run_once(X) for _ in 1:10]), legend=nothing)
REINFORCE Policy
Based on descriptions in Chapter 13.1, we need to define a new customized approximator.
xxxxxxxxxx
20
1
begin
2
Base. struct LinearPreferenceApproximator{F,O} <: AbstractApproximator
3
weight::Vector{Float64}
4
feature_func::F
5
actions::Int
6
opt::O
7
end
8
9
function (A::LinearPreferenceApproximator)(s)
10
h = [dot(A.weight, A.feature_func(s, a)) for a in 1:A.actions]
11
softmax(h)
12
end
13
14
function RLBase.update!(A::LinearPreferenceApproximator, correction::Pair)
15
(s, a), Δ = correction
16
w, x = A.weight, A.feature_func
17
w̄ = -Δ .* (x(s,a) .- sum(A(s) .* [x(s, b) for b in 1:A.actions]))
18
Flux.Optimise.update!(A.opt, w, w̄)
19
end
20
end
xxxxxxxxxx
37
1
begin
2
Base. struct ReinforcePolicy{A<:AbstractApproximator} <: AbstractPolicy
3
approximator::A
4
γ::Float64
5
end
6
7
(p::ReinforcePolicy)(env::AbstractEnv) = prob(p, state(env)) |> WeightedExplorer(;is_normalized=true)
8
9
RLBase.prob(p::ReinforcePolicy, s) = p.approximator(s)
10
11
function RLBase.update!(
12
p::ReinforcePolicy,
13
t::AbstractTrajectory,
14
::AbstractEnv,
15
::PostEpisodeStage
16
)
17
S, A, R = t[:state], t[:action], t[:reward]
18
Q, γ = p.approximator, p.γ
19
G = 0.
20
21
for i in 1:length(R)
22
s,a,r = S[end-i], A[end-i], R[end-i+1]
23
G = r + γ*G
24
25
update!(Q, (s, a) => G)
26
end
27
end
28
29
function RLBase.update!(
30
t::AbstractTrajectory,
31
::ReinforcePolicy,
32
::AbstractEnv,
33
::PreEpisodeStage
34
)
35
empty!(t)
36
end
37
end
run_once_RL (generic function with 1 method)
xxxxxxxxxx
19
1
function run_once_RL(α)
2
agent = Agent(
3
policy=ReinforcePolicy(
4
approximator=LinearPreferenceApproximator(
5
weight=[-1.47, 1.47], # init_weight
6
feature_func=(s,a) -> a == 1 ? [0, 1] : [1, 0],
7
actions=na,
8
opt=Descent(α)
9
),
10
γ=1.0
11
),
12
trajectory=VectorSARTTrajectory()
13
)
14
15
env = ShortCorridorEnv()
16
hook = TotalRewardPerEpisode()
17
run(agent,env,StopAfterEpisode(1000;is_show_progress=false),hook)
18
hook.rewards
19
end
xxxxxxxxxx
7
1
begin
2
fig_13_1 = plot(legend=:bottomright)
3
for x in [-13, -14] # for -12, it seems not that easy to converge in short time
4
plot!(fig_13_1, mean(run_once_RL(2. ^ x) for _ in 1:50), label="alpha = 2^$x")
5
end
6
fig_13_1
7
end
Interested in how to reproduce figure 13.2? A PR is warmly welcomed! See you there!