Chapter 5 The Blackjack Environment
In this notebook, we'll study monte carlo based methods to play the Blackjack game.
xxxxxxxxxx
begin
using ReinforcementLearning
using Flux
using Statistics
using Plots
end
As usual, let's define the environment first. The implementation of the Blackjack environment is mainly taken from openai/gym with some necessary modifications for our following up experiments.
xxxxxxxxxx
begin
# 1 = Ace, 2-10 = Number cards, Jack/Queen/King = 10
DECK = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10]
mutable struct BlackjackEnv <: AbstractEnv
dealer_hand::Vector{Int}
player_hand::Vector{Int}
done::Bool
reward::Int
init::Union{Tuple{Vector{Int}, Vector{Int}}, Nothing}
end
function BlackjackEnv(;init=nothing)
env = BlackjackEnv([], [], false, 0., init)
reset!(env)
env
end
function RLBase.reset!(env::BlackjackEnv)
empty!(env.dealer_hand)
empty!(env.player_hand)
if isnothing(env.init)
push!(env.dealer_hand, rand(DECK))
push!(env.dealer_hand, rand(DECK))
while sum_hand(env.player_hand) < 12
push!(env.player_hand, rand(DECK))
end
else
append!(env.player_hand, env.init[1])
append!(env.dealer_hand, env.init[2])
end
env.done=false
env.reward = 0.
end
RLBase.state_space(env::BlackjackEnv) = Space([Base.OneTo(31), Base.OneTo(10), Base.OneTo(2)])
RLBase.action_space(env::BlackjackEnv) = Base.OneTo(2)
usable_ace(hand) = (1 in hand) && (sum(hand) + 10 <= 21)
sum_hand(hand) = usable_ace(hand) ? sum(hand) + 10 : sum(hand)
is_bust(hand) = sum_hand(hand) > 21
score(hand) = is_bust(hand) ? 0 : sum_hand(hand)
RLBase.state(env::BlackjackEnv) = (sum_hand(env.player_hand), env.dealer_hand[1], usable_ace(env.player_hand)+1)
RLBase.reward(env::BlackjackEnv) = env.reward
RLBase.is_terminated(env::BlackjackEnv) = env.done
function (env::BlackjackEnv)(action)
if action == 1
push!(env.player_hand, rand(DECK))
if is_bust(env.player_hand)
env.done = true
env.reward = -1
else
env.done = false
env.reward = 0
end
elseif action == 2
env.done = true
while sum_hand(env.dealer_hand) < 17
push!(env.dealer_hand, rand(DECK))
end
env.reward = cmp(score(env.player_hand), score(env.dealer_hand))
else
"unknown action"
end
end
end
# BlackjackEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`ReinforcementLearningBase.Space{Array{Base.OneTo{Int64},1}}(Base.OneTo{Int64}[Base.OneTo(31), Base.OneTo(10), Base.OneTo(2)])`
## Action Space
`Base.OneTo(2)`
## Current State
```
(13, 6, 2)
```
xxxxxxxxxx
game = BlackjackEnv()
Base.OneTo(31)
Base.OneTo(10)
Base.OneTo(2)
xxxxxxxxxx
state_space(game)
As you can see, the state_space
of the Blackjack environment has 3 discrete features. To reuse the tabular algorithms in ReinforcementLearning.jl
, we need to flatten the state and wrap it in a StateOverriddenEnv
.
#1 (generic function with 1 method)
xxxxxxxxxx
STATE_MAPPING = s -> LinearIndices((31, 10, 2))[CartesianIndex(s)]
# BlackjackEnv |> StateOverriddenEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`ReinforcementLearningBase.Space{Array{Base.OneTo{Int64},1}}(Base.OneTo{Int64}[Base.OneTo(31), Base.OneTo(10), Base.OneTo(2)])`
## Action Space
`Base.OneTo(2)`
## Current State
```
233
```
xxxxxxxxxx
world = StateOverriddenEnv(
BlackjackEnv(),
STATE_MAPPING
)
xxxxxxxxxx
RLBase.state_space(x::typeof(world)) = Base.OneTo(31* 10*2)
Base.OneTo(620)
xxxxxxxxxx
NS = state_space(world)
Figure 5.1
Agent
├─ policy => VBasedPolicy
│ ├─ learner => MonteCarloLearner
│ │ ├─ approximator => TabularApproximator
│ │ │ ├─ table => 620-element Array{Float64,1}
│ │ │ └─ optimizer => InvDecay
│ │ │ ├─ gamma => 1.0
│ │ │ └─ state => IdDict
│ │ ├─ γ => 1.0
│ │ ├─ kind => ReinforcementLearningZoo.FirstVisit
│ │ └─ sampling => ReinforcementLearningZoo.NoSampling
│ └─ mapping => Main.var"#3#4"
└─ trajectory => Trajectory
└─ traces => NamedTuple
├─ state => 0-element Array{Int64,1}
├─ action => 0-element Array{Int64,1}
├─ reward => 0-element Array{Float32,1}
└─ terminal => 0-element Array{Bool,1}
xxxxxxxxxx
agent = Agent(
policy = VBasedPolicy(
learner=MonteCarloLearner(;
approximator=TabularVApproximator(;n_state=NS, opt=InvDecay(1.0)),
γ = 1.0
),
mapping= (env, V) -> sum_hand(env.env.player_hand) in (20, 21) ? 2 : 1
),
trajectory=VectorSARTTrajectory()
)
xxxxxxxxxx
run(agent, world, StopAfterEpisode(10_000))
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
-0.577778
-0.685714
-0.673469
-0.640777
-0.583333
-0.697248
-0.768421
-0.708738
0.106918
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
xxxxxxxxxx
VT = agent.policy.learner.approximator.table
12:21
1:10
xxxxxxxxxx
X, Y = 12:21, 1:10
xxxxxxxxxx
plot(X, Y, [VT[STATE_MAPPING((i,j,1))] for i in X, j in Y],linetype=:wireframe)
xxxxxxxxxx
plot(X, Y, [VT[STATE_MAPPING((i,j,2))] for i in X, j in Y],linetype=:wireframe)
xxxxxxxxxx
# now run more simulations
run(agent, world, StopAfterEpisode(500_000))
xxxxxxxxxx
plot(X, Y, [VT[STATE_MAPPING((i,j,1))] for i in X, j in Y],linetype=:wireframe)
xxxxxxxxxx
plot(X, Y, [VT[STATE_MAPPING((i,j,2))] for i in X, j in Y],linetype=:wireframe)
Figure 5.2
In Chapter 5.3, a Monte Carlo Exploring Start method is used to solve the Blackjack game. Although several variants of monte carlo methods are supported in ReinforcementLearning.jl
package, they do not support the exploring start. Nevertheless, we can define it very easily.
xxxxxxxxxx
begin
Base. mutable struct ExploringStartPolicy{P} <: AbstractPolicy
policy::P
is_start::Bool = true
end
function (p::ExploringStartPolicy)(env::AbstractEnv)
if p.is_start
p.is_start = false
rand(action_space(env))
else
p.policy(env)
end
end
(p::ExploringStartPolicy)(s::AbstractStage, env::AbstractEnv) = p.policy(s, env)
function (p::ExploringStartPolicy)(s::PreEpisodeStage, env::AbstractEnv)
p.is_start = true
p.policy(s, env)
end
function (p::ExploringStartPolicy)(s::PreActStage, env::AbstractEnv, action)
p.policy(s, env, action)
end
end
Agent
├─ policy => QBasedPolicy
│ ├─ learner => MonteCarloLearner
│ │ ├─ approximator => TabularApproximator
│ │ │ ├─ table => 2×620 Array{Float64,2}
│ │ │ └─ optimizer => InvDecay
│ │ │ ├─ gamma => 1.0
│ │ │ └─ state => IdDict
│ │ ├─ γ => 1.0
│ │ ├─ kind => ReinforcementLearningZoo.FirstVisit
│ │ └─ sampling => ReinforcementLearningZoo.NoSampling
│ └─ explorer => GreedyExplorer
└─ trajectory => Trajectory
└─ traces => NamedTuple
├─ state => 0-element Array{Int64,1}
├─ action => 0-element Array{Int64,1}
├─ reward => 0-element Array{Float32,1}
└─ terminal => 0-element Array{Bool,1}
xxxxxxxxxx
solver = Agent(
policy = QBasedPolicy(
learner=MonteCarloLearner(;
approximator=TabularQApproximator(
;n_state=NS,
n_action=2,
opt=InvDecay(1.0)),
γ = 1.0,
),
explorer=GreedyExplorer()
),
trajectory=VectorSARTTrajectory()
)
xxxxxxxxxx
run(ExploringStartPolicy(policy=solver), world, StopAfterEpisode(10_000_000))
2×620 Array{Float64,2}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
xxxxxxxxxx
QT = solver.policy.learner.approximator.table
xxxxxxxxxx
heatmap([argmax(QT[:,STATE_MAPPING((i,j,1))]) for i in 11:21, j in Y])
xxxxxxxxxx
heatmap([argmax(QT[:,STATE_MAPPING((i,j,2))]) for i in 11:21, j in Y])
Agent
├─ policy => VBasedPolicy
│ ├─ learner => MonteCarloLearner
│ │ ├─ approximator => TabularApproximator
│ │ │ ├─ table => 620-element Array{Float64,1}
│ │ │ └─ optimizer => InvDecay
│ │ │ ├─ gamma => 1.0
│ │ │ └─ state => IdDict
│ │ ├─ γ => 1.0
│ │ ├─ kind => ReinforcementLearningZoo.FirstVisit
│ │ └─ sampling => ReinforcementLearningZoo.NoSampling
│ └─ mapping => Main.var"#17#18"
└─ trajectory => Trajectory
└─ traces => NamedTuple
├─ state => 0-element Array{Int64,1}
├─ action => 0-element Array{Int64,1}
├─ reward => 0-element Array{Float32,1}
└─ terminal => 0-element Array{Bool,1}
xxxxxxxxxx
V_agent = Agent(
policy = VBasedPolicy(
learner=MonteCarloLearner(;
approximator=TabularVApproximator(;n_state=NS, opt=InvDecay(1.0)),
γ = 1.0
),
mapping=(env, V) -> solver.policy(env)
),
trajectory=VectorSARTTrajectory()
)
xxxxxxxxxx
run(V_agent, world, StopAfterEpisode(500_000))
xxxxxxxxxx
V_agent_T = V_agent.policy.learner.approximator.table;
xxxxxxxxxx
plot(X, Y, [V_agent_T[STATE_MAPPING((i,j,1))] for i in X, j in Y],linetype=:wireframe)
xxxxxxxxxx
plot(X, Y, [V_agent_T[STATE_MAPPING((i,j,2))] for i in X, j in Y],linetype=:wireframe)
Figure 5.3
xxxxxxxxxx
static_env = StateOverriddenEnv(
BlackjackEnv(;init=([1,2], [2])),
STATE_MAPPING
);
354
xxxxxxxxxx
INIT_STATE = state(static_env)
-0.27726
xxxxxxxxxx
GOLD_VAL = -0.27726
StoreMSE
xxxxxxxxxx
Base. struct StoreMSE <: AbstractHook
mse::Vector{Float64} = []
end
xxxxxxxxxx
(f::StoreMSE)(::PostEpisodeStage, agent, env) = push!(f.mse, (GOLD_VAL - agent.policy.π_target.learner.approximator[1](INIT_STATE))^2)
#23 (generic function with 1 method)
xxxxxxxxxx
target_policy_mapping = (env, V) -> sum_hand(env.env.player_hand) in (20, 21) ? 2 : 1
xxxxxxxxxx
function RLBase.prob(
p::VBasedPolicy{<:Any,typeof(target_policy_mapping)},
env::AbstractEnv,
a
)
s = sum_hand(env.env.player_hand)
if s in (20, 21)
Int(a == 2)
else
Int(a == 1)
end
end
ordinary_mse (generic function with 1 method)
xxxxxxxxxx
function ordinary_mse()
agent = Agent(
policy=OffPolicy(
π_target = VBasedPolicy(
learner=MonteCarloLearner(
approximator=(
TabularVApproximator(;n_state=NS, opt=Descent(1.0)),
TabularVApproximator(;n_state=NS, opt=InvDecay(1.0))
),
kind=FIRST_VISIT,
sampling=ORDINARY_IMPORTANCE_SAMPLING
),
mapping=target_policy_mapping
),
π_behavior = RandomPolicy(Base.OneTo(2))
),
trajectory=VectorWSARTTrajectory()
)
h = StoreMSE()
run(agent, static_env, StopAfterEpisode(10_000), h)
h.mse
end
weighted_mse (generic function with 1 method)
xxxxxxxxxx
function weighted_mse()
agent = Agent(
policy=OffPolicy(
π_target = VBasedPolicy(
learner=MonteCarloLearner(
approximator=(
TabularVApproximator(;n_state=NS, opt=Descent(1.0)),
TabularVApproximator(;n_state=NS, opt=InvDecay(1.0)),
TabularVApproximator(;n_state=NS, opt=InvDecay(1.0))),
kind=FIRST_VISIT,
sampling=WEIGHTED_IMPORTANCE_SAMPLING
),
mapping=target_policy_mapping
),
π_behavior = RandomPolicy(Base.OneTo(2))
),
trajectory=VectorWSARTTrajectory()
)
h = StoreMSE()
run(agent, static_env, StopAfterEpisode(10_000), h)
h.mse
end
xxxxxxxxxx
begin
fig_5_3 = plot()
plot!(fig_5_3, mean(ordinary_mse() for _ in 1:100), xscale=:log10,label="Ordinary Importance Sampling")
plot!(fig_5_3, mean(weighted_mse() for _ in 1:100), xscale=:log10,label="Weighted Importance Sampling")
fig_5_3
end