%module _IPT
%{
#include <signal.h>

#include <ipt/ipt.h>
#include <ipt/message.h>
#include <ipt/messagetype.h>
#include <ipt/callbacks.h>
#include <ipt/connection.h>
#include <ipt/timer.h>
#include <ipt/server.h>

#include <utils/formatting/Format.h>
#include <utils/formatting/DeleteAction.h>
#include <utils/List.h>
#include <utils/SymbolTable.h>
#include <utils/ConfigFile.h>
#include <utils/Generator.h>

typedef IPConnection Connection;
typedef IPMessage Message;
typedef IPCommunicator Communicator;
typedef IPFormat Format;
typedef IPTimer Timer;
typedef IPMessageType MessageType;

#include "PyPackAction.h"
#include "PyUnpackAction.h"
#include "PyMessageHandler.h"
#include "PyConnectHandler.h"
#include "PyTimerHandler.h"

#define MAX_PEERS 10

static void my_exit(int)
{
    exit(0);
}

static IPConnectionCallback*
lookup_callback(IPList<IPConnectionCallback>* callbacks, PyObject* method) 
{
  IPListIterator<IPConnectionCallback> iter(*callbacks);
  for (IPConnectionCallback* cb = iter.first();cb;
       cb=iter.next()) {
    if (!strcmp(cb->getTypeName(), "PyConnectHandler")) {
      PyConnectHandler* handler = (PyConnectHandler*) cb;
      if (handler->getMethod() == method) 
        return cb;
    }
  }
  return NULL;
}

static IPCommunicator* create_server(IPGenerator<IPCommunicator>*,
                                     IPConfigFile* params,
                                     IPSymbolTable* globals)
{
  const char* module_name = (const char*) globals->get("ModuleName");
  if (!module_name)
    module_name = params->getString("ModuleName", "Module");
  const char* message_file = params->getString("message_file", NULL);
  if (message_file && !*message_file)
    message_file = NULL;
  const char* domain_name = params->getString("domain_name", NULL);
  if (domain_name && !*domain_name)
    domain_name = NULL;
  const char* log_file = params->getString("log_file", NULL);
  if (log_file && !*log_file)
    log_file = NULL;

  const char* peer_names[MAX_PEERS];
  const char* peer_hosts[MAX_PEERS];
  int n1 = params->getStrings("peer_names", peer_names, MAX_PEERS, NULL, 0);
  int n2 = params->getStrings("peer_hosts", peer_hosts, MAX_PEERS, NULL, 0);
  int num_peers = MIN(n1,n2);
  IPDomainSpec peers[MAX_PEERS+1];
  for (int i=0;i<num_peers;i++) {
    peers[i].name = (char*) peer_names[i];
    peers[i].host = (char*) peer_hosts[i];
  }
  peers[num_peers].name = peers[num_peers].host = NULL;
    
  return new IPServer(module_name, message_file, domain_name, peers, log_file);
}


%}

%init
%{
    signal(SIGINT, my_exit);
%}

// When we add a method that takes a python object (such as 
// FoContainer.addMessageHandler), it should be inputted as a python object
%typemap(python, in) PyObject* {
    $target = $source;
}

// When we add a method that returns a python object (such as FoContainer.get),
// it should be returned as a python object
%typemap(python, out) PyObject* {
    $target = $source;
}

%typemap(python, in) const char* msg_name {
    $target = PyString_AsString($source);
    IPMessageType* type = _arg0->LookupMessage($target);
    if (!type) {
        PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                     $target);
        return NULL;
    }
}

%typemap(python, in) const char* format_string {
    if ($source == Py_None) 
        $target = NULL;
    else
        $target = PyString_AsString($source);
}

%typemap(python, except) Communicator* {
    $function
    if (!$source) {
        PyErr_SetString(PyExc_RuntimeError, "NULL Communicator");
        return NULL;
    }
}

%typemap(python, except) Connection* {
    $function
    if (!$source) {
        PyErr_SetString(PyExc_RuntimeError, "NULL Connection");
        return NULL;
    }
}

%typemap(python, except) Message* {
    $function
    if (!$source) {
        PyErr_SetString(PyExc_RuntimeError, "NULL Message");
        return NULL;
    }
}

%typemap(python, except) MessageType* {
    $function
    if (!$source) {
        PyErr_SetString(PyExc_RuntimeError, "NULL Message type");
        return NULL;
    }
}

// Base class for inter-module IPT connections
%name(_Connection) class Connection {
  public:
    void InvokeConnectCallbacks();
    void InvokeDisconnectCallbacks();

    int FD() const { return _fd; }
    Communicator* Communicator() const { return _communicator; }
    const char* Name() const { return _name; }
    const char* Host() const { return _host; }

    int Active();
    int Viable();
    int DataAvailable();

    %addmethods {
      void AddConnectCallback(PyObject* method) {
          self->AddConnectCallback(new PyConnectHandler(method));
      }
      void RemoveConnectCallback(PyObject* method) {
          IPConnectionCallback* cb =
              lookup_callback(self->ConnectCallbacks(), method);
          if (cb)
              self->RemoveConnectCallback(cb);
      }
      void AddDisconnectCallback(PyObject* method) {
          self->AddDisconnectCallback(new PyConnectHandler(method));
      }
      void RemoveDisconnectCallback(PyObject* method) {
          IPConnectionCallback* cb =
              lookup_callback(self->DisconnectCallbacks(), method);
          if (cb)
              self->RemoveDisconnectCallback(cb);
      }
    }
};

class Message {
  public:
    ~Message();

    %addmethods {
        PyObject* Data() {
            if (!self) {
                PyErr_Format(PyExc_RuntimeError, "Using NULL message");
                return NULL;
            }
            void* data = self->FormattedData();
            utils::FormatSpec* spec = self->Type()->Formatter();

            PyObject* res = PyPackAction::pack(spec, data);
            if (!res) {
                PyErr_Format(PyExc_RuntimeError,
                             "No formattable data in message");
                return NULL;
            }
            return res;
        }
        const char* Type() const {
            return self->Type()->Name();
        }
    }
        
    int Instance() const;
    int ID() const;

    Connection* Connection() const;

    int Print(int print_data = 0);
};

class Timer {
public:
    int one_shot() const;
    double interval() const;
    double time_left() const;
    void clear();
};

%name(_Communicator) class Communicator {
  public:
    /* register a message with a name and a format, create it and add message
       name to internal message hash table */
    MessageType* RegisterMessage(const char* msg, const char* format_string);
    Format* RegisterNamedFormatter(const char*, const char*);

    /* make a connection to another module */
    Connection* Connect(const char* name,
                        int required = IPT_REQUIRED);
    Connection* DirectConnect(const char* parameters,
                              int required=IPT_REQUIRED);

    /* get and set handlers status.  When handlers are disabled _no_ handlers
       work, including the ones necessary to connect modules */
    int HandlersActive() const;
    void DisableHandlers();
    void EnableHandlers();

    /* Get the next available message */
    %new Message* ReceiveMessage(double timeout = IPT_BLOCK);
    /* convenience functions for ReceiveMessage*/
    %name(ReceiveMessageFrom)
        %new Message* ReceiveMessage(Connection* conn, const char* msg_name,
                                     double timeout=IPT_BLOCK);


    /* The result of calling this function on a message specifier will be that
       the result of a GetMessage or the data passed in on a callback in result
       of a message of type msg_type will be garaunteed to have the most recent
       message of that type from a connection
     */
    void PigeonHole(const char* msg_name);

    /* routines for calling back when things connect and disconnect */
    void InvokeConnectCallbacks(Connection* res);
    void InvokeDisconnectCallbacks(Connection* res);

    /* Timer operations */
    void RemoveTimer(Timer*);

    /* loops for handling events */
    void MainLoop();
    int Sleep(double time = IPT_BLOCK);
    int Idle(double time = IPT_BLOCK);
    void Finish();
    void Restart();
    int Finished();

    /* lookup a connection by name */
    Connection* LookupConnection(const char*);

    /* Return the name of the machine the module is on */
    const char* ThisHost() const;
    /* Return the name of the module */
    const char* ModuleName() const { return _mod_name; }
    /* Return the name of the machine the server is on */
    const char* ServerHostName() const { return _host_name; }
    const char* DomainName() const { return _domain_name; }
    void SetDomainName(char* name) { _domain_name = name; }

    /* disable and enable handlers for a particular message type */
    void DisableHandler(const char*);
    void EnableHandler(const char*);

    /* publication/subscription routines */
    Connection* Subscribe(const char*, const char*);
    %name(SubscribeTimed) Connection* Subscribe(const char*, const char*,
                                                double interval);
    int NumSubscribers(const char*);
    void Unsubscribe(const char*, const char*);
    
    %addmethods {
        void Close() { 
            delete self;
        }
               
        /* create an appropriate instance of Communicator */
        static Communicator* Instance(const char* mod_name,
                                      const char* host_name=NULL) {
            Communicator* res = Communicator::Instance(mod_name, host_name);
            if (res)
                initPyFormatting(*res->FormatParser());
            return res;
        }
        static Communicator* Create(const char* mod_name, const char* spec) {
            IPGenerator<IPCommunicator>* gen = Communicator::generator();
            gen->registerInterface("server", create_server, NULL);
            
            IPSymbolTable sym_table;
            sym_table.set("ModuleName", mod_name);

            Communicator* res = gen->interface(spec, &sym_table);
            if (res)
                initPyFormatting(*res->FormatParser());
            return res;
        }

        /* register a handler to react to an incoming message with name 
           msg_name name the handler hand_name.  The handler uses callback to
           respond.  When to handle the message is controlled by
           handler_spec */
        void RegisterHandler(const char* msg_name, PyObject* callback,
                             int spec = IPT_HNDL_STD) {
            self->RegisterHandler(msg_name, new PyMessageHandler(callback),
                                  spec);
        }

        /* Convenience functions for receiving formatted data */
        PyObject* ReceiveFormatted(Connection* conn, const char* name, 
                               double timeout = IPT_BLOCK) {
            IPMessageType* type = self->LookupMessage(name);
            if (!type) {
                PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                             name);
                return NULL;
            }

            void* data = self->ReceiveFormatted(conn, type, timeout);
            if (!data) {
                PyErr_Format(PyExc_RuntimeError, "Receive error");
                return NULL;
            }
            utils::FormatSpec* spec = type->Formatter();
            PyObject* res = PyPackAction::pack(spec, data);

            if (!res) {
                PyErr_Format(PyExc_RuntimeError,
                             "No formattable data in message");
                return NULL;
            }
            return res;
        }
        
        PyObject* SendMessage(Connection* conn, const char* send_msg_name,
                              PyObject* data) {
            IPMessageType* type = self->LookupMessage(send_msg_name);
            if (!type) {
                PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                             send_msg_name);
                return NULL;
            }
            utils::FormatSpec* spec = type->Formatter();
            void* msg_data = PyUnpackAction::unpack(spec, data);
            if (!msg_data && spec->dataSize() > 0) {
                PyErr_Format(PyExc_RuntimeError, "Data does not match format");
                return NULL;
            }
            PyObject* res;
            if (!self->SendMessage(conn, type, msg_data)) {
                PyErr_Format(PyExc_RuntimeError, "Error sending %s to %s",
                             send_msg_name, conn);
                res = NULL;
            } else {
                res = Py_None;
                Py_INCREF(res);
            }
            utils::DeleteAction::deleteData(spec, msg_data);
            return res;
        }

        /* routines for calling back when things connect and disconnect */
        void AddConnectCallback(PyObject* method) {
            self->AddConnectCallback(new PyConnectHandler(method));
        }
        void RemoveConnectCallback(PyObject* method) {
            IPConnectionCallback* cb =
                lookup_callback(self->GetConnectCallbacks(), method);
            if (cb)
                self->RemoveConnectCallback(cb);
        }
        void AddDisconnectCallback(PyObject* method) {
            self->AddDisconnectCallback(new PyConnectHandler(method));
        }
        void RemoveDisconnectCallback(PyObject* method) {
            IPConnectionCallback* cb =
                lookup_callback(self->GetDisconnectCallbacks(), method);
            if (cb)
                self->RemoveDisconnectCallback(cb);
        } 

        /* Timer routines */
        Timer* AddTimer(double interval, PyObject* method) {
            return self->AddTimer(interval, new PyTimerHandler(method));
        }
        Timer* AddOneShot(double interval, PyObject* method) { 
            return self->AddOneShot(interval, new PyTimerHandler(method));
        }

        /* Query for a message */
        %new Message* Query(Connection* conn,
                            const char* req_name, PyObject* req_data,
                            const char* repl_name,
                            double timeout = IPT_BLOCK) {
            IPMessageType* req_type = self->LookupMessage(req_name);
            if (!req_type) {
                PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                             req_name);
                return NULL;
            }
            IPMessageType* repl_type = self->LookupMessage(repl_name);
            if (!repl_type) {
                PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                             repl_name);
                return NULL;
            }
            
            utils::FormatSpec* req_spec = req_type->Formatter();
            void* msg_data = PyUnpackAction::unpack(req_spec, req_data);
            if (!msg_data && req_spec->dataSize() > 0) {
                PyErr_Format(PyExc_RuntimeError, "Data does not match format");
                return NULL;
            }
            Message* res = self->Query(conn, req_type, msg_data, repl_type,
                                       timeout);
            if (!res) {
                PyErr_Format(PyExc_RuntimeError,
                             "Error querying %s for %s with %s",
                             conn, repl_name, req_name);
                res = NULL;
            } 
            utils::DeleteAction::deleteData(req_spec, msg_data);
            return res;
        }
        
        PyObject* QueryFormatted(Connection*, const char*, PyObject*,
                                 const char*, double timeout = IPT_BLOCK) {
            return NULL;
        }
        /* Reply to a message */
        PyObject* Reply(Message* msg, const char* repl_name,
                        PyObject* repl_data) {
            IPMessageType* repl_type = self->LookupMessage(repl_name);
            if (!repl_type) {
                PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                             repl_name);
                return NULL;
            }
            utils::FormatSpec* repl_spec = repl_type->Formatter();
            void* data = PyUnpackAction::unpack(repl_spec, repl_data);
            if (!data && repl_spec && repl_spec->dataSize() > 0) {
                PyErr_Format(PyExc_RuntimeError, "Data does not match format");
                return NULL;
            }
            self->Reply(msg, repl_type, data);
            utils::DeleteAction::deleteData(repl_spec, data);
            Py_INCREF(Py_None);
            return Py_None;
        }

        void IterateConnections(PyObject* method) {
            IPConnectionCallback* cb = new PyConnectHandler(method);
            cb->ref();
            self->IterateConnections(cb);
            cb->unref();
        }

        /* publication/subscription routines */
        void DeclareSubscription(const char* msg_name,
                                 PyObject* method = Py_None) {
            IPConnectionCallback* cb;
            if (method == Py_None)
                cb = NULL;
            else
                cb = new PyConnectHandler(method);
            self->DeclareSubscription(msg_name, cb);
        }
        void DeclareTimedSubscription(const char* msg_name,
                                      PyObject* method = Py_None) {
            IPConnectionCallback* cb;
            if (method == Py_None)
                cb = NULL;
            else
                cb = new PyConnectHandler(method);
            self->DeclareTimedSubscription(msg_name, cb);
        }
        
        PyObject* Publish(const char* name, PyObject* msg_data) {
            IPMessageType* msg_type = self->LookupMessage(name);
            if (!msg_type) {
                PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                             name);
                return NULL;
            }
            utils::FormatSpec* msg_spec = msg_type->Formatter();
            void* data = PyUnpackAction::unpack(msg_spec, msg_data);
            if (!data && msg_spec && msg_spec->dataSize() > 0) {
                PyErr_Format(PyExc_RuntimeError, "Data does not match format");
                return NULL;
            }
            self->Publish(msg_type, data);
            utils::DeleteAction::deleteData(msg_spec, data);
            Py_INCREF(Py_None);
            return Py_None;
        }

        /* general server/client routines */
        void Server(const char* msg_name, PyObject* method = Py_None) {
            IPHandlerCallback* cb;
            if (method == Py_None)
                cb = NULL;
            else
                cb = new PyMessageHandler(method);
            self->Server(msg_name, cb);
        }

        PyObject* Broadcast(const char* name, PyObject* msg_data) {
            IPMessageType* msg_type = self->LookupMessage(name);
            if (!msg_type) {
                PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                             name);
                return NULL;
            }
            utils::FormatSpec* msg_spec = msg_type->Formatter();
            void* data = PyUnpackAction::unpack(msg_spec, msg_data);
            if (!data && msg_spec && msg_spec->dataSize() > 0) {
                PyErr_Format(PyExc_RuntimeError, "Data does not match format");
                return NULL;
            }
            self->Broadcast(msg_type, data);
            utils::DeleteAction::deleteData(msg_spec, data);
            Py_INCREF(Py_None);
            return Py_None;
        }
        PyObject* Client(Connection* conn,
                         const char* name, PyObject* msg_data) {
            IPMessageType* msg_type = self->LookupMessage(name);
            if (!msg_type) {
                PyErr_Format(PyExc_RuntimeError, "Unknown message type %s",
                             name);
                return NULL;
            }
            utils::FormatSpec* msg_spec = msg_type->Formatter();
            void* data = PyUnpackAction::unpack(msg_spec, msg_data);
            if (!data && msg_spec && msg_spec->dataSize() > 0) {
                PyErr_Format(PyExc_RuntimeError, "Data does not match format");
                return NULL;
            }
            self->Client(conn, name, data);
            utils::DeleteAction::deleteData(msg_spec, data);
            Py_INCREF(Py_None);
            return Py_None;
        }
    }
}

