import math
import re
import pandas as pd
import string

import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

torch.manual_seed(0)

class IMDB_Dataset(Dataset):
    def __init__(self, csv_file: str, vocabulary_size: int = 10000, max_len=200):
        self.df = pd.read_csv(csv_file)
        self.length = len(self.df)
        self.max_len = max_len
        self.vocab_size = vocabulary_size
        self.texts = []
        self.labels = []

        print(self.df['review'])
        self.df['review'] = self.df['review'].apply(self.preprocessText)
        # Convert sentiment to binary
        if type(self.df['sentiment'][0]) == str:
            self.df["sentiment"] = self.df["sentiment"].map({"positive": 1, "negative": 0})

        # Create a dictionary of words
        self.word_count = {}
        for review in self.df["review"]:
            for word in review.split():
                if word not in self.word_count:
                    self.word_count[word] = 1
                else:
                    self.word_count[word] += 1

        print(f"There are {len(self.word_count)} distinct words")
        # insert the padding token PAD to index 0
        self.word_count = dict([('PAD', 1), ('SOS', 1)] + sorted(self.word_count.items(), key=lambda x: x[1], reverse=True)[:vocabulary_size-2])

        # Create dictionary to convert words to integers
        self.word_dict = {word: i for i, word in enumerate(self.word_count)}
        print(f"Vocabulary size: {len(self.word_dict)}")

        for i in range(self.length):
            review_words = self.df["review"][i].split()
            indices = [1] # start with the 'SOS' token in each review
            for word in review_words:
                if word in self.word_dict:
                    indices.append(self.word_dict[word])

            label = self.df["sentiment"][i]
            self.texts.append(torch.tensor(indices[:max_len], dtype=int))
            self.labels.append(label)

    def preprocessText(self, s):
        s = s.lower().strip()
        s = re.sub(r"<br />",r" ",s)
        s = re.sub(r'(\W)(?=\1)', '', s)
        s = re.sub(r"([.!?])", r" \1", s)
        s = re.sub(r"[^a-zA-Z0-9]+", r" ", s)
        
        return s.strip() 
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]


def imdb_collate(batch):
    texts, labels = zip(*batch)
    text_pad = pad_sequence(texts, batch_first=True)
    lengths_text = torch.tensor([len(seq) for seq in texts])
    labels = torch.tensor(labels, dtype=torch.float32)
    return text_pad, lengths_text, labels


def get_dataloader(vocabulary_size, max_review_length, batch_size):
    all_dataset = IMDB_Dataset("IMDB_Dataset.csv", vocabulary_size=vocabulary_size, max_len=max_review_length)
    train_dataset, validate_dataset, _ = torch.utils.data.random_split(all_dataset, [40000, 5000, 5000])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=imdb_collate, drop_last=True)
    validate_loader = DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, collate_fn=imdb_collate, drop_last=True)

    return train_loader, validate_loader, all_dataset.word_dict
    