import gym

import q_learning as q

env = gym.make('CartPole-v1')

STATE_DIM = len(env.observation_space.high)
ACTIONS = env.action_space.n

alpha = 0.1
gamma = 0.95
EPISODES = 10000

# Exploration parameters
epsilon = 1

def unpack(state):
    try:
        state, _ = state
        return state
    except:
        return state
    

buckets = q.create_discretization([-4.8, -5, -.418, -5], [4.8, 5, .418, 5], 25)
q_table = q.create_q_table(25)

scores = []
data = {'episode': [], 'avg score': []}

for episode in range(EPISODES):
    state = env.reset()
    state = unpack(state)
    discrete_state = q.get_discretized_state(state, buckets)
    terminated = False
    truncated = False 

    score = 0 

    while not (terminated or truncated):
        
        # Increment score for keeping pole upright
        score += 1

        # Get action based on epsilon greedy policy
        action = q.epsilon_greedy(q_table, discrete_state, epsilon)
        
        # Take the action, and advance the environment state
        next_state, reward, terminated, truncated, _ = env.step(action)
        next_state = unpack(next_state)
        discrete_next_state = q.get_discretized_state(next_state, buckets)

        # Penalty for the pole falling over
        if terminated and not truncated:
            reward = -475

        # Update the Q-table
        q.q_update(q_table, discrete_state, action, reward, discrete_next_state, alpha, gamma)

        # Update state
        discrete_state = discrete_next_state

    scores.append(score)

    if episode % 100 == 0:
        episode_scores = scores[-100:]
        avg_score = sum(episode_scores) / len(episode_scores)
        data['episode'].append(episode)
        data['avg score'].append(avg_score)
        print("Episode:", episode, "Average score:", avg_score)

env.close()
