import torch
import torch.utils.data.dataset
import torchvision
import torchvision.transforms
import numpy as np

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=60000, shuffle=False)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=10000, shuffle=False)

print('Loading training data and formatting...')
for (images, targets) in train_loader:
    indices_17 = np.logical_or(targets == 1, targets == 7)
    x_17 = images[indices_17]
    y_17 = targets[indices_17]

    indices_others = np.logical_and(targets != 1, targets != 7)
    x_others = images[indices_others]
    y_others = targets[indices_others]

    np.save('./data/binarymnist/train/xtrain_17.npy', x_17)
    np.save('./data/binarymnist/train/ytrain_17.npy', y_17)
    np.save('./data/binarymnist/train/xtrain_no17.npy', x_others)
    np.save('./data/binarymnist/train/ytrain_no17.npy', y_others)

print('Loading testing data and formatting...')
for (images, targets) in test_loader:
    indices_17 = np.logical_or(targets == 1, targets == 7)
    x_17 = images[indices_17]
    y_17 = targets[indices_17]

    indices_others = np.logical_and(targets != 1, targets != 7)
    x_others = images[indices_others]
    y_others = targets[indices_others]

    np.save('./data/binarymnist/test/xtest_17.npy', x_17)
    np.save('./data/binarymnist/test/ytest_17.npy', y_17)
    np.save('./data/binarymnist/test/xtest_no17.npy', x_others)
    np.save('./data/binarymnist/test/ytest_no17.npy', y_others)

print('Done!')