[Code post] Building a Snake AI with Deep Reinforcement Learning: A Step-by-Step Guide Using Keras

16/9/2024 ● 9 minutes to read

In this blog post, we will explore how to build a Snake AI using deep reinforcement learning (DRL). This is a simple project and the idea is to make you familiar with the development of DRL in a simulation (or game). In this guide, we will train a neural network to control a snake in a discrete grid, where the goal is to consume as many apples as possible. Whether you are a seasoned ML enthusiast or new to RL, this post will provide a practical, hands-on approach to creating a smart Snake AI.

Some of you may ask: "why Snake?". First of all, great question - it is like you read my mind. Seriously, the Snake game is an excellent choice for experimenting with DRL for several reasons. First, the game's simple rules and minimalistic environment make it easy to understand and implement, even for those new to AI. The small grid and clear objectives (eating apples and avoiding collisions) create a manageable problem space. Second, the Snake game involves a series of decisions that impact the game's outcome, making it a classic example for testing RL algorithms. The agent must learn to make decisions that balance immediate rewards (eating apples) with long-term survival (avoiding collisions). Third, as the snake grows, the environment becomes more complex, with fewer available spaces and more potential for self-collision. This increasing complexity challenges the agent to adapt its strategy as the game progresses, providing a rich learning experience.

Well, now when you are motivated, this is the time to plan our Snake game.

The Plan

A personal recommendation - plan your project in advance as much as possible. It will change during the implementation phase. However, you always want a clear plan in your mind, even if it is not perfect (or even good). For our Snake agent, the plan involves several key steps:

  • Environment Setup: Define the game environment using a 10x10 grid, with the snake and apple represented by different colored pixels. This will be our input space for the neural network.
  • Neural Network Design: Construct a convolutional neural network (CNN) to process the game frames. The network will be trained to predict the best action (up, down, left, right) based on the current state of the game.
  • Training Process: Implement a reinforcement learning loop where the snake explores the environment, collects rewards (eating apples), and learns from its experiences. We'll use techniques like random exploration and experience replay to enhance learning and prevent overfitting.
  • Evaluation and Fine-Tuning: Once the initial training is complete, we will evaluate the agent's performance and adjust hyperparameters or network architecture as needed to improve results.
  • Visualization and Testing: Finally, we will visualize the agent's decision-making process by running test games and observing how it adapts to different scenarios. This step will help us understand the AI's strengths and weaknesses.

Based on these components and since I like OOP (object oriented programing) too much (I came from C# to Python, hard to change good habits), lets define the classes we will need for this project.

  • main: An entry point for the project to run it
  • Apple: The apple the snake will eat.
  • Environment: The board of the game.
  • Pixel: A spesific location of the board.
  • Snake: The snake player.
  • SnakeGame: The simulation class runs the logic by using the previous classes.
  • Memory: A class response to recall the actions done by the Snake player.
  • KAgent: Our DRL agent's brain.

For this project, we will use Pygame for the simulation just to get the visualization logic easier and Keras for the DRL. Some of you would like to use Pytorch which is fine, I just prefer Keras personally. A quick reminder - Pygame is a Python library designed for creating games and multimedia applications. It provides modules to handle various aspects of game development, such as graphics rendering, sound, and user input. With Pygame, you can easily create and control 2D game environments, manage sprites, and capture keyboard or mouse events. Keras is a high-level deep learning library written in Python that simplifies the process of building and training neural networks. It provides an easy-to-use interface for creating complex models with minimal code, allowing developers to quickly prototype and experiment with different architectures. Keras runs on top of popular deep learning frameworks like TensorFlow, Theano, or CNTK, offering flexibility and scalability.

Lets Code

Let us start with the easy part - the game. The Snake game is a classic arcade game where the player controls a snake that moves around the screen, eating apples. Each time the snake's head hits the apple, it grows by one pixel. The goal is to keep the snake alive for as long as possible by avoiding collisions with the walls and itself while eating as many apples as possible.


pixel.py


class Pixel(object):

    BLANK = 0
    WALL  = 1
    APPLE = 2
    SNAKE_HEAD = 3
    SNAKE_BODY = 4

    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        return NotImplemented

    def __ne__(self, other):
        if isinstance(other, self.__class__):
            return not self.__eq__(other)
        return NotImplemented

    def __add__(self, other):
        if isinstance(other, self.__class__):
            return Pixel(self.x + other.x, self.y + other.y)
        return NotImplemented

    def __sub__(self, other):
        if isinstance(other, self.__class__):
            return Pixel(self.x - other.x, self.y - other.y)
        return NotImplemented

    def __hash__(self):
        return hash(tuple(sorted(self.__dict__.items)))

    def __str__(self):
        return str(self.__dict__.items())

apple.py


import random
from env.pixel import Pixel
from env.config import *


class Apple(object):

    def __init__(self):
        self.location = Pixel(0, 0)

    def reposition(self, snake):
        pxs = []
        for x in range(WALL_THICKNESS, SCREEN_WIDTH // PIXEL_SIZE - WALL_THICKNESS):
            for y in range(WALL_THICKNESS, SCREEN_HEIGHT // PIXEL_SIZE - WALL_THICKNESS):
                pxs.append(Pixel(x, y))
        
        for i in range(0, len(snake.body) - 1):
            px = Pixel(snake.body[i].x, snake.body[i].y)
            if px in pxs:
                pxs.remove(px)
        
        new_position_index = random.randint(WALL_THICKNESS, len(pxs) - 1)
        new_position = pxs[new_position_index]
        self.location = new_position

snake.py


import pygame
import random
from pixel import Pixel

head = 0
tail = -1
init_length = 3

class Snake(object):

    def __init__(self):
        a = LEFT
        self.body = []
        self._born()


    def _born(self):
        head_x = random.randint(WALL_THICKNESS + 1, SCREEN_WIDTH // PIXEL_SIZE - init_length - WALL_THICKNESS)
        head_y = random.randint(WALL_THICKNESS + 1, SCREEN_HEIGHT // PIXEL_SIZE - init_length - WALL_THICKNESS)
        self.body.insert(0, Pixel(head_x, head_y))
        for i in range(init_length - 1):
            part = Pixel(self.body[head].x + i + 1, self.body[head].y)
            self.body.append(part)
    
    @property
    def head(self):
        return self.body[head]


    @property
    def direction(self):
        if self.body[head].x < self.body[head + 1].x:
            return LEFT
        elif self.body[head].x > self.body[head + 1].x:
            return RIGHT
        elif self.body[head].y < self.body[head + 1].y:
            return UP
        else:
            return DOWN


    def growup(self):
        self.body.append(self.body[tail])
    

    def move(self):
        current_direction = self.direction
        del self.body[tail]
        if current_direction == LEFT:
            new_head = Pixel(self.body[head].x - 1, self.body[head].y)
            self.body.insert(0, new_head)
        elif current_direction == RIGHT:
            new_head = Pixel(self.body[head].x + 1, self.body[head].y)            
            self.body.insert(0, new_head)
        elif current_direction == UP:
            new_head = Pixel(self.body[head].x, self.body[head].y - 1)            
            self.body.insert(0, new_head)
        else:
            new_head = Pixel(self.body[head].x, self.body[head].y + 1)                
            self.body.insert(0, new_head)
    

    def turn(self, direction):
        if direction == LEFT:
            self.turn_to_left()
            return True
        elif direction == RIGHT:
            self.turn_to_right()
            return True
        elif direction == UP:
            self.turn_to_up()
            return True
        else:
            self.turn_to_down()
            return True
        return False
    

    def turn_to_up(self):
        del self.body[-1]
        new_head = Pixel(self.body[head].x, self.body[head].y - 1)
        self.body.insert(0, new_head)


    def turn_to_down(self):
        del self.body[-1]
        new_head = Pixel(self.body[head].x, self.body[head].y + 1)
        self.body.insert(0, new_head)


    def turn_to_left(self):
        del self.body[-1]
        new_head = Pixel(self.body[head].x - 1, self.body[head].y)
        self.body.insert(0, new_head)


    def turn_to_right(self):
        del self.body[-1]
        new_head = Pixel(self.body[head].x + 1, self.body[head].y)
        self.body.insert(0, new_head)
    

    def turn_left(self):
        direction = self.direction
        if direction == LEFT:
            self.turn_to_down()
        elif direction == UP:
            self.turn_to_left()
        elif direction == RIGHT:
            self.turn_to_up()
        elif direction == DOWN:
            self.turn_to_right()


    def turn_right(self):
        direction = self.direction
        if direction == LEFT:
            self.turn_to_up()
        elif direction == UP:
            self.turn_to_right()
        elif direction == RIGHT:
            self.turn_to_down()
        elif direction == DOWN:
            self.turn_to_left()

environment.py


import numpy as np
from pixel import Pixel

class Environment(object):

    def __init__(self, width, height, pixel_size):
        self.pixels = None
        self.width = width
        self.height = height
        self.pixel_size = pixel_size
        self._build()
    

    def _build(self):
        w = self.width // self.pixel_size
        h = self.height // self.pixel_size
        self.pixels = np.zeros((w, h))
        for i in range(0, w):
            self.pixels[i, 0] = Pixel.WALL
            self.pixels[i, h - 1] = Pixel.WALL
        for j in range(0, h):
            self.pixels[0, j] = Pixel.WALL
            self.pixels[w - 1, j] = Pixel.WALL


    def read_pixel(self, x, y):
        return self.pixels[x, y]


    def write_pixel(self, px, px_type):
        self.pixels[px.x, px.y] = px_type


    @property
    def pixel_total_count(self):
        w, h = np.shape(self.pixels)
        return w * h


    @property
    def shape(self):
        return np.shape(self.pixels)

    
    def reset(self):
        self._build()

game.py


# library imports
import sys
import pygame
import numpy as np
from pygame.locals import *

# project imports
from pixel import Pixel
from snake import Snake
from apple import Apple
from environment import Environment

# config #

SCREEN_WIDTH = 200
SCREEN_HEIGHT = 200
PIXEL_SIZE = 20

WALL_THICKNESS = 1

LEFT = 0
RIGHT = 1
UP = 2
DOWN = 3

MOVE_ON = 0
TURN_LEFT = 1
TURN_RIGHT = 2

SNAKE_ACTIONS = [
    MOVE_ON,
    TURN_LEFT,
    TURN_RIGHT,
]

# end - config #

class Colors(object):

    BLANK = (170, 204, 153)
    WALL = (56, 56, 56)
    APPLE = (173, 52, 80)
    SNAKE_HEAD = (122, 154, 191)
    SNAKE_BODY = (105, 132, 164)


class SnakeGame(object):

    def __init__(self, is_tick=False):
        pygame.init()
        global screen, FPS
        screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
        FPS = pygame.time.Clock()
        self.is_tick = is_tick
        self._build_enviroment()


    def _build_enviroment(self):
        self.environment = Environment(SCREEN_WIDTH, SCREEN_HEIGHT, PIXEL_SIZE)
        self.snake = Snake()
        self.apple = Apple()
        self.apple.reposition(self.snake)
        self.score = 0
    
    @property
    def observation_shape(self):
        return np.shape(self.environment.pixels)

    def new_round(self):
        self._build_enviroment()
        feedback = Feedback(
            observation=np.copy(self.environment.pixels),
            reward=0,
            game_over=False
        )
        return feedback

    def step(self, action):
        for event in pygame.event.get():
            if event.type == QUIT:
                pygame.quit()
                sys.exit()

        if self.is_tick:
            FPS.tick(10)
        
        if action == MOVE_ON:
            self.snake.move()
        elif action == TURN_LEFT:
            self.snake.turn_left()
        elif action == TURN_RIGHT:
            self.snake.turn_right()

        eat_apple = self.eat_apple(self.snake, self.apple)
        game_over = self.game_is_over(self.snake)

        if game_over is False:
            self.render()

        reward = 1 if eat_apple is True else 0
        if game_over:
            reward = -1
        
        feedback = Feedback(
            observation=np.copy(self.environment.pixels),
            reward=reward,
            game_over=game_over
        )

        return feedback

    @property
    def actions_num(self):
        return len(SNAKE_ACTIONS)
    
    @property
    def current_score(self):
        return self.score    

    def draw_node(self, x, y, px):
        rect = pygame.Rect(x * PIXEL_SIZE, y * PIXEL_SIZE, PIXEL_SIZE, PIXEL_SIZE)
        if px == Pixel.WALL:
            pygame.draw.rect(screen, Colors.WALL, rect)
        elif px == Pixel.APPLE:
            pygame.draw.rect(screen, Colors.APPLE, rect)
        elif px == Pixel.SNAKE_HEAD:
            pygame.draw.rect(screen, Colors.SNAKE_HEAD, rect)
        elif px == Pixel.SNAKE_BODY:
            pygame.draw.rect(screen, Colors.SNAKE_BODY, rect)
    
    def draw_environment(self, environment):
        screen.fill((Colors.BLANK))
        w, h = environment.shape
        for i in range(w):
            for j in range(h):
                self.draw_node(i, j, environment.read_pixel(i, j))

    def render(self):
        self.update_enviroment(self.snake,self.apple, self.environment)
        self.draw_environment(self.environment)
        pygame.display.update()

    def update_enviroment(self, snake, apple, environment):
        environment.reset()
        for px in snake.body:
            environment.write_pixel(Pixel(px.x, px.y), Pixel.SNAKE_BODY)
        environment.write_pixel(snake.head, Pixel.SNAKE_HEAD)
        environment.write_pixel(Pixel(apple.location.x, apple.location.y), Pixel.APPLE)
    
    def eat_apple(self, snake, apple):
        if snake.head.x == apple.location.x and snake.head.y == apple.location.y:
            snake.growup()
            apple.reposition(snake)
            self.score += 1
            return True
        return False

    def game_is_over(self, snake):
        if snake.head.x * PIXEL_SIZE < WALL_THICKNESS * PIXEL_SIZE or snake.head.x * PIXEL_SIZE >= SCREEN_WIDTH - PIXEL_SIZE or snake.head.y * PIXEL_SIZE < WALL_THICKNESS * PIXEL_SIZE or snake.head.y * PIXEL_SIZE >= SCREEN_HEIGHT - PIXEL_SIZE:
            return True
        else:
            for part in snake.body[1:]:
                if part == snake.head:
                    return True
        return False

    def gameOver(self):
        screen.fill((0, 0, 0))
        fontObj = pygame.font.Font('freesansbold.ttf', 20)
        textSurfaceObj1 = fontObj.render('Game over!', True, (255, 0, 0))
        textRectObj1 = textSurfaceObj1.get_rect()
        textRectObj1.center = (SCREEN_WIDTH / 3, SCREEN_HEIGHT / 3)
        screen.blit(textSurfaceObj1, textRectObj1)

        textSurfaceObj2 = fontObj.render('Score: %s' % self.score, True, (255, 0, 0))
        textRectObj2 = textSurfaceObj2.get_rect()
        textRectObj2.center = (SCREEN_WIDTH*2/3, SCREEN_HEIGHT*2/3)
        screen.blit(textSurfaceObj2, textRectObj2)

        pygame.display.update()

        over = True
        while(over):
            for event in pygame.event.get():
                if event.type == QUIT:
                    pygame.quit()
                    sys.exit()


    def destroy(self):
        pygame.quit()
        sys.exit()

class Feedback(object):

    def __init__(self, observation, reward, game_over):
        self.observation = observation,
        self.reward = reward,
        self.game_over = game_over

The above code is mostly technical. Some who are not familiar with pygame may have questions about one function call or another but overall, if you read the code carefully, you should not find anything too complex. This is unlike the following code... The Memory class is designed to help a DRL agent learn more effectively by storing and recalling past experiences. In the context of training our DRL, the agent needs to remember what it has done before—what worked and what did not. This information is then used by the DRL to learn sequences of actions and their overall rewards.


memory.py


import random
import numpy as np
import collections

class Memory(object):

    def __init__(self, input_shape, num_actions, memory_size=100):
        self.memory = collections.deque() #双向队列
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.memory_size = memory_size
    
    def reset(self):
        self.memory = collections.deque()

    def store(self, state, action, reward, next_state, is_end):
        s = state.flatten()
        a = np.array(action).flatten()
        r = np.array(reward).flatten()
        s_ = next_state.flatten()
        end = 1 * np.array(is_end).flatten()
        experience = np.concatenate([s, a, r, s_, end])
        self.memory.append(experience)
        if 0 < self.memory_size < len(self.memory):
            self.memory.popleft()

    def get_batch(self, batch_size):
        batch_size = min(len(self.memory), batch_size)
        experience = np.array(random.sample(self.memory, batch_size))
        input_dim = np.prod(self.input_shape) # shape 相乘

        states = experience[:, 0:input_dim]
        actions = experience[:, input_dim]
        rewards = experience[:, input_dim+1]
        next_state = experience[:, input_dim + 2 : input_dim * 2 + 2]
        ends = experience[:, input_dim * 2 + 2]

        states = states.reshape((batch_size, ) + self.input_shape)
        next_state = next_state.reshape((batch_size, ) + self.input_shape)

        return states, actions, rewards, next_state, ends, batch_size

Finally, the code you all waited for - our DRL agent using Keras. For this project, we used CNN on the pixels of the board. This is a simple approach as time-series implementations like LSTM can work better. However, one would have a harder time defining the state space for such implementation. After some trial and error, I came up with the following NN structure:


Layer (type) Output Shape Param
conv2d_1 (Conv2D) (None, 16, 8, 8) 592
activation_1 (Activation) (None, 16, 8, 8) 0
conv2d_2 (Conv2D) (None, 32, 6, 6) 4640
activation_2 (Activation) (None, 32, 6, 6) 0
flatten_1 (Flatten) (None, 1152) 0
dense_1 (Dense) (None, 256) 295168
activation_3 (Activation) (None, 256) 0
dense_2 (Dense) (None, 3) 771

The code is as follows:


brain.py


import keras
from keras.models import Sequential
from keras.layers import *
from keras.optimizers import *
from memory import Memory


class KAgent(object):
    def __init__(self, features_shape, actions):
        self.actions = actions
        self.memory = Memory(features_shape, actions, 1000)
        from keras.models import load_model
        self.model = self.build_layers(features_shape, actions)

    def build_layers(self, features_shape, actions):
        model = Sequential()
        model.add(Conv2D(
            16,
            kernel_size=(3, 3),
            strides=(1, 1),
            data_format='channels_first',
            input_shape=features_shape))

        model.add(Activation('relu'))
        model.add(Conv2D(
            32,
            kernel_size=(3, 3),
            strides=(1, 1),
            data_format='channels_first'))
        model.add(Activation('relu'))
        model.add(Flatten())
        model.add(Dense(256))
        model.add(Activation('relu'))
        model.add(Dense(actions))
        
        model.compile(optimizer=SGD(lr=0.01, momentum=0.9, nesterov=True), loss='MSE')

        callback = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, batch_size=50, write_graph=True, write_grads=False, 
        write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)

        callback.set_model(model)

        model.summary()

        return model

    def store_memory(self, state, action, reward, state_next, game_is_over):
        self.memory.store(state, action, reward, state_next, game_is_over)

    def learn(self, discount_factor):
        states, actions, rewards, states_next, is_ends, size = self.memory.get_batch(50)
        actions = np.cast['int'](actions)
        rewards = rewards.repeat(self.actions).reshape((size, self.actions))
        is_ends = is_ends.repeat(self.actions).reshape((size, self.actions))

        X = np.concatenate([states, states_next], axis=0)
        y = self.model.predict(X)
        Q_next = np.max(y[size:], axis=1).repeat(self.actions).reshape((size, self.actions))

        delta = np.zeros((size, self.actions))
        delta[np.arange(size), actions] = 1

        targets = (1 - delta) * y[:size] + delta * (rewards + discount_factor * (1 - is_ends) * Q_next)  
        loss = self.model.train_on_batch(states, targets)

        return loss

    def predict(self, states):
        q = self.model.predict(states)[0]
        return np.argmax(q)

    def save_model(self):
        self.model.save('agent-keras-final.model')

Of course, we need to make all of this to run... So, a "main.py" file to set things going!


main.py


import collections
import numpy as np
from agent.brain_keras import KAgent
from env.game import SnakeGame
import matplotlib.pyplot as plt


def main():
    discount_factor = 0.95
    last_frames_num = 4
    actions_num = 3
    exploration_rate = 1.0
    min_exploration_rate = 0.1
    episode_num = 30000
    exploration_decay = ((exploration_rate - min_exploration_rate) / episode_num * 0.5)

    game = SnakeGame(is_tick=False)
    agent = KAgent((last_frames_num,) + game.observation_shape, actions_num)

    for episode in range(episode_num):
        random_count = 0
        predict_count = 0
        loss = 0.0
        w, h = game.observation_shape
        first_step = game.new_round()
        game_over = False        
        
        game.render()
        
        last_frames = collections.deque([first_step.observation] * last_frames_num)
        state = np.array(last_frames)
        state = np.reshape(state, (-1, last_frames_num, w, h))

        while not game_over:
            if np.random.random() < exploration_rate:
                action = np.random.randint(actions_num)
                random_count += 1
                action_type = 'random'
            else:
                action = agent.predict(state)
                predict_count += 1
                action_type = 'predict'

            one_step = game.step(action)

            # print("action_type: %s"%(action_type))

            reward = one_step.reward
            last_frames.append(one_step.observation)
            last_frames.popleft()
            next_state = np.array(last_frames)
            next_state = np.reshape(next_state, (-1, last_frames_num, w, h))
            game_over = one_step.game_over

            agent.store_memory(state, action, reward, next_state, game_over)
            loss += float(agent.learn(discount_factor))

            if game_over is True:
                log = 'episode {:5d} || exploration_rate {:.2f} || random count {:3d} || predict count {:3d}' + \
                        ' || loss {:8.4f} || score {:3d}'
                print(log.format(episode, exploration_rate, random_count, predict_count, loss, game.current_score))
                break

            state = next_state

        if exploration_rate > min_exploration_rate:
            exploration_rate -= exploration_decay

    agent.save_model()

if __name__ == '__main__':
    main()

It is important to note that you need to install several dependencies - TensorFlow, Keras, Pygame.

Continue Reading