import torch.nn as nn
import torch.nn.functional as F

class CubeClassifierNet(nn.Module):
  def __init__(self):
    super(CubeClassifierNet, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=9, stride=1, padding=4)
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=7, stride=1, padding=3)
    self.pool = nn.MaxPool2d(2,2)
    self.drop = nn.Dropout(p=0.5)
    self.fc1 = nn.Linear(32*60*40, 120)
    self.fc2 = nn.Linear(120, 64)
    self.fc3 = nn.Linear(64, 1)
    self.sigmoid = nn.Sigmoid()
  
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 32*60*40)
    x = F.relu(self.drop(self.fc1(x)))
    x = F.relu(self.drop(self.fc2(x)))
    x = self.fc3(x)
    x = self.sigmoid(x)
    return x
