#include <assert.h>
#include "Network.h"


const int VERSION = 2;
const int INIT = 0;
const int ACK = 1;
const int ALIVE = 2;
const int DATA = 3;
const int TERMINATE = 4;

const int PACKET_HEADER_SIZE = 2;

const int MAX_SEQ = 128;
const int MAX_RETRANSMIT = 3;


sem_t monitorLock;

int createSocketID() {
  int sockfd;
  sockaddr_in client_addr;

  bzero((char*)&client_addr, sizeof(sockaddr_in));
  client_addr.sin_family = AF_INET;
  client_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  client_addr.sin_port = htonl(0);

  if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) {
    perror("Client: Cannot open socket");
    exit(1);
  }

  if (bind(sockfd, (sockaddr*)&client_addr, sizeof(sockaddr_in)) < 0) {
    perror("Client: Cannot bind with local address");
    exit(1);
  }

  return sockfd;
}


void *network_packetfunc(char* args) {
  int size;
  char* data = NULL;
  BufferQueue* sendBuffer;
  TransmissionTable* table;
  sem_t* timeoutLock;

  memcpy((char*)&size, (char*)args, sizeof(int));
  memcpy((char*)data, (char*)(args + sizeof(int)), size);
  memcpy((char*)&sendBuffer, (char*)(args + sizeof(int) + size), sizeof(BufferQueue*));
  memcpy((char*)&table,
	 (char*)(args + sizeof(int) + size + sizeof(BufferQueue*)),
	 sizeof(TransmissionTable*));
  memcpy((char*)&timeoutLock, 
	 (char*)(args + sizeof(int) + size + sizeof(BufferQueue*) + 
		 sizeof(TransmissionTable*)),
	 sizeof(sem_t*));



  return NULL;
}

void *network_sendfunc(char* args) {

  int sockfd;
  sockaddr_in monitorAddress;
  BufferQueue* sendBuffer;
  TransmissionTable *table;
  sem_t* timeoutLock;

  memcpy((char*)&sockfd, (char*)args, sizeof(int));
  memcpy((char*)&monitorAddress, (char*)(args + sizeof(int)), sizeof(sockaddr_in));
  memcpy((char*)&sendBuffer,
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in)), sizeof(BufferQueue*));
  memcpy((char*)&table, 
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in) + sizeof(BufferQueue*)),
	 sizeof(TransmissionTable*));
  memcpy((char*)&timeoutLock,
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in) + sizeof(BufferQueue*) +
		 sizeof(TransmissionTable*)),
	 sizeof(sem_t*));

  while (1) {
    sem_wait(timeoutLock);
    Packet packet;

    if (table->getRetransmitCount() == -1) { // refresh packet
      sendBuffer->checkIfExist();
      sendBuffer->getHead(packet);
      table->resetRetransmitCount();
    }
    else {
      sendBuffer->getHead(packet);
    }
    
    // Send to Monitor
    //cout<<"Send: "<<(int)packet.data[0]<<"  "<<(int)packet.data[1]<<endl;

    if (sendto(sockfd, (char*)(packet.data), packet.packetSize, 0, 
	       (const sockaddr*)&monitorAddress, sizeof(sockaddr_in)) != 
	packet.packetSize) {
      perror("Client: sendto");
      exit(1);
    }

    int type = (int)(packet.data[0] & 0xF);

    switch (type) {
    case ACK:
      {  // Just send and no need to retain any states
	sendBuffer->removeHead();
	table->setTimeoutValue(-1);
	table->setRetransmitCount(-1);
	sem_post(timeoutLock);
      }
      break;
    case DATA: 
    case INIT:
      {
	int retransmit = table->getRetransmitCount();

	if (retransmit != -1) {
	  if (retransmit < MAX_RETRANSMIT) {
	    table->setRetransmitCount(retransmit + 1);
	    table->setTimeoutValue(3);
	  }
	  else { // take it out and ignore it
	    sendBuffer->removeHead();
	    table->setTimeoutValue(-1);
	    table->setRetransmitCount(-1);
	    sem_post(timeoutLock);
	  }
	}
	break;
      }
    default:
      cerr << "should not reach here...\n";
      //assert(0);
      //sem_post(timeoutLock);
    }
    cout<<"Packet sent"<<endl; // yhchu XXX
    
  }

  return NULL;
}

void *network_recvfunc(char* args) {

  int sockfd;
  sockaddr_in monitorAddress;
  BufferQueue* sendBuffer;
  TransmissionTable *table;
  sem_t* timeoutLock;

  memcpy((char*)&sockfd, (char*)args, sizeof(int));
  memcpy((char*)&monitorAddress, (char*)(args + sizeof(int)), sizeof(sockaddr_in));
  memcpy((char*)&sendBuffer,
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in)), sizeof(BufferQueue*));
  memcpy((char*)&table, 
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in) + sizeof(BufferQueue*)),
	 sizeof(TransmissionTable*));
  memcpy((char*)&timeoutLock,
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in) + sizeof(BufferQueue*) +
		 sizeof(TransmissionTable*)),
	 sizeof(sem_t*));

  while (1) {
    Packet packet;
    // TODO: Testing
    packet.setSize(20);

    struct sockaddr_in sock;
    socklen_t socklen = sizeof(sock);
    int number_bytes = recvfrom(sockfd, (char*)(packet.data), 
				packet.packetSize, 0, 
				(struct sockaddr *)&sock, &socklen);

    if (number_bytes < 0) {
      perror("client: recvfrom");
      exit(1);
    }

    // Process the packket
    cout<<"Received: "<<(int)packet.data[0]<<"  "<<(int)packet.data[1]<<endl;


    int version = (packet.data[0] & 0xF0) >> 4;
    if (version != VERSION)
      continue;
    
    // Process the received packet data
    int type = (int)(packet.data[0] & 0xF);

    switch (type) {
    case INIT:
      {
	// Extract the parameter
	sendBuffer->removeHead();
	table->setTimeoutValue(-1);
	table->setRetransmitCount(-1);

	sem_post(&monitorLock);
	sem_post(timeoutLock);
      }
      break;
    case ACK:
      {
	int ack = (int)(packet.data[1]);
	Packet curPac;
	sendBuffer->getHead(curPac);
	int curAck = (int)(curPac.data[1]);
	//cout<<"Ack received: "<<ack<<endl;
	//cout<<"  waiting "<<curAck<<endl;
	
	if (ack == curAck) { // ack match with sent packet
	  sendBuffer->removeHead();
	  table->setTimeoutValue(-1);
	  table->setRetransmitCount(-1);
	  sem_post(timeoutLock);
	}  // else if previous ack, simply ignore it
      }
      break;
    case ALIVE:
      {
	int expectingSeq = table->getRecvSeqNum();
	int in_seq = (int)(packet.data[1]);

	//cout<<"Alive Message Received: "<<in_seq<<endl;
	//cout<<"   expected: "<<expectingSeq<<endl;
	
	if (in_seq == expectingSeq) { // if it is expecting one
	  // Send back ack to reply
	  Packet packetAck;
	  createAckPacket(packetAck, in_seq);
	  table->setRecvSeqNum((in_seq + 1) % MAX_SEQ);
	  sendBuffer->enqueue(packetAck);
	}
	else if (in_seq + 1 % MAX_SEQ == expectingSeq) { 
	  // if it is last one... due to missing ack
	  // Only send back ack
	  Packet packetAck;
	  createAckPacket(packetAck, in_seq);
	  sendBuffer->enqueue(packetAck);
	}
	else {
	  // probably disconnection for a period
	  // take appropriate action by invoke upper later
	  Packet packetAck;
	  createAckPacket(packetAck, in_seq);
	  table->setRecvSeqNum((in_seq + 1) % MAX_SEQ);
	  sendBuffer->enqueue(packetAck);
	}
      }
      break;
    default:
      cerr << "should not reach here network_recvfun\n";
      assert(0);
    }
    
  }
  
  return NULL;
}

void *network_retransmitfunc(char* args) {
  //cout<<"Retransmit"<<endl;

  int sockfd;
  sockaddr_in monitorAddress;
  BufferQueue* sendBuffer;
  TransmissionTable *table;
  sem_t* timeoutLock;

  memcpy((char*)&sockfd, (char*)args, sizeof(int));
  memcpy((char*)&monitorAddress, (char*)(args + sizeof(int)), sizeof(sockaddr_in));
  memcpy((char*)&sendBuffer,
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in)), sizeof(BufferQueue*));
  memcpy((char*)&table, 
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in) + sizeof(BufferQueue*)),
	 sizeof(TransmissionTable*));
  memcpy((char*)&timeoutLock,
	 (char*)(args + sizeof(int) + sizeof(sockaddr_in) + sizeof(BufferQueue*) +
		 sizeof(TransmissionTable*)),
	 sizeof(sem_t*));

  while(1) {
    if (table->getTimeoutValue() != -1) {// in transmission
      table->setTimeoutValue(table->getTimeoutValue() - 1);
      //cout<<" ** countdown **"<<endl;
      if (table->getTimeoutValue() == 0)
	sem_post(timeoutLock);
    }
    //cout<<"---------- timeout ------------"<<endl;
    sleep(1);
  }

  return NULL;
}
  


void initializeNetworkComponents(int sockfd, char* monitorHost, int port, 
				 BufferQueue *sendBuffer,
				 TransmissionTable* t) {

  sockaddr_in monitorAddress;
  hostent* monitor;
  TransmissionTable* table = t;

  bzero((char*)&monitorAddress, sizeof(sockaddr_in));
  if ((monitor = gethostbyname(monitorHost)) == NULL) {
    fprintf(stderr, "Cannot get address for host %s\n", monitorHost);
    exit(1);
  }

  memcpy(&(monitorAddress.sin_addr.s_addr), monitor->h_addr, 
	 monitor->h_length);
  monitorAddress.sin_family = AF_INET;
  monitorAddress.sin_port = htons(port);

  // create semaphore
  sem_t* timeoutLock = new sem_t;
  sem_init(timeoutLock, 0, 1);

  // create arguments for threads
  char* args = new char[sizeof(int) + sizeof(sockaddr_in) + 
		       sizeof(BufferQueue*) + sizeof(TransmissionTable*) + 
                       sizeof(sem_t*)];
  memcpy((char*)args, (char*)&sockfd, sizeof(int));
  memcpy((char*)(args + sizeof(int)), (char*)&monitorAddress, sizeof(sockaddr_in));
  memcpy((char*)(args + sizeof(int) + sizeof(sockaddr_in)),
	 (char*)&sendBuffer, sizeof(BufferQueue*));
  memcpy((char*)(args + sizeof(int) + sizeof(sockaddr_in) + sizeof (BufferQueue*)),
	 (char*)&table, sizeof(TransmissionTable*));
  memcpy((char*)(args + sizeof(int) + sizeof(sockaddr_in) + sizeof(BufferQueue*) +
		 sizeof(TransmissionTable*)),
	 (char*)&timeoutLock, sizeof(sem_t*));

  // Initialize Threads
  pthread_t recvThread;
  pthread_t sendThread;
  pthread_t retransmitThread;

  pthread_create(&recvThread, NULL, network_recvfunc, (void*)args);
  pthread_create(&sendThread, NULL, network_sendfunc, (void*)args);
  pthread_create(&retransmitThread, NULL, network_retransmitfunc, (void*)args);

}


void createAckPacket(Packet& packet, int recvSeq) {
  packet.setSize(PACKET_HEADER_SIZE);
  packet.data[0] = ((VERSION & 0xF) << 4) | (ACK & 0xF);
  packet.data[1] = (char)recvSeq;
}

void createDataPacket(Packet& packet, int seqNum, char* data, int len) {
  packet.setSize(PACKET_HEADER_SIZE + len);
  packet.data[0] = ((VERSION & 0xF) << 4) | (DATA & 0xF);
  packet.data[1] = (char)seqNum;
  memcpy((char*)(packet.data + 2), (char*)data, len);

}

void createInitPacket(Packet& packet) {
  packet.setSize(PACKET_HEADER_SIZE);
  packet.data[0] = (char)((VERSION & 0xF) << 4);
  
}
  
void createTerminatePacket(Packet& packet) {
  packet.setSize(PACKET_HEADER_SIZE);
  packet.data[0] = (char)((VERSION & 0xF) << 4) | (TERMINATE & 0xF);
}
