#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/shm.h>
#include <sys/time.h>
#include <sys/select.h>
#include <sys/param.h>
#include <assert.h>

//#include <Gossip/global.h>
#include <Util/inetMisc.h>

#include "util.h"
#include "tfrcSender.h"

/* #define DEBUG_EXTENDED */

void TfrcSender::SetBwActual(int estBwHistTime, int estBwUnitTime) {
  if (bwActual != NULL) {
    delete bwActual;
  }

  /* set offset */
  bwActual = new EstimateBandwidth(estBwHistTime, estBwUnitTime);
}


void TfrcSender::SetBwExpected(int estBwHistTime, int estBwUnitTime) {
  if (bwExpected != NULL) {
    delete bwExpected;
  }

  /* set offset */
  bwExpected = new EstimateBandwidth(estBwHistTime, estBwUnitTime);
}

/* in Kbps */
int TfrcSender::GetBwActual() {
  assert(bwActual != NULL);

  return bwActual->Report();
}


int TfrcSender::GetBwExpected() {
  assert(bwExpected != NULL);

  return bwExpected->Report();
}


void TfrcSender::Print() {
  printf("\nsnd %ld, rcv %ld, crtt %ld, rtt %3.1f b_rep %.0f, b_exp %.0f, b_act %.2f, D %3.2f, adj_ts %.0f, rep_ts %.0f, round %ld, mode %d",
	 this->packets_sent, this->reports_recv,
	 this->rtt_cur, this->rtt,
	 this->b_rep, this->b_exp, this->b_act,
	 this->rate_delta, (float)this->adj_ts,
	 (float)this->rep_ts, this->round, this->mode);
}


int TfrcSender::send_packet(struct pkt_dtg *packet, int pktlen) {
  packet->mseq    = htonl(this->packets_sent);
  packet->dseq    = 0; /* unused */
  packet->ts      = htonl(get_time() / MS_TO_US);
  packet->rtt     = htonl(this->rtt_cur);
  packet->bitrate = htonl((u_int32) this->b_act);
  packet->round   = htonl(this->round);
  packet->mode    = htonl(this->mode);
  
  assert(this->raddr != 0 && this->rport != 0);
  struct sockaddr_in sin = GetSockaddr(this->raddr, this->rport);

  int ret = Sendto2(this->fd, (const char *)packet, pktlen, 0, 
		    (const struct sockaddr *)&sin, sizeof(sin), "tfrcSender");
  
  /* XXX should handle differently based on errno */
  if (ret < 0) {
    /* treat as lost packet, don't retransmit */
    if (this->mode == M_SLOWSTART) { /* kick out of SLOWSTART, we probably filled up the send buffer */
      /* XXX next receiver report will get us back to SLOWSTART!!! fix that! */
      this->mode = M_NORMAL;
      this->b_exp /= 2;
      this->b_act = this->b_exp;
      assert(this->b_act != 0);
      
      this->rate_delta = 0;
      this->max_increase = this->b_act;
      this->adj_ts = 0;
      this->force_n_mode = 1; /* prevent going back to SLOWSTART when next report arrives */
      if (verbosity > 1) {
	printf("\ncongestion control mode (send error)");
      }
    }
  } else {
    if (bwActual != NULL) bwActual->Update(ret);
  }

  return ret;
}


void TfrcSender::update_mode(u_int32 b_exp) {
  switch (this->mode) {
  case M_INITRTT:
    if (this->rtt > 0.0) {
      this->mode = M_SLOWSTART;
      this->b_act = this->ave_pkt_size / (this->rtt / MILLISEC) / 2;
      this->b_rep = this->ave_pkt_size / (this->rtt / MILLISEC) / 2;
      if (verbosity > 1) {
	printf("\nslowstart mode");
      }
    }
    break;
  case M_SLOWSTART:
    if (b_exp > 0) {
      this->mode = M_NORMAL;
      this->adj_ts = 0;
      if (verbosity > 1) {
	printf("\ncongestion control mode (update_mode)");
      }
    }
    break;
  case M_NORMAL:
    /* back to SLOWSTART mode when b_exp == 0 (that's when l_exp == 0) */
    if (b_exp == 0 && this->force_n_mode == 0) {
      this->b_exp = 0;
      this->mode = M_SLOWSTART;
      this->adj_ts = 0;
#ifdef DEBUG_TFRC
      if (verbosity > 1) {
	printf("\nreentered slowstart mode");
      }
#endif
    }
    break;
  default:
    fprintf(stderr, "\nerror invalid mode (%i)", this->mode);
    assert(0);
  }
}


void TfrcSender::update_bexp(u_int32 b_exp) {
  if (this->mode == M_NORMAL && b_exp > 0) {
    this->b_exp = (double) b_exp;
    this->b_exp = max(this->b_exp, this->min_bitrate);
    this->b_exp = min(this->b_exp, this->max_bitrate);
    
    if (this->b_act >= this->b_exp) {
      this->b_act = this->b_exp;
      this->max_increase = this->b_exp;
      this->rate_delta = 0;
      assert(this->max_increase > 0);
    } else {
      /* inc by 1 packet */
      this->rate_delta  = this->ave_pkt_size / (this->rtt / MILLISEC);
      /* limit total increase to 2 packets/rtt until new rec. reports arrive */
      this->max_increase = min(this->b_exp, this->b_act + 2*this->rate_delta);
      /* and spread increase over the approx. number of packets sent 
       * during the next RTT */
      this->rate_delta /= (this->b_act + this->rate_delta / 2) 
	/ this->ave_pkt_size * (this->rtt / MILLISEC);
      assert(this->max_increase > 0);
    }
    
#ifdef DEBUG_EXTENDED
    printf("\nTFRC %f %f + %f -> %f", (float) this->rep_ts / MILLISEC,
	   this->b_act, this->rate_delta, this->b_exp);
#endif
  }
}


/* returns two values
 * - delay when sending the packet
 * - dataSent is the returned value of Sendto
 */
int64 TfrcSender::send_single(char *buf, const int buflen, int *dataSent) {
  char buffer[MAX_UDP_PKTSIZE];
  struct pkt_dtg *packet = (struct pkt_dtg *) buffer;
  
  int pktlen = buflen+sizeof(struct pkt_dtg);
  assert(pktlen <= MAX_UDP_PKTSIZE);
  memcpy(buffer+sizeof(struct pkt_dtg), buf, buflen);

  int64 start = get_time();
  double modified_bact;
  
  if (this->cbr == 0) {
    switch (this->mode) {
    case M_INITRTT:
      break;
    case M_SLOWSTART:
                                /* inc by delta over one rtt */
      if (start > this->adj_ts + this->rtt * MS_TO_US) {
	if (this->b_rep > 0) {
	  this->b_exp = min(2 * this->b_rep, 2 * this->b_act);
	}
	if (this->b_exp > this->max_bitrate)
	  this->b_exp = this->max_bitrate;
	/* use of b_exp for interval is conservative */
	assert(this->b_exp != 0);
	assert(this->rtt != 0);
	this->rate_delta = (this->b_exp - this->b_act);
	
	/* and spread increase over packets sent during the next RTT */
	this->rate_delta /= max((this->b_act + this->rate_delta / 2) 
				/ pktlen * (this->rtt / MILLISEC), 1);
	
	if (this->rate_delta < 0.0) 
	  this->rate_delta = 0.0;
	else
	  this->adj_ts = start;
      }
      if (this->b_act + this->rate_delta < this->b_exp)
	this->b_act += this->rate_delta;
      else
	this->b_act = this->b_exp;
      break;
      assert(this->b_act != 0);
      
    case M_NORMAL:
                                /* inc smoothly each time a packet is sent */
      this->b_act = min(this->b_act + this->rate_delta, this->max_increase);
      assert(this->b_act != 0);
      break;
    default:
      fprintf(stderr, "\nerror invalid mode (%i)", this->mode);
      assert(0);
      break;
    }
  } else {
    this->b_act = this->b_exp = this->cbr;
  }
  
  assert(this->b_act != 0);
  
  *dataSent = send_packet(packet, pktlen);  /* packet len including headers */
  
  /* XXX need to keep this counting up */
  ++this->packets_sent;

  /* XXX: yhchu: figure out average pkt size 
   * arguable: may want to consider only cbr traffic */
  if (this->ave_pkt_size == MAX_UDP_PKTSIZE) {
    this->ave_pkt_size = pktlen;
  } else {
    this->ave_pkt_size = this->ave_pkt_size*0.9 + pktlen*0.1;
  }
  
  /* modify sending rate, should not influence Internet experiments but
     improve Dummynet behavior */
  if (this->rtt_cur != 0 && this->mode == M_NORMAL) {
    modified_bact = this->b_act * this->sqrt_rtt / sqrt(this->rtt_cur);
#ifdef DEBUG_EXTENDED
    printf("\n%6.3f %3.2f \t %3.2f \t %3.2f \t %3.2f", 
	   (double) start / MICROSEC, 
	   this->b_act, modified_bact, 
	   this->sqrt_rtt, sqrt(this->rtt_cur));
#endif
  } else {
    modified_bact = this->b_act;
  }
  
  /* the time it should take */
  return (int64)(pktlen * MICROSEC / modified_bact); 
}


double TfrcSender::get_max_wait() {
  double max_wait = 0;

  switch (this->mode) {
  case M_INITRTT:
    /* assume 500ms RTT for timeout while in M_INITRTT */
    max_wait = DEC_NO_REPORT * 500.0 / MILLISEC / REPORTS_PER_RTT;
    break;
  case M_SLOWSTART:
    /* use SLOWSTART_NO_REPORT instead of DEC_NO_REPORT, be more cautious */
    max_wait = (double) SLOWSTART_NO_REPORT * 
      max(max(this->rtt_cur, this->rtt) / MILLISEC / REPORTS_PER_RTT,
	  MAX_UDP_PKTSIZE / this->b_act);
    break;
  case M_NORMAL:
    /* use max (RTT_CUR, RTT, time to send one packet)
       -> avoid decrease caused by heavy RTT variations */
    max_wait = (double) DEC_NO_REPORT * 
      max(max(this->rtt_cur, this->rtt) / MILLISEC / REPORTS_PER_RTT,
	  MAX_UDP_PKTSIZE / this->b_act);
    break;
  default:
    fprintf(stderr, "\nerror invalid mode (%i)", this->mode);
    assert(0);
    break;
  }
  return max_wait;
}


/* return -1 if connection is believed to be terminated
 * 0 if pkt not for this connection
 * 1 if success
 */
int TfrcSender::recv_report_process() {
  char buffer[REP_SIZE];
  struct rep_dtg *rep = (struct rep_dtg *) buffer;
  
  struct sockaddr_in sin;
  socklen_t len = sizeof(sin);
  int datalen = Recvfrom2(this->fd, buffer, REP_SIZE, 0, 
			  (struct sockaddr *)&sin, &len);

  if (datalen < 0) {
    switch(errno) {
    case ECONNREFUSED: {/* connection refused (ICMP port unreachable) */
      if (verbosity > 0) {
	printf("\nTFRC SND %s/%d connection refused (ICMP)",
	       getInetAddrName(this->raddr), this->rport);
      }

      return -1;
      break;
    }
    default:
      assert("unhandled error code");
      break;
    }
  }

  /* adapt to receiver handoff parameters when first pkt is received */
  if (isConnected == FALSE) {
    isConnected = TRUE;
    this->raddr = htonl(sin.sin_addr.s_addr);
    this->rport = htons(sin.sin_port);
  } else {
    /* XXX: temporary fix
     * these conditions can happen due to
     * - bug in connection setup in my port of TFRC
     * - someone else is doing port scanning
     * I am not sure which case it is though at this point.
     */
    if (this->raddr != htonl(sin.sin_addr.s_addr) || 
	this->rport != htons(sin.sin_port)) {
      printf("\nTFRC SND WRN %s -> %s/%d: received bad pkt %s/%d (%d)",
	     getInetAddrName(GetMyAddr()),
	     getInetAddrName(this->raddr), this->rport,
	     getInetAddrName(htonl(sin.sin_addr.s_addr)), htons(sin.sin_port),
	     datalen);
      
      fprintf(stderr,"\nTFRC SND WRN %s -> %s/%d: received bad pkt %s/%d (%d)",
	      getInetAddrName(GetMyAddr()),
	      getInetAddrName(this->raddr), this->rport,
	      getInetAddrName(htonl(sin.sin_addr.s_addr)), htons(sin.sin_port),
	      datalen);
      return 0;
    }
  }
  
  /* ETB emulate transient behavior: do not send the pkt probabilistically */
  /* XXX should emulate: but this makes tfrc processing a bit more complex...
  if (ETBRecvDrop(htonl(sin.sin_addr.s_addr))) {
     return -1;  
  }
  */


  assert(datalen == REP_SIZE);
  
  this->rep_ts = get_time() / MS_TO_US;
  this->rtt_cur = this->rep_ts - ntohl(rep->ts);  /* current rtt */
  ++this->round; /* new report -> start new round for loss measurement */

  /* smooth out rtt */
  if (this->rtt > 0.0)
    this->rtt = (1 - RTT_DECAY) * this->rtt + RTT_DECAY * this->rtt_cur;
  else
    this->rtt = this->rtt_cur;
  
  if (this->sqrt_rtt > 0.0)
    this->sqrt_rtt = (1 - RTT_DECAY) * this->sqrt_rtt + RTT_DECAY * sqrt(this->rtt_cur);
  else
    this->sqrt_rtt = sqrt(this->rtt_cur);
  
  ++this->reports_recv;
  this->b_rep = ntohl(rep->b_rep);
  update_mode(ntohl(rep->b_exp));
  update_bexp(ntohl(rep->b_exp));

  PrintReport(rep);

  return 1;
}

void TfrcSender::PrintReport(struct rep_dtg *rep) {
  /* timestamp      : current time (ms)

     cur_rtt        : current rtt (ms)
     cur_bitrate    : current sending rate  (Kbps)
     max_tfrc_rate  : cur max tfrc rate (Kbps)
     loss_cur       : current loss rate (max 1.0)

     rtt_ave        : average rtt
     rate_ave       : average sending rate
     maxrate_ave    : average max tfrc rate
     loss_ave       : average loss
  */
  int64 cur_time = get_time();
  num_samples++;
  rtt_sum += this->rtt_cur;

  if (num_samples == 1) {
    prev_report = cur_time;
  }

  if (cur_time - prev_report < 1000*1000) {
    return;
  }
  prev_report = cur_time;
  
  double loss_smooth = ((double)(ntohl(rep->loss_smooth)))/1000000.0;
  double rtt_smooth = ((double)(ntohl(rep->rtt_smooth)))/100.0;
  double loss_ses = ((double)(ntohl(rep->loss_ses)))/1000000.0;
  
  if (verbosity > 0) {
    printf("\nTFRC_SND %.1f %s -> %s R %.1f L %.6f SB %d RB %lu MB %.0f CB %.0f SL %.6f SR %.1f SB %lu",
	   (cur_time-start_time)/1000000.0,
	   getInetAddrName(GetMyAddr()), 
	   getInetAddrName(GetRemoteAddr()), 
	   rtt_smooth,                /* 3 smooth rtt (ms, rcv) */
	   loss_smooth,               /* 4 smooth loss (<= 1.0, rcv) */
	   GetBwExpected(),
	   (long unsigned int)ntohl(rep->bw_smooth),/* current rcvd bw (Kbps)*/
	   ntohl(rep->b_exp)*8.0/1000.0,     /* 6 max bitrate */
	   ntohl(rep->b_rep)*8.0/1000.0,     /* 7 cur bitrate */
	   loss_ses,                  /* 8 ses loss (rcv) */
	   this->rtt_sum/num_samples, /* 9 session rtt (snd) */
	   (long unsigned int)ntohl(rep->bw_ses)    /* 10 session bw (rcv) */
	   );
  }
}


void TfrcSender::recv_report_timeout(double max_wait) {
  /* no receiver report for DEC_NO_REPORT x "expected time interval" 
   * -> assume congestion and reduce rate
   */
#ifdef DEBUG_EXTENDED
  printf("\ntimeout no report for %3.2fms", 
	 (float) this->rep_ts / MILLISEC, max_wait * MILLISEC);
#endif
  
  if (this->mode == M_SLOWSTART) {
    this->mode = M_NORMAL;
    /* set adj_ts to 0 so that rate_delta is reset to an appropriate value
       as soon as we get out of slowstart */
    this->adj_ts = 0;
    /* keep sending at the current rate */
    this->b_exp = this->b_act;
    this->rate_delta = 0;
    this->max_increase = this->b_act;
#ifdef DEBUG_TFRC
    printf("\ncongestion control mode (recv_report)", (float) this->rep_ts / MILLISEC);
#endif
  } else if (this->mode == M_NORMAL) {
    this->b_act = max(this->b_act / 2, this->min_bitrate);
    this->max_increase = this->b_act;
    this->rate_delta = 0;
    this->b_exp = this->b_act;
  }
  
  assert(this->max_increase > 0);
  assert(this->b_act > 0);
  
  ++this->round; /* sender reacted by halving bitrate -> start new round*/
}


TfrcSender::TfrcSender(int addr, int port) {
  this->packets_sent    = 0;
  this->reports_recv    = 0;
  this->rtt_cur         = 0;
  this->rtt             = 0.0;
  this->sqrt_rtt        = 0.0;
  this->b_rep           = 0.0;
  this->b_exp           = MAX_UDP_PKTSIZE * 2;
  this->b_act           = this->b_exp;
  this->min_bitrate     = MIN_BITRATE;
  this->max_bitrate     = MAX_BITRATE;
  this->max_increase    = this->b_exp;
  this->rate_delta      = 0.0;
  this->adj_ts          = 0;
  this->rep_ts          = 0;
  this->round           = 0;
  this->mode            = M_INITRTT;
  this->force_n_mode    = 0;
  
  this->cbr = 0;   /* 0 means not constant bitrate */

  this->timer = new TfrcTimer((int64)TFRC_SND_MAX_IDLE * (int64)MICROSEC);
  
  this->fd = Socket(AF_INET, SOCK_DGRAM, 0);
  SetsockoptReuseAddrPort(this->fd);
  SetsockNonBlocking(this->fd);

  this->laddr = INADDR_ANY;
  this->lport = 0;  /* any port */

  assert(addr != 0 && port != 0);
  this->raddr = addr;
  this->rport = port;

  struct sockaddr_in sin = GetSockaddr(this->laddr, this->lport);
  Bind(this->fd, (struct sockaddr *)&sin, sizeof(sin));
  
  /* this goes last */
  this->max_wait = get_max_wait();

  this->timeout_recv_report = get_time() + (int64)(this->max_wait*MICROSEC);
  this->timeout_send_data = get_time();

  this->bwActual = NULL;
  this->bwExpected = NULL;

  this->sendBuffer = new PacketBuffer(TFRC_SND_QLEN);
  this->isConnected = FALSE;

  this->ave_pkt_size = MAX_UDP_PKTSIZE;

  /* yhchu */
  rtt_sum = 0;
  num_samples = 0;
  start_time = 0;
}

TfrcSender::~TfrcSender() {
  Close();

  delete bwActual;
  delete bwExpected;
  delete timer;
  delete sendBuffer;
}

int TfrcSender::IsConnected() {
  return isConnected;
}


int TfrcSender::IsBufferFull() {
  return (this->sendBuffer->Size() >= TFRC_SND_QLEN);
}

int TfrcSender::SetFD(fd_set *rs, fd_set *ws) {
  int64 wait_time = 60*MICROSEC;

  FD_SET(this->fd, rs);

  if (this->sendBuffer->Size() == 0) {
    wait_time = this->timeout_recv_report - get_time();
  } else {
    wait_time = min(this->timeout_recv_report, 
		    this->timeout_send_data) - get_time();
    /* if buffer is not empty, set the write FD */
    FD_SET(this->fd, ws);
  }
  
  wait_time = min(wait_time, this->timer->Expire());

  if (wait_time < 0) wait_time = 0;

  return wait_time/1000;
}


/* return -1 if error (due to Sendto())
 * 0 if buf empty or timer not expired
 * return data len if data sent
 */
int TfrcSender::SendInternal() {
  int64 cur_time = get_time();

  /* send timer not yet expired */
  if (this->timeout_send_data > cur_time) return 0;
  
  /* buffer empty, return immediately */
  if (this->sendBuffer->Size() == 0) return 0;

  char *buf; int buflen;
  assert(this->sendBuffer->Peek(&buf, &buflen) >= 0);
  
  /* waited too long / too short */
  int dataSent;
  int64 delay = send_single(buf, buflen, &dataSent);
  

  /* if data was not sent due to Sendto(), requeue the data 
   * (implemented using peek => dequeue is txmted successfully
   */
  if (dataSent >= 0) {
    assert(this->sendBuffer->Dequeue(NULL, NULL) >= 0);
  }
  

  /* compute next eligible time */
  this->timeout_send_data += delay;
  if (this->sendBuffer->Size() == 0) {
    /* XXX allow 2 second of "advanced data" */
    this->timeout_send_data = 
      max(cur_time+delay - (int64)TFRC_SND_ADVANCE_DATA*(int64)MICROSEC, 
	  this->timeout_send_data);
  }
  
  return dataSent;
}

int TfrcSender::GetSessionDuration() {
  return (int)((get_time() - start_time)/1000000);
}


int TfrcSender::GetRemoteAddr() {
  return raddr;
}


int TfrcSender::GetRemotePort() {
  return rport;
}


/* enqueue data, -1 if overflows, or # of bytes sent
 * (can be more than buflen) 
 */
int TfrcSender::Send(const char *buf, int buflen, int priority) {
  if (start_time ==0) {
    start_time = get_time();
  }


  if (bwExpected != NULL) { bwExpected->Update(buflen); }

  if (this->sendBuffer->Enqueue(buf, buflen, priority) < 0) {
    return -1;
  }

  /* be very aggressive in clearing out the data (non-blocking) */
  int dataSent = 0;
  for (;;) {
    int ret;

    if ((ret = SendInternal()) <= 0) {
      break;
    }

    dataSent += ret;
  }

  return dataSent;
}

int TfrcSender::Process(fd_set *rs, fd_set *ws) {
  return Process(rs, ws, get_time());
}

/* return -1 if connection is believed to be terminated, 0 otherwise
 */
int TfrcSender::Process(fd_set *rs, fd_set *ws, int64 cur_time) {
#ifdef DEBUG_EXTENDED
  printTfrc();  /* yhchu */
#endif
  
  /* be very aggressive in clearing out the data (non-blocking) */
  while (TRUE) {
    int ret = SendInternal();
    if (ret == 0) break;
    if (ret < 0) {
      FD_CLR(this->fd, ws);
      break;
    }
  }

  /* report available */
  if (FD_ISSET(this->fd, rs)) {
    FD_CLR(this->fd, rs);
    int ret = recv_report_process();

    if (ret > 0) {
      this->max_wait = get_max_wait();
      this->timeout_recv_report = cur_time + (int64)(this->max_wait*MICROSEC);
      this->timer->Update();
      return 0;
    } else if (ret == 0) { 
      return 0;
    } else if (ret == -1) {
      return -1;
    } else {
      assert(0);
    }
  }
  
  /* max_wait expires */
  if (this->timeout_recv_report <= cur_time) {
    recv_report_timeout(this->max_wait);

    this->max_wait = get_max_wait();
    this->timeout_recv_report = cur_time + (int64)(this->max_wait*MICROSEC);
  }

  if (this->timer->Expire() == 0) {
    if (verbosity > 0) 
      printf("\nTFRC SND %s/%d timer expired", getInetAddrName(raddr), rport);
    return -1;
  } else {
    return 0;
  }
}


int TfrcSender::Close() {
  return close(this->fd);
}
