Overview
This is a very basic implementation of the 3×4 grid world as used in AI-Class Week 5, Unit 9.
It uses the version of the Value Iteration equation that is given at the end of Unit 9.15, with minor modifications to conform to the algorithm as specified in Russell & Norvig, “Artificial Intelligence a Modern Approach”, 3ed Figure 17.4 p653.
With it, you can test the various scenarios outlined by Sebastian Thrun in the class. At the top of the code, you will see the following variables you can adjust:
- Ra: reward in non-terminal states (used to initialise the reward matrix R)
- gamma: discount factor
- pGood: probability of taking intended action (1-pGood is split equally between the two orthogonal actions).
Implementation
It works as follows:
- Initialise R to -3 for non-terminal states, -100/+100 for the two terminal states, and 0 for the blocked state
- Initialise V’ for all states to 0
- Repeat until convergence:
- V = V’
- Loop over all states:
- In each state s:
- V'(s) := R(s) if a terminal state
- V'(s) := The Bellman equation otherwise (computed using V(s), not V'(s))
Convergence is achieved when the max difference between V’ and V is less than the specified tolerance. I also set a maximum number of iterations.
Testing
This gives the same values for all scenarios shown by Sebastian Thrun, to the number of significant figures he shows. It also gives the same results as reported in the Russell & Norvig textbook for the settings they use, which are a little different.
The Code
You can copy this code here and paste it to a file called GridWorld2.java.
This code is may be used freely without restriction, though attribution of my authorship would be appreciated.
/** * AI-Class Unit 9 simple grid world Value Iteration. * By Michael Madden, Nov 2011. * Further details: see https://galweejit.wordpress.com. * * This version (GridWorld2) uses simultaneous updates, * as shown in AIMA 3ed Figure 17.4 p653. * * This code is may be used freely without restriction, * though attribution of my authorship would be appreciated. */ public class GridWorld2 { // General settings private static double Ra = -3; // reward in non-terminal states (used to initialise r[][]) private static double gamma = 1; // discount factor private static double pGood = 0.8; // probability of taking intended action private static double pBad = (1-pGood)/2; // 2 bad actions, split prob between them private static int N = 10000; // max number of iterations of Value Iteration private static double deltaMin = 1e-9; // convergence criterion for iteration // Main data structures private static double U[][]; // long-term utility private static double Up[][]; // UPrime, used in updates private static double R[][]; // instantaneous reward private static char Pi[][]; // policy private static int rMax = 3, cMax = 4; public static void main(String[] args) { int r,c; double delta = 0; // policy: initially null Pi = new char[rMax][cMax]; // initialise U' Up = new double[rMax][cMax]; // row, col for (r=0; r<rMax; r++) { for (c=0; c<cMax; c++) { Up[r][c] = 0; } } // Don't initialise U: will set U=Uprime in iterations U = new double[rMax][cMax]; // initialise R: set everything to Ra and then override the terminal states R = new double[rMax][cMax]; // row, col for (r=0; r<rMax; r++) { for (c=0; c<cMax; c++) { R[r][c] = Ra; } } R[0][3] = 100; // positive sink state R[1][3] = -100; // negative sink state R[1][1] = 0; // unreachable state // Now perform Value Iteration. int n = 0; do { // Simultaneous updates: set U = Up, then compute changes in Up using prev value of U. duplicate(Up, U); // src, dest n++; delta = 0; for (r=0; r<rMax; r++) { for (c=0; c<cMax; c++) { updateUPrime(r, c); double diff = Math.abs(Up[r][c] - U[r][c]); if (diff > delta) delta = diff; } } } while (delta > deltaMin && n < N); // Display final matrix System.out.println("After " + n + " iterations:\n"); for (r=0; r<rMax; r++) { for (c=0; c<cMax; c++) { System.out.printf("% 6.1f\t", U[r][c]); } System.out.print("\n"); } // Before displaying the best policy, insert chars in the sinks and the non-moving block Pi[0][3] = '+'; Pi[1][3] = '-'; Pi[1][1] = '#'; System.out.println("\nBest policy:\n"); for (r=0; r<rMax; r++) { for (c=0; c<cMax; c++) { System.out.print(Pi[r][c] + " "); } System.out.print("\n"); } } public static void updateUPrime(int r, int c) { // IMPORTANT: this modifies the value of Up, using values in U. double a[] = new double[4]; // 4 actions // If at a sink state or unreachable state, use that value if ((r==0 && c==3) || (r==1 && c==3) || (r==1 && c==1)) { Up[r][c] = R[r][c]; } else { a[0] = aNorth(r,c)*pGood + aWest(r,c)*pBad + aEast(r,c)*pBad; a[1] = aSouth(r,c)*pGood + aWest(r,c)*pBad + aEast(r,c)*pBad; a[2] = aWest(r,c)*pGood + aSouth(r,c)*pBad + aNorth(r,c)*pBad; a[3] = aEast(r,c)*pGood + aSouth(r,c)*pBad + aNorth(r,c)*pBad; int best = maxindex(a); Up[r][c] = R[r][c] + gamma * a[best]; // update policy Pi[r][c] = (best==0 ? 'N' : (best==1 ? 'S' : (best==2 ? 'W': 'E'))); } } public static int maxindex(double a[]) { int b=0; for (int i=1; i<a.length; i++) b = (a[b] > a[i]) ? b : i; return b; } public static double aNorth(int r, int c) { // can't go north if at row 0 or if in cell (2,1) if ((r==0) || (r==2 && c==1)) return U[r][c]; return U[r-1][c]; } public static double aSouth(int r, int c) { // can't go south if at row 2 or if in cell (0,1) if ((r==rMax-1) || (r==0 && c==1)) return U[r][c]; return U[r+1][c]; } public static double aWest(int r, int c) { // can't go west if at col 0 or if in cell (1,2) if ((c==0) || (r==1 && c==2)) return U[r][c]; return U[r][c-1]; } public static double aEast(int r, int c) { // can't go east if at col 3 or if in cell (1,0) if ((c==cMax-1) || (r==1 && c==0)) return U[r][c]; return U[r][c+1]; } public static void duplicate(double[][]src, double[][]dst) { // Copy data from src to dst for (int x=0; x<src.length; x++) { for (int y=0; y<src[x].length; y++) { dst[x][y] = src[x][y]; } } } }
Leave a comment