#ifndef _NET_ATTACKS_H_
#define _NET_ATTACKS_H_

#include <set>
#include <iostream>
#include <sstream>

#include "Network.hpp"
#include "nodeattr.h"

using namespace std;

namespace GraphGraphics
{
  
struct  NetworkInfo
{
  Host*   src_host;
  Host*   dest_host;
  Attack* attack;
  Firewall *dest_firewall;
  
  NetworkInfo() 
  {
    src_host = 0;
    dest_host = 0;
    attack = 0;
    dest_firewall = 0;
  }
};

class AttackTriple
{
  int m_src_host_id;
  int m_dest_id;
  int m_attack_id;
  bool m_is_dest_firewall;
  
  static const string TARGET_FW_TYPE;
  
public:
  AttackTriple()
  { 
    m_src_host_id = -1; 
    m_dest_id = -1; 
    m_attack_id = -1;
    m_is_dest_firewall = false; 
  }
  AttackTriple( int attack_id, int src_id, int dest_id, bool dest_firewall = false )
  {
    m_src_host_id = src_id; 
    m_is_dest_firewall = dest_firewall;
    m_dest_id = dest_id; 
    m_attack_id = attack_id;
  }                  
  AttackTriple( const AttackTriple &attack )
  {
    m_attack_id   = attack.m_attack_id;
    m_src_host_id = attack.m_src_host_id;
    m_dest_id = attack.m_dest_id;
    m_is_dest_firewall = attack.m_is_dest_firewall;
  }
    
  AttackTriple& operator = ( const AttackTriple &attack )
  {
    m_attack_id   = attack.m_attack_id;
    m_src_host_id = attack.m_src_host_id;
    m_dest_id     = attack.m_dest_id;
    m_is_dest_firewall = attack.m_is_dest_firewall;
      
    return (*this);
  }
                    
  ~AttackTriple(){};
      
  unsigned int get_source_host_id() const { return m_src_host_id; } 
  unsigned int get_target_id() const { return m_dest_id; }
  bool         is_target_firewall() const { return m_is_dest_firewall; }
  unsigned int get_attack_id() const { return m_attack_id; }
    
  void set_source_host_id( int src_id ) { m_src_host_id = src_id; }
  void set_target_id( int dest_id, bool is_firewall = false ) { m_dest_id = dest_id; m_is_dest_firewall = is_firewall;}
  void set_attack_id( int attack_id ) { m_attack_id = attack_id; }
    
  NetworkInfo *get_network_info( Network *network ) const
  {
    NetworkInfo *res = 0;
      
    if ( network != 0 )
    {
      if ( m_src_host_id < network->GetNumHosts() && m_src_host_id >= 0 &&
           (m_dest_id < network->GetNumHosts() && !m_is_dest_firewall || 
            m_is_dest_firewall && m_dest_id < (int)network->GetFirewallNumber()) && 
           m_dest_id >= 0 &&
           m_attack_id < network->GetNumAttacks() && m_attack_id >= 0 )
      {
        res = new NetworkInfo();
        res->src_host = network->GetHost( m_src_host_id );
        if ( m_is_dest_firewall )
        {
          res->dest_firewall = network->GetFirewall( m_dest_id );
          res->dest_host = 0;
        }
        else
        {
          res->dest_host = network->GetHost( m_dest_id );
          res->dest_firewall = 0;
        }
        res->attack = network->GetAttack( m_attack_id );
      }
    }
    
    return res;
  }
    
  bool operator == ( const AttackTriple &attack_triple ) const 
  {
    return (attack_triple.m_attack_id == m_attack_id &&
            attack_triple.m_src_host_id == m_src_host_id &&
            attack_triple.m_dest_id == m_dest_id &&
            attack_triple.m_is_dest_firewall == m_is_dest_firewall);
  }
    
  bool less( const AttackTriple &attack_triple ) const
  {
    bool res;
      
    if ( !(res = (m_attack_id < attack_triple.m_attack_id)) )
    {   
      if ( attack_triple.m_attack_id == m_attack_id )
      {
        if ( !(res = (m_src_host_id < attack_triple.m_src_host_id)) )
          if ( m_src_host_id == attack_triple.m_src_host_id )
            if ( !(res = ( m_dest_id < attack_triple.m_dest_id )) )
              if ( m_dest_id == attack_triple.m_dest_id )
                res = (!m_is_dest_firewall && attack_triple.m_is_dest_firewall);
      }
    }
      
    return res;
  }
    
  void make_graph_attr_list( NodeAttributeList &attr_list ) const
  {
    stringstream str_stream("");
    string       str;
    str_stream << "{ GlobalVariables{ net.attack_id=" << m_attack_id << "; ";
    str_stream << "net.source_ip=" << m_src_host_id << "; ";
    if ( m_is_dest_firewall )
      str_stream << "net.target_fwi=";
    else
      str_stream << "net.target_ip=";
    str_stream << m_dest_id << "; }}";
    str = str_stream.str();
    GraphAttributes::parse_string( str, attr_list );
  }
    
  bool make_attack_triple( GraphAttributes *attr );
     
  bool make_attack_triple( const string &attack_name, const string &source_host, const string &dest, Network* network, bool is_dest_firewall = false );
      
  bool make_attack_triple( NetworkInfo *net_info, Network* network )
  {
    bool res = false;

    if ( net_info != 0 && network != 0 )
    {
      if ( net_info->src_host != 0 && net_info->attack != 0 )
      {
        if ( net_info->dest_host != 0 )        
          res = make_attack_triple( net_info->attack->GetName().transcode(), 
                                    net_info->src_host->GetName().transcode(),
                                    net_info->dest_host->GetName().transcode(), network );
        else
          res = make_attack_triple( net_info->attack->GetName().transcode(), 
                                    net_info->src_host->GetName().transcode(),
                                    net_info->dest_firewall->GetName().transcode(), network, true );
      }
    }
    return res;
  }
  
  bool make_attack_triple( const string &data, Network *net );
    
  string make_attack_triple_string( Network *net ) const;
  
};

class AttackTripleLess
{
public:
  bool operator () ( const AttackTriple &key1, const AttackTriple &key2 ) const
  {
    return key1.less( key2 );
  }
};

typedef set<AttackTriple, AttackTripleLess> AttackTriplesSet;

class AttackMeasure: public AttackTriplesSet
{
private:
  string m_name;

public:
  AttackMeasure(){ m_name = ""; }
  ~AttackMeasure(){}
  
  AttackMeasure( const AttackMeasure& measure )
  {
    m_name = measure.m_name;
    insert( measure.begin(), measure.end() );
  }
  
  AttackMeasure& operator = ( const AttackMeasure& measure )
  {
    clear();
    
    m_name = measure.m_name;
    insert( measure.begin(), measure.end() );
    
    return *this;
  }
  
  bool operator == ( const AttackMeasure& measure ) const
  {
    bool res = (m_name == measure.m_name);
    
    if ( res )
      res = (size() == measure.size());
    
    for ( AttackMeasure::const_iterator p = begin(); p != end() && res; p++ )
      if ( measure.find( *p ) == measure.end() )
        res = false;
      
    return res;
  }
  
  void   set_name( const string &name ){ m_name = name; }
  string get_name( ) { return m_name; }
  
  bool load_from_file( const string& file_name, Network *net );
  bool save_into_file( const string& file_name, Network *net ) const;
  
  void write_to( ostream& s_out, Network *net, bool save_name = false ) const;
  bool read_from( istream& s_in, Network *net );
};

typedef int AttackMeasureID;
typedef set<AttackMeasureID> AttackMeasureIDSet;

class AttackMeasuresSet
{
typedef map<AttackMeasureID, AttackMeasure> AttackMeasuresMap;
typedef map<AttackTriple, AttackMeasureIDSet, AttackTripleLess> AttackTripleMeasuresMap;
  
private:
  AttackMeasuresMap m_measures;
  AttackTripleMeasuresMap m_attack_measure;
  AttackMeasureID   m_last_measure_id;
  
public:
  AttackMeasuresSet(): m_last_measure_id(0){}
  ~AttackMeasuresSet(){}
    
  AttackMeasuresSet( const AttackMeasuresSet &measure_set )
  {
    m_last_measure_id = measure_set.m_last_measure_id;
    m_measures.insert( measure_set.m_measures.begin(), measure_set.m_measures.end() );
    m_attack_measure.insert( measure_set.m_attack_measure.begin(), measure_set.m_attack_measure.end() );
  }
  
  AttackMeasuresSet& operator = ( const AttackMeasuresSet &measure_set )
  {
    m_last_measure_id = measure_set.m_last_measure_id;
    m_measures.insert( measure_set.m_measures.begin(), measure_set.m_measures.end() );
    m_attack_measure.insert( measure_set.m_attack_measure.begin(), measure_set.m_attack_measure.end() );
    
    return *this;
  }
    
  AttackMeasureID add_measure( const AttackMeasure& measure );
    
  AttackMeasureID get_first_measure_id() const
  {
    AttackMeasureID res = -1;
    AttackMeasuresMap::const_iterator p;
    
    if ( (p = m_measures.begin()) != m_measures.end() )
      res = p->first;
    
    return res;
  }
  
  AttackMeasureID get_next_measure_id( AttackMeasureID measure_id ) const
  {
    AttackMeasureID res = -1;
    AttackMeasuresMap::const_iterator p;
    
    if ( (p = m_measures.find( measure_id )) != m_measures.end() )
    {
      p++;
      if ( p != m_measures.end() )
        res = p->first;
    }
    
    return res;
  }
  
  AttackMeasureID find_measure( const AttackMeasure& measure ) const
  {
    AttackMeasureID res = -1;
    AttackMeasuresMap::const_iterator p;
    
    for ( p = m_measures.begin(); p != m_measures.end() && res < 0; p++ )
    {
      if ( p->second == measure )
        res = p->first;
    }
    
    return res;
  }
  
  AttackMeasure get_measure( AttackMeasureID measure_id ) const
  {
    AttackMeasuresMap::const_iterator p;
    if ( (p = m_measures.find( measure_id )) == m_measures.end() )
      return AttackMeasure();
    
    return p->second;
  }
  
  void set_measure( AttackMeasureID measure_id, const AttackMeasure &measure );
    
  AttackMeasureIDSet get_measures_covering( const AttackTriple &attack ) const
  {
    AttackTripleMeasuresMap::const_iterator p;
    if ( (p = m_attack_measure.find( attack )) == m_attack_measure.end() )
      return AttackMeasureIDSet();
    
    return p->second;
  }
  
  AttackMeasureID get_bigest_measure_covering( const AttackTriple &attack ) const
  {
    AttackTripleMeasuresMap::const_iterator p;
    AttackMeasuresMap::const_iterator iter;
    unsigned int        max_attack_cover = 0;
    AttackMeasureID     res = -1;
    AttackMeasure       measure;
    
    if ( (p = m_attack_measure.find( attack )) != m_attack_measure.end() )
    {
      for ( AttackMeasureIDSet::const_iterator mp = p->second.begin(); mp != p->second.end(); mp++ )
      {
        iter = m_measures.find( *mp );
        measure = iter->second;
        if ( max_attack_cover < measure.size() )
        {
          max_attack_cover = measure.size();
          res = *mp;
        }
      }
    }      
        
    return res;
   }
  
  void remove_measure( AttackMeasureID measure_id )
  {
    if ( m_measures.find( measure_id ) != m_measures.end() )
    {
      AttackMeasure measure = m_measures[measure_id];
    
      for ( AttackMeasure::iterator p = measure.begin(); p != measure.end(); p++ )
        m_attack_measure.erase( *p );
    
      m_measures.erase( measure_id );
    }
  }
  
  void clear()
  {
    m_measures.clear();
    m_attack_measure.clear();
    m_last_measure_id = 0;
  }
    
  void get_graph_attr_lists( AttackMeasureID measure_id, vector<NodeAttributeList> &attr_lists ) const
  {
    AttackMeasure measure = get_measure( measure_id );
    NodeAttributeList attr_list;
    
    attr_lists.clear();
    
    for ( AttackMeasure::const_iterator p = measure.begin(); p != measure.end(); p++ )
    {
      p->make_graph_attr_list( attr_list );
      attr_lists.push_back( attr_list );
    }
  }
  
  unsigned int size() const { return m_measures.size(); }
  
  bool load_from_file( const string& file_name, Network *net );
  bool save_into_file( const string& file_name, Network *net ) const;
};

typedef map<AttackMeasureID, unsigned int> AttackMeasuresCount; 
typedef vector<AttackMeasuresCount> GraphMeasureMarks;

};

#endif
