๐Ÿ“ฆ mguzdial3 / Simplanimo

๐Ÿ“„ Qlearner.py ยท 102 lines
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102from mdp import *
import random, pickle,copy

#Hyperparameters
totalEpisodes = 100
maxRolloutLength = 50
learningRate = 0.1
discountFactor = 0.95
epsilon = 0.85
epsilonDecay = 0.01
random.seed(1)

actions = ["r", "l", "u", "d"]
qTable = {}#state->actions->values

for i in range(0, totalEpisodes):
	mapInit = [["C","-","-"],["F","-","-"],["A","-","-"]]
	currEnv = Environment(copy.deepcopy(mapInit))
	SARs = []
	rolloutComplete = False
	rolloutIndex = 0
	totalReward = 0
	while rolloutIndex < maxRolloutLength and not rolloutComplete:
		rolloutIndex+=1

		state = State(currEnv)

		#Action Selection
		action = random.choice(actions)
		if state.state in qTable.keys():
			maxAction = action
			maxValue = -1000

			for a in actions:
				if a in qTable[state.state].keys():
					if maxValue < qTable[state.state][a]:
						maxValue = qTable[state.state][a]
						maxAction = a
			action = maxAction
		else:
			action = random.choice(actions)


		if random.random()>epsilon:
			action = random.choice(actions)
		if epsilon<1:
			epsilon+=epsilonDecay


		#s_(t+1) <- s_t
		nextEnvironment = currEnv.child()

		if action=="r":
			nextEnvironment = MoveRight(nextEnvironment)
		elif action =="l":
			nextEnvironment = MoveLeft(nextEnvironment)
		elif action =="u":
			nextEnvironment = MoveUp(nextEnvironment)
		elif action =="d":
			nextEnvironment = MoveDown(nextEnvironment)

		reward = CalculateReward(currEnv, nextEnvironment)
		totalReward+=reward

		SARs.append([state.state, action, reward])
		currEnv = nextEnvironment

		#Check if complete
		if len(currEnv.playerPos)==0 or currEnv.totalCrystals==currEnv.collectedCrystals:
			rolloutComplete = True

	print ("Episode: "+str(i)+" total reward: "+str(totalReward))

	# Q-update
	for j in range(len(SARs)-1, 0, -1):
		oldQValue = 0#assume 0 initialization
		optimalFutureValue = -10

		if SARs[j][0] in qTable.keys():
			if SARs[j][1] in qTable[SARs[j][0]].keys():
				oldQValue = qTable[SARs[j][0]][SARs[j][1]]

			if j+1<len(SARs)-1:
				for a in actions:
					if a in qTable[SARs[j+1][0]].keys():
						if optimalFutureValue < qTable[SARs[j+1][0]][a]:
							optimalFutureValue = qTable[SARs[j+1][0]][a]
			else:
				optimalFutureValue = 0

		newQValue = oldQValue + learningRate*(SARs[j][2] + discountFactor*optimalFutureValue - oldQValue)

		if not SARs[j][0] in qTable.keys():
			qTable[SARs[j][0]] = {}

		qTable[SARs[j][0]][SARs[j][1]] = newQValue

pickle.dump(qTable,open("qTable.pickle", "wb"))