Skip to content

Commit a8b146f

Browse files
antmarakisnorvig
authored andcommitted
Implementation: Transition Model for MDP (aimacode#445)
* Update test_mdp.py * Update mdp.py
1 parent 0cd40a8 commit a8b146f

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

mdp.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Markov Decision Processes (Chapter 17)
22
33
First we define an MDP, and the special case of a GridMDP, in which
4-
states are laid out in a 2-dimensional grid. We also represent a policy
4+
states are laid out in a 2-dimensional grid. We also represent a policy
55
as a dictionary of {state:action} pairs, and a Utility function as a
6-
dictionary of {state:number} pairs. We then define the value_iteration
6+
dictionary of {state:number} pairs. We then define the value_iteration
77
and policy_iteration algorithms."""
88

99
from utils import argmax, vector_add
@@ -17,32 +17,37 @@ class MDP:
1717
"""A Markov Decision Process, defined by an initial state, transition model,
1818
and reward function. We also keep track of a gamma value, for use by
1919
algorithms. The transition model is represented somewhat differently from
20-
the text. Instead of P(s' | s, a) being a probability number for each
20+
the text. Instead of P(s' | s, a) being a probability number for each
2121
state/state/action triplet, we instead have T(s, a) return a
22-
list of (p, s') pairs. We also keep track of the possible states,
22+
list of (p, s') pairs. We also keep track of the possible states,
2323
terminal states, and actions for each state. [page 646]"""
2424

25-
def __init__(self, init, actlist, terminals, gamma=.9):
25+
def __init__(self, init, actlist, terminals, transitions={}, states=set(), gamma=.9):
26+
if not (0 <= gamma < 1):
27+
raise ValueError("An MDP must have 0 <= gamma < 1")
28+
2629
self.init = init
2730
self.actlist = actlist
2831
self.terminals = terminals
29-
if not (0 <= gamma < 1):
30-
raise ValueError("An MDP must have 0 <= gamma < 1")
32+
self.transitions = transitions
33+
self.states = states
3134
self.gamma = gamma
32-
self.states = set()
3335
self.reward = {}
3436

3537
def R(self, state):
36-
"Return a numeric reward for this state."
38+
"""Return a numeric reward for this state."""
3739
return self.reward[state]
3840

3941
def T(self, state, action):
40-
"""Transition model. From a state and an action, return a list
42+
"""Transition model. From a state and an action, return a list
4143
of (probability, result-state) pairs."""
42-
raise NotImplementedError
44+
if(self.transitions == {}):
45+
raise ValueError("Transition model is missing")
46+
else:
47+
return self.transitions[state][action]
4348

4449
def actions(self, state):
45-
"""Set of actions that can be performed in this state. By default, a
50+
"""Set of actions that can be performed in this state. By default, a
4651
fixed list of actions, except for terminal states. Override this
4752
method if you need to specialize by state."""
4853
if state in self.terminals:
@@ -53,9 +58,9 @@ def actions(self, state):
5358

5459
class GridMDP(MDP):
5560

56-
"""A two-dimensional grid MDP, as in [Figure 17.1]. All you have to do is
61+
"""A two-dimensional grid MDP, as in [Figure 17.1]. All you have to do is
5762
specify the grid as a list of lists of rewards; use None for an obstacle
58-
(unreachable state). Also, you should specify the terminal states.
63+
(unreachable state). Also, you should specify the terminal states.
5964
An action is an (x, y) unit vector; e.g. (1, 0) means move east."""
6065

6166
def __init__(self, grid, terminals, init=(0, 0), gamma=.9):

tests/test_mdp.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,17 @@ def test_best_policy():
2525
assert sequential_decision_environment.to_arrows(pi) == [['>', '>', '>', '.'],
2626
['^', None, '^', '.'],
2727
['^', '>', '^', '<']]
28+
29+
30+
def test_transition_model():
31+
transition_model = {
32+
"A": {"a1": (0.3, "B"), "a2": (0.7, "C")},
33+
"B": {"a1": (0.5, "B"), "a2": (0.5, "A")},
34+
"C": {"a1": (0.9, "A"), "a2": (0.1, "B")},
35+
}
36+
37+
mdp = MDP(init="A", actlist={"a1","a2"}, terminals={"C"}, states={"A","B","C"}, transitions=transition_model)
38+
39+
assert mdp.T("A","a1") == (0.3, "B")
40+
assert mdp.T("B","a2") == (0.5, "A")
41+
assert mdp.T("C","a1") == (0.9, "A")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy