Skip to content

Commit 6b7a9de

Browse files
committed
Added numpy based snake game
0 parents  commit 6b7a9de

File tree

4 files changed

+707
-0
lines changed

4 files changed

+707
-0
lines changed

SnakeEnv.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import numpy as np;
2+
3+
4+
class Grid:
5+
def __init__(self, length, width):
6+
self.length, self.width = length, width
7+
self.grid = np.chararray((length, width))
8+
self.grid[:] = '.'
9+
10+
def reset(self):
11+
self.grid = np.chararray((self.length,self.width));
12+
self.grid[:]='.'
13+
def itemset(self,pos,val):
14+
self.grid.itemset(pos,val)
15+
16+
def display(self):
17+
for row in self.grid:
18+
print('|'+' '.join(row.decode('utf-8'))+'|')
19+
20+
21+
class Snake:
22+
def __init__(self, pos, snake_id):
23+
self.length = 1
24+
self.head = pos
25+
self.snake_id = snake_id
26+
self.pos_ary = [pos]
27+
self.direction = np.random.randint(4)
28+
29+
def move(self, action,food_pos):
30+
self.head = tuple(map(sum, zip(self.head, self.decodeAction(action))))
31+
eaten_something = False
32+
if(self.head == food_pos):
33+
eaten_something = True
34+
self.pos_ary.append(self.head)
35+
if not eaten_something:
36+
self.pos_ary.pop(0)
37+
else:
38+
self.length += 1
39+
return eaten_something;
40+
41+
def reset(self,pos):
42+
self.length=1;
43+
self.head=pos;
44+
self.pos_ary=[pos];
45+
46+
47+
def decodeAction(self,a):
48+
dirAry=[(-1,0),(0,1),(1,0),(0,-1)] #NESW
49+
if a == 1: # Left
50+
self.direction=(self.direction+4-1)%4
51+
elif a == 2:
52+
self.direction=(self.direction+1)%4
53+
54+
return dirAry[self.direction]
55+
56+
57+
58+
class SnakeGame:
59+
def __init__(self,length,width,n_snakes=1):
60+
self.grid = Grid(length,width)
61+
self.n_snakes=n_snakes
62+
if n_snakes+1>=length*width:
63+
raise Exception('too many snakes')
64+
self.snakes=[]
65+
self.update_food()
66+
self.addSnakes()
67+
68+
def has_snake(self,pos):
69+
for i in range(len(self.snakes)):
70+
if pos in self.snakes[i].pos_ary:
71+
return True;
72+
return False;
73+
74+
def sample_empty_pos(self):
75+
pos = (np.random.randint(self.grid.length),np.random.randint(self.grid.width))
76+
while self.has_snake(pos) or pos == self.food_pos:
77+
pos = (np.random.randint(self.grid.length),np.random.randint(self.grid.width))
78+
return pos
79+
80+
def addSnakes(self):
81+
pos_ary=[]
82+
for i in range(self.n_snakes):
83+
pos = (np.random.randint(self.grid.length),np.random.randint(self.grid.width))
84+
while pos in pos_ary or (pos == self.food_pos):
85+
pos = (np.random.randint(self.grid.length),np.random.randint(self.grid.width))
86+
pos_ary.append(pos);
87+
self.snakes.append(Snake(pos,i))
88+
89+
def inside_grid(self,pos):
90+
pos = list(pos)
91+
if pos[0]>=0 and pos[1]>=0 and pos[0]<self.grid.length and pos[1]<self.grid.width:
92+
return True
93+
return False
94+
95+
def hasCollided(self):
96+
for snake in self.snakes:
97+
head, snake_id = snake.head, snake.snake_id;
98+
if not self.inside_grid(head):
99+
return True
100+
for snake in self.snakes:
101+
if snake.snake_id==snake_id:
102+
if head in snake.pos_ary[:-1]:
103+
return True
104+
else:
105+
if head in snake.pos_ary:
106+
return True
107+
108+
109+
def update_food(self):
110+
food_pos = (np.random.randint(self.grid.length),np.random.randint(self.grid.width))
111+
while self.has_snake(food_pos):
112+
food_pos = (np.random.randint(self.grid.length),np.random.randint(self.grid.width))
113+
self.food_pos = food_pos
114+
115+
116+
117+
def get_observation(self):
118+
obs = np.zeros((1,self.grid.length,self.grid.width,1),dtype=float)
119+
for x in range(self.grid.length):
120+
for y in range(self.grid.width):
121+
obs[0][x][y][0] = int(ord(self.grid.grid[x][y]))
122+
obs[obs == ord('.')] = 5
123+
obs[obs == ord('x')] = 255
124+
125+
return np.array(obs,dtype=np.uint8);
126+
127+
def display(self,verbose=False):
128+
self.grid.reset()
129+
self.grid.itemset(self.food_pos,'x')
130+
for snake in self.snakes:
131+
for pos in snake.pos_ary:
132+
self.grid.itemset(pos,'o')
133+
self.grid.itemset(snake.head, chr(48+snake.snake_id))
134+
#self.grid.itemset(snake.head, '#')
135+
if verbose == True :
136+
self.grid.display()
137+
return self.grid
138+
139+
def get_state(self):
140+
states = np.zeros((self.n_snakes,4),dtype=np.float)
141+
for i in range(self.n_snakes):
142+
states[i][0] = float(self.snakes[i].direction)/4.0 # Direction of snake head
143+
states[i][1] = float(self.snakes[i].head[0])/self.grid.length
144+
states[i][2] = float(self.snakes[i].head[1])/self.grid.width
145+
states[i][3] = float(np.sum(np.absolute(np.array(self.snakes[i].head)-np.array(self.food_pos))))/(self.grid.length+self.grid.width)
146+
return states
147+
148+
149+
def step(self,action_list,verbose=False):
150+
"""
151+
state = 4 x nsnakes info with (direction,xpos,ypos,distance from food)
152+
"""
153+
rewards,dones = np.zeros((self.n_snakes),dtype=np.float), np.zeros((self.n_snakes)).astype('bool')
154+
for i in range(self.n_snakes):
155+
eaten_food = self.snakes[i].move(action_list[i],self.food_pos)
156+
if eaten_food == True :
157+
self.update_food()
158+
rewards[i]=0.9; # Eaten food
159+
elif self.hasCollided() == True:
160+
self.snakes[i].reset(self.sample_empty_pos())
161+
rewards[i]=-0.2;
162+
dones[i]=True;
163+
else:
164+
rewards[i] = 0.1*(1.0-float(np.sum(np.absolute(np.array(self.snakes[i].head)-np.array(self.food_pos))))/(self.grid.length+self.grid.width))
165+
if action_list[i] != 0:
166+
rewards[i] = rewards[i] - 0.1
167+
168+
self.display(verbose=verbose)
169+
observation = self.get_observation();
170+
observation = np.repeat(observation,repeats=self.n_snakes,axis=0)
171+
states = self.get_state()
172+
return observation,rewards,dones,states;
173+
174+
175+
176+
def close(self):
177+
pass
178+
179+
class MultiAgentSnakeGame:
180+
def __init__(self,length,width,n_envs=1,n_snakes=1):
181+
self.n_envs = n_envs;
182+
self.n_snakes = n_snakes;
183+
self.s_games = [ SnakeGame(length,width,n_snakes=n_snakes) for i in range(n_envs)]
184+
185+
def step(self,action_list,verbose = False):
186+
action_list = action_list.reshape(self.n_envs,self.n_snakes)
187+
m_obs,m_rewards,m_dones,m_states = [],[],[],[]
188+
for i in range(self.n_envs):
189+
observation,rewards,dones,states = self.s_games[i].step(action_list[i])
190+
for j in range(self.n_snakes):
191+
m_obs.append(observation[j])
192+
m_rewards.append(rewards[j])
193+
m_dones.append(dones[j])
194+
m_states.append(states[j])
195+
if verbose == True:
196+
self.display(verbose=verbose)
197+
m_obs,m_rewards,m_dones,m_states = np.array(m_obs),np.array(m_rewards),np.array(m_dones),np.array(m_states)
198+
return m_obs,m_rewards,m_dones,m_states;
199+
200+
def get_state(self):
201+
m_states = []
202+
for i in range(self.n_envs):
203+
states = self.s_games[i].get_state()
204+
for j in range(self.n_snakes):
205+
m_states.append(states[j])
206+
return np.array(m_states)
207+
208+
def display(self,verbose=False,limit = -1):
209+
if limit == -1:
210+
limit = self.n_envs;
211+
for i in range(min(self.n_envs,limit)):
212+
print ('PLAYER : ',i)
213+
self.s_games[i].display(verbose=verbose)
214+
215+
216+
def close(self):
217+
pass
218+

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy