xxxxxxxxxx
7
1
begin
2
using ReinforcementLearning
3
using Flux
4
using Statistics
5
using Plots
6
using Distributions
7
end
Again, we use a environment model to describe the Grid World in Chapter 4.2.
xxxxxxxxxx
35
1
begin
2
isterminal(s::CartesianIndex{2}) = s == CartesianIndex(1,1) || s == CartesianIndex(4,4)
3
4
function nextstep(s::CartesianIndex{2}, a::CartesianIndex{2})
5
s′ = s + a
6
if isterminal(s) || s′[1] < 1 || s′[1] > 4 || s′[2] < 1 || s′[2] > 4
7
s′ = s
8
end
9
r = isterminal(s) ? 0. : -1.0
10
[(r, isterminal(s′), LinearIndices((4,4))[s′]) => 1.0]
11
end
12
13
const ACTIONS = [
14
CartesianIndex(-1, 0),
15
CartesianIndex(1,0),
16
CartesianIndex(0, 1),
17
CartesianIndex(0, -1)
18
]
19
20
struct GridWorldEnvModel <: AbstractEnvironmentModel
21
cache
22
end
23
24
GridWorldEnvModel() = GridWorldEnvModel(
25
Dict(
26
(s, a) => nextstep(CartesianIndices((4,4))[s], ACTIONS[a])
27
for s in 1:16 for a in 1:4
28
)
29
)
30
31
(m::GridWorldEnvModel)(s, a) = m.cache[(s,a)]
32
33
RLBase.state_space(m::GridWorldEnvModel) = Base.OneTo(16)
34
RLBase.action_space(m::GridWorldEnvModel) = Base.OneTo(4)
35
end
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
xxxxxxxxxx
1
1
V = TabularVApproximator(n_state=16, opt=Descent(1.0))
TabularRandomPolicy
├─ table => Dict
└─ rng => Random._GLOBAL_RNG
xxxxxxxxxx
1
1
p = TabularRandomPolicy(table=Dict(s => fill(0.25, 4) for s in 1:16))
12
4
-1.0
false
8
1.0
4
4
-1.0
false
4
1.0
9
4
-1.0
false
5
1.0
11
1
-1.0
false
10
1.0
9
1
-1.0
false
9
1.0
3
1
-1.0
false
2
1.0
16
3
0.0
true
16
1.0
14
4
-1.0
false
10
1.0
15
4
-1.0
false
11
1.0
2
4
-1.0
false
2
1.0
8
4
-1.0
false
4
1.0
1
2
0.0
true
1
1.0
9
2
-1.0
false
10
1.0
5
1
-1.0
false
5
1.0
14
2
-1.0
false
15
1.0
16
2
0.0
true
16
1.0
3
4
-1.0
false
3
1.0
8
2
-1.0
false
8
1.0
2
3
-1.0
false
6
1.0
16
1
0.0
true
16
1.0
13
3
-1.0
false
13
1.0
10
4
-1.0
false
6
1.0
11
4
-1.0
false
7
1.0
5
4
-1.0
true
1
1.0
2
1
-1.0
true
1
1.0
1
4
0.0
true
1
1.0
15
2
-1.0
true
16
1.0
15
1
-1.0
false
14
1.0
1
1
0.0
true
1
1.0
13
1
-1.0
false
13
1.0
xxxxxxxxxx
1
1
model = GridWorldEnvModel()
0.0
-13.9993
-19.999
-21.9989
-13.9993
-17.9992
-19.9991
-19.9991
-19.999
-19.9991
-17.9992
-13.9994
-21.9989
-19.9991
-13.9994
0.0
1.0
xxxxxxxxxx
1
1
policy_evaluation!(V=V, π=p, model=model, γ=1.0)
4×4 Array{Float64,2}:
0.0 -13.9993 -19.999 -21.9989
-13.9993 -17.9992 -19.9991 -19.9991
-19.999 -19.9991 -17.9992 -13.9994
-21.9989 -19.9991 -13.9994 0.0
xxxxxxxxxx
1
1
reshape(V.table, 4,4)