#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#ifdef HAVE_BSTRING_H
#include <bstring.h>
#endif
#include <errno.h>

#include <sys/types.h>
#include <sys/user.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sys/sem.h>
#include <asm/page.h>
#include <utils/Time.h>

#if defined(__GNU_LIBRARY__) && !defined(_SEM_SEMUN_UNDEFINED)
/* union semun is defined by including <sys/sem.h> */
#else
/* according to X/OPEN we have to define it ourselves */
union semun {
  int val;                    /* value for SETVAL */
  struct semid_ds *buf;       /* buffer for IPC_STAT, IPC_SET */
  unsigned short int *array;  /* array for GETALL, SETALL */
  struct seminfo *__buf;      /* buffer for IPC_INFO */
};
#endif
#define SHMMIN 1
#define SHMMAX 0x2000000

#include <ipt/ipt.h>
#include <ipt/posixmem.h>
#include <ipt/message.h>

#include "internal_messages.h"
#include "connection.h"

#include <utils/ConfigFile.h>
#include <utils/Generator.h>
#include <utils/SymbolTable.h>
#include <utils/List.h>
#include <utils/formatting/Format.h>
#include <utils/formatting/FormatParser.h>
#include <utils/formatting/UnpackAction.h>
#include <utils/formatting/PackAction.h>
#include <utils/formatting/DeleteAction.h>
#include <utils/formatting/BufferSizeAction.h>
#include <utils/Time.h>

struct IPSharedMemoryInfo {
  char* tag;
  IPSharedMemory* (*create_interface)(IPGenerator<IPSharedMemory>*,
                                      IPConfigFile*, IPSymbolTable*);
};

IPSharedMemory* create_shmem_owner(IPGenerator<IPSharedMemory>*,
                                        IPConfigFile*, IPSymbolTable*);
IPSharedMemory* create_shmem_client(IPGenerator<IPSharedMemory>*,
                                    IPConfigFile*, IPSymbolTable*);
IPSharedMemory* create_shmem_managed(IPGenerator<IPSharedMemory>*,
                                     IPConfigFile*, IPSymbolTable*);

static struct IPSharedMemoryInfo IP_shmems[] = {
  {"owner", create_shmem_owner},
  {"client", create_shmem_client},
  {"managed", create_shmem_managed},
  {0, 0}
};

#define DEFAULT_INTF "client"

static IPGenerator<IPSharedMemory>* generator()
{
  int i;
  IPGenerator<IPSharedMemory>* gen =
    new IPGenerator<IPSharedMemory>();
  for (i=0;IP_shmems[i].tag; i++) {
    gen->registerInterface(IP_shmems[i].tag,
                           IP_shmems[i].create_interface, 0L);
    if (!strcmp(IP_shmems[i].tag,DEFAULT_INTF))
      gen->registerInterface("default", IP_shmems[i].create_interface, 0L);
  }

  return gen;
}

IPGenerator<IPSharedMemory>* IPCommunicator::SharedMemoryGenerator()
{
  if (!_shmem_gen)
    _shmem_gen = ::generator();
  return _shmem_gen;
}

IPSharedMemory* IPCommunicator::OpenSharedMemory(IPConfigFile* params, 
                                                 const char* fmt, int max_size)
{
  IPSymbolTable sym_table;
  if (fmt)
    sym_table.set("format", fmt);
  if (max_size > 0)
    sym_table.set("max_size", (void*) (long) max_size);
  sym_table.set("Communicator", this);
  
  if (!_shmem_gen)
    _shmem_gen = ::generator();

  IPSharedMemory* res = _shmem_gen->interface(params, &sym_table);
  UseSharedMemory(res);
  return res;
}

void IPCommunicator::UseSharedMemory(IPSharedMemory* shmem)
{
  if (shmem)
    _shmem_list->append(shmem);
}

IPSharedMemory* IPCommunicator::OpenSharedMemory(const char* spec,
                                                 const char* fmt, int max_size)
{
  IPSymbolTable sym_table;
  if (fmt)
    sym_table.set("format", fmt);
  if (max_size > 0)
    sym_table.set("max_size", (void*) (long) max_size);
  sym_table.set("Communicator", this);
  
  if (!_shmem_gen)
    _shmem_gen = ::generator();

  IPSharedMemory* res = _shmem_gen->interface(spec, &sym_table);
  UseSharedMemory(res);
  return res;
}

void IPCommunicator::CloseSharedMemory(IPSharedMemory* mem)
{
  if (_shmem_list->remove(mem))
    delete mem;
}

IPSharedMemory::IPSharedMemory(IPCommunicator* com, const char* fmt,
                               int max_size)
{
  IPFormatParser* parser = com->FormatParser();
  _format = parser->parseString(fmt);
  if (_format)
    _format->ref();
  _max_size = max_size;
}

IPSharedMemory::~IPSharedMemory()
{
  if (_format)
    _format->unref();
  _format = NULL;
}

int IPSharedMemory::Update()
{
  int align, byte_order;
  unsigned char* data;
  int size;

  return Data(size, data, align, byte_order);
}

void* IPSharedMemory::FormattedData(int force_copy)
{
  if (!_format)
    return NULL;

  unsigned char* data;
  int size_data, alignment, byte_order;
  if (!Data(size_data, data, alignment, byte_order))
    return NULL;
  if (!size_data)
    return NULL;

  if (force_copy)
    return IPUnpackAction::unpack(_format, data, 0, alignment, byte_order);
  else
    return IPUnpackAction::unpack(_format, data, size_data,
                                  alignment, byte_order);
}

int IPSharedMemory::FormattedData(void* output, int force_copy)
{
  if (!_format)
    return 0;

  unsigned char* data;
  int size_data, alignment, byte_order;
  if (!Data(size_data, data, alignment, byte_order))
    return 0;

  if (!size_data) {
    bzero(output, _format->dataSize());
    return 1;
  }

  if (force_copy)
    return IPUnpackAction::unpack(_format, data, 0, output, 0,
                                  alignment, byte_order);
  else
    return IPUnpackAction::unpack(_format, data, size_data,
                                  output, 0, alignment, byte_order);
}

int IPSharedMemory::DeleteFormatted(void* data)
{
  return IPDeleteAction::deleteData(_format, data, GetInputArea(), _max_size);
}

int IPSharedMemory::DeleteContents(void* data)
{
  return IPDeleteAction::deleteStruct(_format, data,
                                      GetInputArea(), _max_size);
}

int IPSharedMemory::PutFormattedData(void* input)
{
  unsigned char* output = GetOutputArea();
  if (!output)
    return 0;
  
  int output_size = IPBufferSizeAction::size(_format, input);
  if (output_size > _max_size)
    return 0;

  if (!IPPackAction::pack(_format, output, output_size, input))
    return 0;
  PutData(output_size);
  return 1;
}

void IPPosixSharedMemory::setIDs(int sem_id, int shm_id)
{
  if (shm_id < 0 && _shm_id >= 0) {
    if (_shm && (-1 == shmdt ((void *)_shm))) 
      perror ("shmdt");
    _shm=NULL;
  }
    
  _sem_id = sem_id;
  _shm_id = shm_id;

  if (_sem_id < 0 || _shm_id < 0)
    return;
  
  _shm = (unsigned char*) shmat(_shm_id, 0, 0);
  if ((long)(_shm) == -1) {
    perror("IPT: Owner attaching shared memory:");
    _shm = NULL;
  }
}

void IPPosixSharedMemory::clear()
{
  if (_shm_id != -1) {
    /* detach and clear shared memory */
    if (_shm && (-1 == shmdt ((void *)_shm))) 
      perror ("shmdt");
    if (-1 == shmctl (_shm_id, IPC_RMID, NULL)) 
      perror ("shmctl");
    _shm_id = -1;
  }
  if (_sem_id != -1) {
    union semun un;
    un.val = 0;
    /* clear semaphore */
    if (-1 == semctl (_sem_id, 1, IPC_RMID, un)) 
      perror ("semctl");
    _sem_id = -1;
  }
}
  

class IPOwnedManagedMemory : public IPOwnedSharedMemory {
public:
  IPOwnedManagedMemory(IPCommunicator*, const char* fmt, int max_size,
                       int sem_id, int shm_id, int flag_sem_id);
  virtual ~IPOwnedManagedMemory();

  virtual int PutData(int size);

private:
  int _flag_sem_id;
};

class IPClientManagedMemory : public IPClientSharedMemory {
public:
  IPClientManagedMemory(IPCommunicator*, const char* fmt, int max_size,
                        int sem_id, int shm_id);
};

static int pad(int size)
{
  size += 2*sizeof(int);  /* account for tag and size -> note should
                             eventually be split into administrative page
                             followed by data page(s) */

  if (size < SHMMIN)
    size = SHMMIN;
  if (size > SHMMAX)
    size = SHMMAX;

  if (size % PAGE_SIZE) 
    size += PAGE_SIZE - size % PAGE_SIZE;

  return size;
}

IPOwnedSharedMemory::IPOwnedSharedMemory(key_t shm_key, key_t sem_key,
                                         IPCommunicator* com, const char* fmt,
                                         int max_size)
  : IPPosixSharedMemory(com, fmt, pad(max_size))
{
  max_size = MaxSize();

  _output_area = new unsigned char[max_size];
  _tag = 0;

  _shm_id = shmget(shm_key, max_size, IPC_CREAT|IPC_EXCL|0666);

  if (_shm_id == -1) {
    printf("IPT: Problem getting shared memory with key %d\n", shm_key);
    perror("shmget:");
    printf("\tI advise running 'iptmemdel %d %d'\n", shm_key, sem_key);
    _sem_id = -1;
    return;
  }

  _sem_id = semget(sem_key, 4, IPC_CREAT|IPC_EXCL|0666);
  if (_sem_id == -1) {
    printf("IPT: Problem getting semaphore with key %d\n", sem_key);
    perror("semget:");
    printf("\tI advise running 'iptmemdel %d %d'\n", shm_key, sem_key);
    if (-1 == shmctl (_shm_id, IPC_RMID, NULL)) 
      perror ("shmctl");
    _shm_id = -1;
    return;
  }
  semctl(_sem_id, 2, SETVAL, 1); // set up mutual exclusion semaphore

  setIDs(_sem_id, _shm_id);
  if (_shm) {
    bzero(_shm, MaxSize());
  }
}

IPOwnedSharedMemory::IPOwnedSharedMemory(IPCommunicator* com, const char* fmt,
                                         int max_size)
  : IPPosixSharedMemory(com, fmt, pad(max_size))
{
  max_size = MaxSize();

  _output_area = new unsigned char[max_size];
  _tag = 0;

  _shm_id = _sem_id = -1;
}

IPOwnedSharedMemory::~IPOwnedSharedMemory()
{
  clear();

  delete [] _output_area;
}

int IPOwnedSharedMemory::Data(int& size, unsigned char*& data,
                              int& alignment, int& byte_order)
{
  alignment = ALIGN;
  byte_order = BYTE_ORDER;
  size = *(((int*) _shm)+1);
  data = _shm+2*sizeof(int);

  return 1;
}

int IPOwnedSharedMemory::UpdatedTag()
{
  return *((int*)_shm);
}

int IPOwnedSharedMemory::PutData(int size)
{
  if (_sem_id == -1 || _shm_id == -1 || size > MaxSize())
    return 0;

  static sembuf op_lock[3] = {
    { 0, 0, 0 },   // wait for clients to say clear to write
    { 1, 1, SEM_UNDO },   // indicate we are writing
    { 2, -1, SEM_UNDO },  // grab condition mutex
  };
  static sembuf op_unlock[1] = {
    { 1, -1, SEM_UNDO},  // finished writing
  };
  static sembuf op_release_mutex[1] = {
    { 2, 1, SEM_UNDO }
  };
  static sembuf op_signal[2] = {
    { 2, 1, SEM_UNDO },
    { 3, -1, 0 }
  };

  while (1) {
    if (semop(_sem_id, op_lock, 3) == -1) {
      perror("owner lock semop");
      return 0;
    } else
      break;
  }

  *((int*)_shm) = ++_tag;
  *(((int*)_shm)+1) = size;
  bcopy(_output_area, _shm+2*sizeof(int), size);

  if (semop(_sem_id, op_unlock, 1) == -1) {
    perror("owner unlock semop");
    return 0;
  }

  int val = semctl(_sem_id, 3, GETVAL);
  if (val < 0) {
    perror("get val");
    return 0;
  }
  if (val > 0) {
    if (semop(_sem_id, op_signal, 2) == -1) {
      perror("signal");
      return 0;
    }
  } else {
    if (semop(_sem_id, op_release_mutex, 1) == -1) {
      perror("release mutex");
      return 0;
    }
  }

  return 1;
}

#define WAIT_INCREMENT 25000

IPClientSharedMemory::IPClientSharedMemory(key_t shm_key, key_t sem_key,
                                           float timeout,
                                           IPCommunicator* com,const char* fmt,
                                           int max_size)
  : IPPosixSharedMemory(com, fmt, pad(max_size))
{
  max_size = MaxSize();

  _cur_tag = 0;
  _input_area = new unsigned char[max_size];
  
  int delay_time = (int) (timeout*1000000);
  for (int cur_time=0;cur_time<=delay_time;cur_time+=WAIT_INCREMENT) {
    _shm_id = shmget(shm_key, max_size, 0666);
    if (_shm_id != -1)
      break;
    utils::Time::sleep(double(WAIT_INCREMENT)/1000000.0);
  }
  if (_shm_id == -1) {
    printf("IPT: Problem getting shared memory with key %d\n", shm_key);
    perror("shmget:");
    _sem_id = -1;
    return;
  }

  _sem_id = semget(sem_key, 4, 0666);
  if (_sem_id == -1) {
    printf("IPT: Problem getting semaphore with key %d\n", shm_key);
    perror("semget:");
    return;
  }

  setIDs(_sem_id, _shm_id);
}
  
IPClientSharedMemory::IPClientSharedMemory(IPCommunicator* com,const char* fmt,
                                           int max_size)
  : IPPosixSharedMemory(com, fmt, pad(max_size))
{
  max_size = MaxSize();

  _cur_tag = 0;
  _input_area = new unsigned char[max_size];

  _shm_id = _sem_id = -1;
}
  
IPClientSharedMemory::~IPClientSharedMemory()
{
  if (_shm_id != -1) {
    /* detach shared memory */
    if (_shm && (-1 == shmdt ((void *)_shm))) 
      perror ("shmdt");
  }

  delete [] _input_area;
}

int IPClientSharedMemory::lock()
{
  if (_sem_id == -1 || _shm_id == -1)
    return 0;

  static sembuf op_lock[2] = {
    { 1, 0, 0 },   // wait for owner to say clear to read
    { 0, 1, SEM_UNDO },   // indicate we are reading
  };

  if (semop(_sem_id, op_lock, 2) == -1) {
    perror("client lock semop");
    return 0;
  }

  return 1;
}

int IPClientSharedMemory::unlock()
{
  if (_sem_id == -1 || _shm_id == -1)
    return 0;

  static sembuf op_unlock[1] = {
    { 0, -1, SEM_UNDO},  // finished reading
  };

  if (semop(_sem_id, op_unlock, 1) == -1) {
    perror("client unlock semop");
    return 0;
  }

  return 1;
}

int IPClientSharedMemory::Data(int& size, unsigned char*& data,
                               int& alignment, int& byte_order)
{
  if (!lock())
    return 0;

  int tag = *((int*)_shm);
  size = *(((int*)_shm)+1);
  if (size > MaxSize()) {
    fprintf(stderr, "IPClientSharedMemory: "
            "Warning, input size %d exceeds maximum size %d\n", 
            size, MaxSize());
    size = MaxSize();
  }
  if (tag != _cur_tag) {
    _cur_tag = tag;
    bcopy(_shm+2*sizeof(int), _input_area, size);
  }

  if (!unlock())
    return 0;

  data = _input_area;
  alignment = ALIGN;
  byte_order = BYTE_ORDER;

  return 1;
}

int IPClientSharedMemory::UpdatedTag()
{
  if (!lock())
    return -1;

  int tag = *((int*)_shm);

  if (!unlock())
    return -1;

  return tag;
}

int IPClientSharedMemory::Wait()
{
  static sembuf op_grab_mutex[1] = {
    { 2, -1, SEM_UNDO }
  };
  static sembuf op_release_mutex[1] = {
    { 2, 1, SEM_UNDO }
  };
  static sembuf op_initiate_wait[2] = {
    { 2, 1, SEM_UNDO },
    { 3, 1, 0 }
  };
  static sembuf op_wait[1] = {
    { 3, 0, 0 }
  };

  if (semop(_sem_id, op_grab_mutex, 1) == -1) {
    perror("grab mutex");
    return 0;
  }

  int t = UpdatedTag();
  if (t != -1 && t != _cur_tag) {
    if (semop(_sem_id, op_release_mutex, 1) == -1) {
      perror("release mutex");
      return 0;
    }
    return 1;
  }

  int val = semctl(_sem_id, 3, GETVAL);
  if (val < 0) {
    perror("get val");
    return 0;
  }
  if (val > 0) {
    if (semop(_sem_id, op_release_mutex, 1) == -1) {
      perror("release mutex");
      return 0;
    }
  } else {
    if (semop(_sem_id, op_initiate_wait, 2) == -1) {
      perror("initiate wait");
      return 0;
    }
  }

  while (semop(_sem_id, op_wait, 1) == -1) {
    perror("wait");
    if (errno != EINTR) {
      return 0;
    }
  }

  if (UpdatedTag() == -1)
    return false;

  return 1;
}

IPOwnedManagedMemory::IPOwnedManagedMemory(IPCommunicator* com,
                                           const char* fmt,
                                           int max_size,
                                           int sem_id, int shm_id,
                                           int flag_sem_id)
  : IPOwnedSharedMemory(com, fmt, max_size)
{
  _flag_sem_id = flag_sem_id;
  setIDs(sem_id, shm_id);
  PutData(0);
}

IPOwnedManagedMemory::~IPOwnedManagedMemory()
{
  // make sure we never remove semaphores or memory: that's the manager's job
  setIDs(-1,-1);
}

int IPOwnedManagedMemory::PutData(int size)
{
  static struct sembuf cli_flag[1] = {
    { 0, -1, IPC_NOWAIT }   /* Set flag, if not already done */
  };

  int res = IPOwnedSharedMemory::PutData(size);
  if (res) 
    semop(_flag_sem_id, cli_flag, 1);
  return res;
}

IPClientManagedMemory::IPClientManagedMemory(IPCommunicator* com,
                                             const char* fmt,
                                             int max_size,
                                             int sem_id, int shm_id)
  : IPClientSharedMemory(com, fmt, max_size)
{
  setIDs(sem_id, shm_id);
}

static void make_keys(IPConfigFile* params, key_t& shm_key, key_t& sem_key)
{
  const char* key_file;
  shm_key = params->getInt("shm_key", -1);
  if (shm_key == -1) {
    key_file = params->getString("shm_key_file", "");
    if (!*key_file)
      shm_key = IPC_PRIVATE;
    else {
      shm_key = ftok(key_file, params->getInt("shm_key_proj", 1));
      if (shm_key == -1) {
        perror("ftok");
        shm_key = IPC_PRIVATE;
      }
    }
  }

  sem_key = params->getInt("sem_key", -1);
  if (sem_key == -1) {
    key_file = params->getString("sem_key_file", "");
    if (!*key_file)
      sem_key = IPC_PRIVATE;
    else {
      sem_key = ftok(key_file, params->getInt("sem_key_proj", 1));
      if (sem_key == -1) {
        perror("ftok");
        sem_key = IPC_PRIVATE;
      }
    }
  }
}

IPSharedMemory* create_shmem_owner(IPGenerator<IPSharedMemory>*,
                                   IPConfigFile* params, IPSymbolTable* syms)
{
  const char* fmt = params->getString("format", (char*) syms->get("format"));
  int max_size = params->getInt("max_size", (long) syms->get("max_size"));
  IPCommunicator* com = (IPCommunicator*) syms->get("Communicator");

  key_t shm_key, sem_key;
  make_keys(params, shm_key, sem_key);
  
  IPOwnedSharedMemory* res =
    new IPOwnedSharedMemory(shm_key, sem_key, com, fmt, max_size);

  int sem_id = res->semID();
  int shm_id = res->shmID();
  if (sem_id < 0 || shm_id < 0) {
    delete res;
    return NULL;
  }
  params->setInt("int sem_id", sem_id);
  params->setInt("int shm_id", shm_id);

  return res;
}

IPSharedMemory* create_shmem_client(IPGenerator<IPSharedMemory>*,
                                    IPConfigFile* params, IPSymbolTable* syms)
{
  const char* fmt = params->getString("format", (char*) syms->get("format"));
  int max_size = params->getInt("max_size", (long) syms->get("max_size"));
  IPCommunicator* com = (IPCommunicator*) syms->get("Communicator");

  key_t shm_key, sem_key;
  make_keys(params, shm_key, sem_key);
  
  IPClientSharedMemory* res = 
    new IPClientSharedMemory(shm_key, sem_key,
                             params->getFloat("timeout", 1.0),
                             com, fmt, max_size);

  int sem_id = res->semID();
  int shm_id = res->shmID();
  if (sem_id < 0 || shm_id < 0) {
    delete res;
    return NULL;
  }
  params->setInt("int sem_id", sem_id);
  params->setInt("int shm_id", shm_id);

  return res;
}

IPSharedMemory* create_shmem_managed(IPGenerator<IPSharedMemory>*,
                                     IPConfigFile* params, IPSymbolTable* syms)
{
  IPTShmMemoryInitStruct smi;

  smi.fmt = params->getString("format", (char*) syms->get("format"));
  smi.max_size = params->getInt("max_size", (long) syms->get("max_size"));
  smi.use_tcp = params->getBool("use_tcp", false);
  IPCommunicator* com = (IPCommunicator*) syms->get("Communicator");

  const char* mgr_host = params->getString("host", com->ThisHost());
  int port = params->getInt("port", 1389);

  smi.name = params->getString("name", "");
  if (!*smi.name) {
    com->printf("Managed shared memory requires a name\n");
    return NULL;
  }

  bool required = params->getBool("required", false);
  int opt = (required ? IPT_REQUIRED : IPT_OPTIONAL);

  const char* mgr_name = params->getString("string mgr_name", "iptshmgr");

  IPConnection* mgr = com->LookupConnection(mgr_name);
  if (!mgr || !mgr->Active()) {
    IPConfigFile mgr_params;
    IPConfigFile::copy(*params, mgr_params);
    mgr_params.setString("string name", mgr_name);
    const char* mgr_tag = params->getString("string mgr_tag", "Unix");
    mgr_params.setString("string tag", mgr_tag);
    mgr = com->DirectConnect(&mgr_params, opt);
  }
  if (!mgr || !mgr->Active()) {
    com->printf("No memory manager on %s, %d %p\n", mgr_host, port, mgr);
    return NULL;
  }

  bool owner = params->getBool("owner", false);
  const char* msg_name;
  if (owner)
    msg_name = IPT_SHM_OWNED_MEMORY_INIT_MSG;
  else
    msg_name = IPT_SHM_CLIENT_MEMORY_INIT_MSG;

  IPMessage* repl =
    com->Query(mgr, msg_name, &smi, IPT_SHM_MEMORY_INITIALIZED_MSG);
  if (!repl)
    return NULL;

  IPTShmMemorySpecStruct* sms = (IPTShmMemorySpecStruct*)repl->FormattedData();
  if (sms->shm_id < 0 || sms->sem_id < 0 || sms->flag_sem_id < 0) {
    com->printf("Memory initialization error %d %d %d\n",
           sms->shm_id, sms->sem_id, sms->flag_sem_id);
    delete repl;
    return NULL;
  }

  IPSharedMemory* res;
  if (owner) {
    res = new IPOwnedManagedMemory(com, smi.fmt, smi.max_size,
                                   sms->sem_id, sms->shm_id, sms->flag_sem_id);
  } else {
    res = new IPClientManagedMemory(com, smi.fmt, smi.max_size,
                                    sms->sem_id, sms->shm_id);
  }
  if (!res->Valid()) {
    delete res;
    res = NULL;
  }
  repl->DeleteFormatted(sms);
  delete repl;
  return res;
}

