Q1.ipynb - Colab
Q1.ipynb - Colab
ipynb - Colab
1 import numpy as np
2 import time
1 class Agent:
2 def __init__(self):
3 self.balls_played = 0
4 self.runs_scored = 0
5 self.wickets_down = 0
6 self.last_played = 0
7 self.total_reward = 0
8
9 self.pulls = np.zeros(6, dtype=np.int32)
10 self.arm_rewards = np.zeros(6, dtype=np.int32)
11
12 self.ucb_arms = np.zeros(6, dtype=np.float32)
13
14 def kl_divergence(self, p, q):
15 if p == 0 or q == 0:
16 return 0
17 return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))
18
19 def solve_q(self, rhs, p_a):
20 if p_a == 1:
21 return 1
22
23 q_s = np.arange(p_a, 1, 0.01)
24 lhs = np.zeros(q_s.shape)
25
26 for i, q in enumerate(q_s):
27 lhs[i] = self.kl_divergence(p_a, q)
28
29 lhs_rhs = lhs - rhs
30 lhs_rhs[lhs_rhs <= 0] = np.inf
31 q = q_s[np.argmin(lhs_rhs)]
32
33 return q
34
35 def calculate_ucb(self):
36 for action in range(6):
37 p_a = self.arm_rewards[action] / self.pulls[action]
38 rhs = (
39 np.log(self.balls_played) + 3 * np.log(np.log(self.balls_played))
40 ) / self.pulls[action]
41 self.ucb_arms[action] = self.solve_q(rhs, p_a)
42
43 def get_action(self, wicket, runs_scored):
44 action = None
45
46 if self.balls_played == 0:
47 action = 0
48 self.last_played = action
49 self.balls_played += 1
50 return action
51 else:
52 self.runs_scored += runs_scored
53 self.wickets_down += wicket
54 self.total_reward += 1 - wicket
55
56 self.arm_rewards[self.last_played] += 1 - wicket
57 self.pulls[self.last_played] += 1
58
59 if self.balls_played < 6:
60 action = self.balls_played
61 else:
62 self.calculate_ucb()
63 maxucb = np.amax(self.ucb_arms)
64 indices = np.where(self.ucb_arms == maxucb)
65 action = np.amax(indices)
66
67 self.last played = action
https://colab.research.google.com/drive/1WseZcW4oq6tTOXqnuRpoeqsV--4EJkkb#printMode=true 1/3
5/12/24, 8:36 AM Q1.ipynb - Colab
67 self.last_played action
68 self.balls_played += 1
69 return action
1 class Environment:
2 def __init__(self, num_balls, agent):
3 self.num_balls = num_balls
4 self.agent = agent
5 self.__run_time = 0
6 self.__total_runs = 0
7 self.__total_wickets = 0
8 self.__runs_scored = 0
9 self.__start_time = 0
10 self.__end_time = 0
11 self.__regret_w = 0
12 self.__regret_s = 0
13 self.__wicket = 0
14 self.__regret_rho = 0
15 self.__p_out = np.array([0.001, 0.01, 0.02, 0.03, 0.1, 0.3])
16 self.__p_run = np.array([1, 0.9, 0.85, 0.8, 0.75, 0.7])
17 self.__action_runs_map = np.array([0, 1, 2, 3, 4, 6])
18 self.__s = (1 - self.__p_out) * self.__p_run * self.__action_runs_map
19 self.__rho = self.__s / self.__p_out
20
21 def __get_action(self):
22 self.__start_time = time.time()
23 action = self.agent.get_action(self.__wicket, self.__runs_scored)
24 self.__end_time = time.time()
25 self.__run_time = self.__run_time + self.__end_time - self.__start_time
26 return action
27
28 def __get_outcome(self, action):
29 pout = self.__p_out[action]
30 prun = self.__p_run[action]
31 wicket = np.random.choice(2, 1, p=[1 - pout, pout])[0]
32 runs = 0
33 if wicket == 0:
34 runs = (
35 self.__action_runs_map[action]
36 * np.random.choice(2, 1, p=[1 - prun, prun])[0]
37 )
38 return wicket, runs
39
40 def innings(self):
41 self.__total_runs = 0
42 self.__total_wickets = 0
43 self.__runs_scored = 0
44
45 for ball in range(self.num_balls):
46 action = self.__get_action()
47 self.__wicket, self.__runs_scored = self.__get_outcome(action)
48 self.__total_runs = self.__total_runs + self.__runs_scored
49 self.__total_wickets = self.__total_wickets + self.__wicket
50 self.__regret_w = self.__regret_w + (
51 self.__p_out[action] - np.min(self.__p_out)
52 )
53 self.__regret_s = self.__regret_s + (np.max(self.__s) - self.__s[action])
54 self.__regret_rho = self.__regret_rho + (
55 np.max(self.__rho) - self.__rho[action]
56 )
57 return (
58 self.__regret_w,
59 self.__regret_s,
60 self.__regret_rho,
61 self.__total_runs,
62 self.__total_wickets,
63 self.__run_time,
64 )
https://colab.research.google.com/drive/1WseZcW4oq6tTOXqnuRpoeqsV--4EJkkb#printMode=true 2/3
5/12/24, 8:36 AM Q1.ipynb - Colab
1 agent = Agent()
2 environment = Environment(1000, agent)
3 regret_w, regret_s, reger_rho, total_runs, total_wickets, run_time = (
4 environment.innings()
5 )
6
7 print(regret_w, regret_s, reger_rho, total_runs, total_wickets, run_time)
https://colab.research.google.com/drive/1WseZcW4oq6tTOXqnuRpoeqsV--4EJkkb#printMode=true 3/3