import math
import tqdm

import torch
from torch import nn

from data import get_dataloader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super(PositionalEncoding, self).__init__()
        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(max_len, d_model, requires_grad=False).to(device)

        pos = torch.arange(0, max_len)
        pos = pos.float().unsqueeze(dim=1)

        _2i = torch.arange(0, d_model, step=2).float()

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # compute positional encoding to consider positional information of words

    def forward(self, x):
        seq_len = x.shape[1]
        return self.encoding[:seq_len, :]
    
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len=500, drop_prob=0.1):
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_emb = PositionalEncoding(d_model, max_len)
        self.drop_out = nn.Dropout(p=drop_prob)

    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(tok_emb + pos_emb)

class AttentionLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        
        # TODO: Initialize, query (W_q), key (W_k), value (W_v) weight matrices and softmax

    def forward(self, inp_q, inp_k, inp_v):
        # TODO: Implement forward
        pass
    

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, d_v):
        super().__init__()
        #TODO: Initialize n_heads of AttentionLayer and linear layer
        
    def forward(self, inp_q, inp_k, inp_v):
        #TODO: Implement forward 
        pass


class Residual(nn.Module):
    def __init__(self, module, d_model, drop_p=0.1):
        super().__init__()
        self.module = module
        
        # TODO: Initialize layer normalization

    def forward(self, *inp):
        # TODO: Implement forward
        pass
    
class FeedForward(nn.Module):
    def __init__(self, d_model, d_lin):
        super().__init__()
        # TODO: Initialize feed forward network

    def forward(self, inp):
        # TODO: Implement forward
        pass
    

class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_k, d_v, d_lin):
        super().__init__()
        # TODO: Initialize MultiHeadAttention and FeedForward, and Residual layers for both

    def forward(self, inp):
        # TODO: Implement forward
        pass
    

class Encoder(nn.Module):
    def __init__(self, n_heads, d_model, d_k, d_v, d_lin, n_layers, vocab_size):
        super().__init__()
        # TODO: Initialize n_layers of EncoderLayer and TransformerEmbedding

    def forward(self, inp):
        # TODO: Implement forward, embed the input and pass it through the layers
        pass

class Transformer(nn.Module):
    def __init__(self, vocab_size, n_heads=4, d_model=128, d_k=64, d_v=64, d_lin=128, n_layers=6):
        super().__init__()
        # TODO: Initialize Transformer model with Encoder and Classifier

    def forward(self, inp):
        # TODO: Implement forward, should return the probability that the review is positive
        pass


def main():
    
    vocabulary_size = 20000
    batch_size = 128
    d_model = 64
    max_review_length = 50
    lr = 1e-4
    num_epochs = 10

    train_dataloader, test_dataloader, vocabulary = get_dataloader(vocabulary_size, max_review_length, batch_size)

    # TODO: Initialize model, optimizer, and criterion and train/test the model

if __name__ == '__main__':
    main()
    