
#include "policy_routines.h"
#include "btree.h"
#include "pdphack.h"

#include <netinet/ip_var.h>
#include <netinet/tcp_timer.h>
#include <netinet/tcp_var.h>
#include <netinet/udp_var.h>
#include <netinet/in_pcb.h>
#include <netinet/in.h>
#include <netinet/udp_var.h>
#include <sys/socket.h>
#include <sys/socketvar.h>
#include <netinet/tcpip.h>

//#define IPOUTVERBOSE

#define POLICYPORT 3232

long ip_out_packet_count = 0;

extern u_char ip_protox[];

extern int policy_queue_length;

extern long CACHE_SIZE;

btree inet_out_cache_tree = NULL;

struct pdphack_t *pdphack_out = NULL;

extern long sleepingguys;

int
policy_check_inet_out(struct policy_inet_out_t *pi) {
    struct mbuf *m;
    int spl;
    int pass;
    register long space, len;
    register quad_t resid;
    policy_context	*context;
    policy_request_entry_t *newnode;
    int hlen;
    struct socket *so;
    struct inpcb *inp;
    int policyres;

    so = pi->so;
    inp = sotoinpcb(so);

    spl = splhigh();
    
    //pdp hack
    if (so->so_proto->pr_protocol == 6)
      {
	char *ipaddr;
	struct pdphack_t *pdptmp;
	int hackallow=0;
	int sport,dport;
	int nofree=0;
	
	sport = ntohs(inp->inp_lport);
	dport = ntohs(inp->inp_fport);
	
	MALLOC(ipaddr, char*,sizeof(char)*20,M_TEMP,M_WAITOK);
	
	my_inet_ntop4(&inp->inp_faddr.s_addr, ipaddr, 0);
	
	pdptmp = pdphack_lookup(&pdphack_out,ipaddr);
	
	if (sport == POLICYPORT)
	  {
	    if (pdptmp == NULL)
	      {
		pdphack_add(&pdphack_out,ipaddr);
		nofree=1;
	      }
	    hackallow=1;
	  } else {
	    if (pdptmp != NULL)
	      {
		if (pdptmp->port == 0)
		  {
		    if (sport != POLICYPORT)
		      pdptmp->port = sport;
		    inet_out_add_to_cache(pi, CACHE_ACCEPT);
		    hackallow=1;
		  } else {
		    if (pdptmp->port == sport)
		      {
			hackallow=1;
		      }
		  }
	      }
	    
	  }
	
	if ((!hackallow)&&(dport == POLICYPORT))
	  {
	    if (pdptmp == NULL)
	      {
		pdphack_add(&pdphack_out,ipaddr);
		nofree=1;
	      } 
	    hackallow=1;
	  } else {
	    if (pdptmp != NULL)
	      {
		if (pdptmp->port == 0)
		  {
		    if (dport != POLICYPORT)
		      pdptmp->port = dport;
		    inet_out_add_to_cache(pi, CACHE_ACCEPT);
		    hackallow=1;
		  } else {
		    if (pdptmp->port == dport)
		      {
			hackallow=1;
		      }
		  }
	      }
	  }
	if (!nofree)
	  {
	    FREE(ipaddr, M_TEMP);
	  }
	
	if (hackallow)
	  {
	    splx(spl);

	    return(0);
	  } 
      }

    
    splx(spl);

    ip_out_packet_count++;
    
    
    context = policy_create_context();
    if (context == NULL) {
      printf("context was null !!!\n");
      
      //our fault -- jump to bottom and deal with it XXX
      goto bad;
    }
    
    policy_add_string(context, "app_domain", "dist_firewall");
    policy_add_string(context, "agent","Agent");
    policy_add_string(context, "operation", "ipv4output");
    policy_add_int(context, "family", so->so_proto->pr_protocol);
    policy_add_ipv4address(context, "dst_addr", (in_addr_t *) &inp->inp_faddr.s_addr);
    policy_add_ipv4address(context, "src_addr", (in_addr_t *) &inp->inp_laddr.s_addr);
    policy_add_int(context, "dst_port", (int)ntohs(inp->inp_fport));
    policy_add_int(context, "src_port", (int)ntohs(inp->inp_lport));
    
    

    if (so->so_proto->pr_protocol == 6)
      {
	//tcp
	
	policy_add_string(context, "subject", "level2_protocol");
	policy_add_string(context, "l2protocol_name", "tcp_ip");

      } else if (so->so_proto->pr_protocol ==17) {
	//udp
	
	policy_add_string(context, "subject", "level2_protocol");
	policy_add_string(context, "l2protocol_name", "udp_ip");

      } else {
	policy_add_string(context, "subject", "level1_protocol");
	policy_add_string(context, "l1protocol_name", "ip");
      }
    
    context->pc_type = PLC_INET_OUT;
    context->pc_data = pi; 
    
    context->session_id = 0;
    
    context->sleep_id = 1;

    if (!(policy_in_use))
      {
	return(0);
      }

#ifdef IPOUTVERBOSE 
	printf("\n#1# Context Created.. Enqueuing");
#endif

    spl = splhigh();
    
    create_context_session(context);
    
    policy_commit_context(context);
    
    policy_queue_length++;

    sleepingguys++;

    splx(spl);
    
#ifdef IPOUTVERBOSE 
    printf("\n#2# Context Enqueued... Sleeping");
#endif
    

    //all we can do now is wait for the call to complete and find out
    //if we are allowed to perform the operation or not
    
    policyres = tsleep(&context->sleep_id, PUSER | PCATCH , 0, /* (int)(hz / 2) */ 0 );

    spl = splhigh();
    
    sleepingguys--;
    
    if (context->sleep_id)
      {
	if (policyres==0)
	  panic("Why did we stop here??");
	
	// sleep was killed because the system call was interrupted
	
	free_context_session(context->session_id);
	
	if (remove_from_list(&policy_context_head, context))
	  policy_queue_length--;
	
	remove_from_list(&policy_context_waiting, context);
	
	policy_destroy_context(context);
	
	return(1);
      }
    
    
    splx(spl);
    
#ifdef FSVERBOSE 
    printf("\n<3> Woken up from sleep.. Checking Result");
#endif
    

    {
      int a;
      for(a=0;a<MAX_POLICY_CONTEXTS;a++)
	if (context == context_session[a].context)
	  panic("SESSION SHOULD BE FREE MOFO!!! (OUT)");
      
      if ((context == policy_context_head)||(context==policy_context_waiting))
	panic("thats not it!!");
    }
    
    if (!(policy_in_use))
      {
	policy_destroy_context(context);
	return(0);
      }
    
    if (context->reply == NULL)
      {
	printf("\nNULL REPLY FROM POLICY DAEMON!!");
      }
    
    if ((context->reply!=NULL)&&(context->reply[0] == 's'))
      {
#ifdef FSVERBOSE 
	printf("\n<4> ALLOW!  Waking up policywrite()");
#endif
	policy_destroy_context(context);
	return(0);
      }	
    else 
      {
#ifdef FSVERBOSE 
	printf("\n<4> DENY!  Waking up policywrite()");
#endif
	policy_destroy_context(context);
	return(1);
      }
	
 bad:
    //we fucked up -- network dies!
    printf("Problem with IP network filter! (mem error in policy filter)");
    return(1);
    
};
      
int policy_cache_inet_out(struct policy_inet_out_t *pi)
{
  ip_cache_entry     ce;
  struct inpcb *inp;
  int pass = 0;
  inp = sotoinpcb(pi->so);

  if (inet_out_cache_tree == NULL)
    {
      return(-1);
    } else {
      u_int32_t key;
      
      key = inp->inp_faddr.s_addr^ inp->inp_fport;
        
      ce = bt_lookup(inet_out_cache_tree,key);
      
      if (ce == (ip_cache_entry)-1)
	{
	  return(-1);
	}
      
      if ((inp->inp_laddr.s_addr== ce->ip_src) &&
	  (inp->inp_faddr.s_addr== ce->ip_dst) &&
	  (pi->so->so_proto->pr_protocol == ce->ip_p) &&
	  (inp->inp_lport == ce->sport) &&
	  (inp->inp_fport == ce->dport))
	  pass = 1;
	
      }
      
      if (pass)
	{
	  return ce->decision;
	} else {
	  return(CACHE_UNKNOWN);
	}
};

void
inet_out_reset_cache() 
{
  if (inet_out_cache_tree != NULL)
    {
      bt_destroy(inet_out_cache_tree,1);
    }
    
  inet_out_cache_tree = bt_create_tree(CACHE_SIZE);
};

void
inet_out_clear_cache()
{
  if (inet_out_cache_tree != NULL)
    {
      bt_destroy(inet_out_cache_tree,1);
    }
}


int inet_out_add_to_cache(struct policy_inet_out_t *pi, int decision)
{
  ip_cache_entry ce;
  ip_cache_entry old_ce;
  long key;

  struct inpcb *inp;
  
  inp = sotoinpcb(pi->so);

  key = inp->inp_faddr.s_addr^ inp->inp_fport;

  MALLOC(ce, ip_cache_entry, sizeof(struct ip_cache_entry_t), M_TEMP,M_WAITOK);

  ce->ip_p = pi->so->so_proto->pr_protocol;
  
  ce->ip_src = inp->inp_laddr.s_addr;

  ce->ip_dst = inp->inp_faddr.s_addr;

  ce->sport = inp->inp_lport;
  ce->dport = inp->inp_fport;

  ce->key ^= key;

  ce->decision = decision;

  if (inet_out_cache_tree == NULL)
    {
      inet_out_cache_tree = bt_create_tree(CACHE_SIZE);
      bt_insert(inet_out_cache_tree, key, ce);
    } else {
      old_ce = bt_lookup(inet_out_cache_tree, key);
      if (old_ce == (ip_cache_entry)-1)
	{
	  bt_insert(inet_out_cache_tree, key, ce);
	} else {
	  ip_cache_entry tmp;

	  if ((inp->inp_laddr.s_addr== ce->ip_src) &&
	      (inp->inp_faddr.s_addr== ce->ip_dst) &&
	      (pi->so->so_proto->pr_protocol == ce->ip_p) &&
	      (inp->inp_lport == ce->sport) &&
	      (inp->inp_fport == ce->dport))
	    {
	      FREE(ce, M_TEMP);
	      
	    } else {
	      tmp = bt_remove(inet_out_cache_tree, key);
	      FREE(tmp, M_TEMP);
	      bt_insert(inet_out_cache_tree, key, ce);
	    }
	}
    };
  
  return(1);

};
