/*
 * /dev/policy driver
 *
 * Maintained by: Tom Langan (tlangan@dsl.cis.upenn.edu)
 *                Matt Lehrer (mlehrer@dsl.cis.upenn.edu)
 *
 * Original code: Sotiris Ioannidis (sotiris@dsl.cis.upenn.edu)
 *
 */

#include <policy/policy.h>
#include <sys/types.h>
#include <sys/param.h>
#include <sys/proc.h>
#include <sys/malloc.h>

#include "policy_routines.h"

char *
my_inet_ntop4(src, dst, normalize)
in_addr_t	*src;
char	*dst;
int	normalize;
{
	char		fmt[] = "%03u.%03u.%03u.%03u";
	char		tmp[sizeof("255.255.255.255")];

	in_addr_t	src2;


	if (normalize)
		src2 = ntohl(*src);
	else
		src2 = *src;

	sprintf(tmp, fmt, ((u_int8_t *) &src2)[0], ((u_int8_t *) &src2)[1], ((u_int8_t *) &src2)[2], ((u_int8_t *) &src2)[3]);

	strcpy(dst, tmp);

	return(dst);
}

policy_context	*
policy_create_context()
{
	policy_context	*context;

	int		zero = 0;

	char		*tmp;

	policy_mbuf *pd;

	/* possible error checking for failed mallocs */
	/* Allocate the policy_context */
	MALLOC(context, policy_context *, sizeof(policy_context), M_TEMP, M_WAITOK);
	if (context == NULL) {
		printf("what the hell !!!\n");

		return(NULL);
	}
	bzero(context, sizeof(policy_context));


	/* Allocate the first policy_mbuf where we will hold metadata */
	MALLOC(context->p_mbuf, policy_mbuf *, sizeof(policy_mbuf), M_TEMP, M_WAITOK);
	if (context->p_mbuf == NULL) {
		printf("what the hell !!!\n");

		return(NULL);
	}

	
	context->p_mbuf->length = 3 * sizeof(u_int32_t);
	tmp = (char *) context->p_mbuf->data;
	bcopy((void *) &zero, (void *) tmp, sizeof(u_int32_t));
	context->sequence = 0;

	context->p_mbuf->next = NULL;

	/* XXX this needs to go */
	if (curproc == NULL) {
		int	zero = 0;

		bcopy((void *) &zero, (void *) (tmp + sizeof(u_int32_t)), sizeof(uid_t));
	} else {
	  bcopy((void *) &(curproc->p_cred->p_ruid), (void *) (tmp + sizeof(u_int32_t)), sizeof(uid_t));
	}

	bcopy((void *) &(zero), (void *) (tmp + sizeof(u_int32_t) + sizeof(uid_t)), sizeof(uid_t));

	/* Allocate the second policy_mbuf where we will hold data */
	MALLOC(context->p_mbuf->next, policy_mbuf *, sizeof(policy_mbuf), M_TEMP, M_WAITOK);

	if (context->p_mbuf->next == NULL) {
		printf("what the hell !!!\n");

		return(context);
	}

	context->p_mbuf->next->length = 0;
	context->p_mbuf->next->next = NULL;
	context->pc_type=0;
	context->pc_data = NULL;

	return(context);
}

void
policy_destroy_context(context)
policy_context	*context;
{

  /* XXX destroy pmbufs if present */

  if (context->reply) {
    FREE(context->reply, M_TEMP);
    context->reply = NULL;
  }
  
  FREE(context, M_TEMP);
	
}

void
policy_commit_context(context)
policy_context	*context;
{
        int n;
	
	bcopy((void *) (context->p_mbuf->data + sizeof(u_int32_t) + sizeof(uid_t)), (void *) &n, sizeof(u_int32_t));


	if (policy_context_head == NULL) {
		policy_context_head = policy_context_tail = context;
	} else {
		policy_context_tail->policy_context_next = context;
		policy_context_tail = context;
	}
}

void
policy_add_int(context, name, value)
policy_context	*context;
char		*name;
int		value;
{
	u_int32_t	len0, len1, n;

	policy_mbuf	*tmp0 = context->p_mbuf, *tmp1 = context->p_mbuf->next;


	// im assuming for the moment that all lengths are less than POLICY_DATA_SIZE
	len0 = strlen(name);
	len1 = 20;	/* XXX is this enough ? too much ? how to determine ? */

	while (tmp1->next != NULL)
		tmp1 = tmp1->next;

	if ((POLICY_DATA_SIZE - tmp1->length) < len0) {
		MALLOC(tmp1->next, policy_mbuf *, sizeof(policy_mbuf), M_TEMP, M_WAITOK);

		tmp1 = tmp1->next;
		tmp1->length = 0;
		tmp1->next = NULL;
	}

	bcopy((void *) name, (void *) (tmp1->data + tmp1->length), len0);
	tmp1->length += len0;

	if ((POLICY_DATA_SIZE - tmp1->length) < len1) {
		MALLOC(tmp1->next, policy_mbuf *, sizeof(policy_mbuf), M_TEMP, M_WAITOK);

		tmp1 = tmp1->next;
		tmp1->length = 0;
		tmp1->next = NULL;
	} 

	sprintf(tmp1->data + tmp1->length, "%d", value); /* XXX careful about buffer overflow here */
	tmp1->length += len1;

	/* get n */
	bcopy((void *) (tmp0->data + sizeof(u_int32_t) + sizeof(uid_t)),
		(void *) &n, sizeof(u_int32_t));
	bcopy((void *) &len0,
		(void *) (tmp0->data + (n + 2)*sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));
	bcopy((void *) &len1,
		(void *) (tmp0->data + (n + 3)*sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));
	n += 2;
	/* put n */
	bcopy((void *) &n,
		(void *) (tmp0->data + sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));

	tmp0->length += 2*sizeof(u_int32_t);
	// add more code here in case we overflow to add another policy_mbuf
	

}

void
policy_add_string(context, name, value)
policy_context	*context;
char		*name;
char		*value;
{
	u_int32_t	len0, len1, n;

	policy_mbuf	*tmp0 = context->p_mbuf, *tmp1 = context->p_mbuf->next;

	// im assuming for the moment that all lengths are less than POLICY_DATA_SIZE
	len0 = strlen(name);
	len1 = strlen(value);

	while (tmp1->next != NULL)
		tmp1 = tmp1->next;

	if ((POLICY_DATA_SIZE - tmp1->length) < len0) {
		MALLOC(tmp1->next, policy_mbuf *, sizeof(policy_mbuf), M_TEMP, M_WAITOK);

		tmp1 = tmp1->next;
		tmp1->length = 0;
		tmp1->next = NULL;
	}

	bcopy((void *) name, (void *) (tmp1->data + tmp1->length), len0);
	tmp1->length += len0;

	if ((POLICY_DATA_SIZE - tmp1->length) < len1) {
		MALLOC(tmp1->next, policy_mbuf *, sizeof(policy_mbuf), M_TEMP, M_WAITOK);

		tmp1 = tmp1->next;
		tmp1->length = 0;
		tmp1->next = NULL;
	} 

	bcopy((void *) value, (void *) (tmp1->data + tmp1->length), len1);
	tmp1->length += len1;

	/* get n */
	bcopy((void *) (tmp0->data + sizeof(u_int32_t) + sizeof(uid_t)),
		(void *) &n, sizeof(u_int32_t));
	bcopy((void *) &len0,
		(void *) (tmp0->data + (n + 2)*sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));
	bcopy((void *) &len1,
		(void *) (tmp0->data + (n + 3)*sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));
	n += 2;
	/* put n */
	bcopy((void *) &n,
		(void *) (tmp0->data + sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));

	tmp0->length += 2*sizeof(u_int32_t);
	// add more code here in case we overflow to add another policy_mbuf
	

}

void
policy_add_ipv4address(context, name, value)
policy_context	*context;
char		*name;
in_addr_t	*value;
{
	char		dst[16];

	u_int32_t	len0, len1, n;

	policy_mbuf	*tmp0 = context->p_mbuf, *tmp1 = context->p_mbuf->next;

	my_inet_ntop4(value, dst, 0);

	// im assuming for the moment that all lengths are less than POLICY_DATA_SIZE
	len0 = strlen(name);
	len1 = strlen(dst);

	while (tmp1->next != NULL)
		tmp1 = tmp1->next;

	if ((POLICY_DATA_SIZE - tmp1->length) < len0) {
		MALLOC(tmp1->next, policy_mbuf *, sizeof(policy_mbuf), M_TEMP, M_WAITOK);

		tmp1 = tmp1->next;
		tmp1->length = 0;
		tmp1->next = NULL;
	}

	bcopy((void *) name, (void *) (tmp1->data + tmp1->length), len0);
	tmp1->length += len0;

	if ((POLICY_DATA_SIZE - tmp1->length) < len1) {
		MALLOC(tmp1->next, policy_mbuf *, sizeof(policy_mbuf), M_TEMP, M_WAITOK);

		tmp1 = tmp1->next;
		tmp1->length = 0;
		tmp1->next = NULL;
	} 

	bcopy((void *) dst, (void *) (tmp1->data + tmp1->length), len1);
	tmp1->length += len1;

	/* get n */
	bcopy((void *) (tmp0->data + sizeof(u_int32_t) + sizeof(uid_t)),
		(void *) &n, sizeof(u_int32_t));
	bcopy((void *) &len0,
		(void *) (tmp0->data + (n + 2)*sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));
	bcopy((void *) &len1,
		(void *) (tmp0->data + (n + 3)*sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));
	n += 2;
	/* put n */
	bcopy((void *) &n,
		(void *) (tmp0->data + sizeof(u_int32_t) + sizeof(uid_t)),
		sizeof(u_int32_t));

	tmp0->length += 2*sizeof(u_int32_t);
	// add more code here in case we overflow to add another policy_mbuf


}



int context_sessions_inited = 0;

void init_context_session(int freecontexts)
{
  int j;

  for (j=0; j<MAX_POLICY_CONTEXTS; j++)
    {
      if (freecontexts && context_session[j].context!= NULL)
	policy_destroy_context(context_session[j].context);

      context_session[j].context = NULL;
    }

  context_sessions_inited = 1;
};

// NULL = failure
policy_context *get_context_session(int session_id)
{
  if (session_id == 0 || session_id > MAX_POLICY_CONTEXTS)
    panic("Out of bounds context session in get");
    
  if (!(context_sessions_inited))
    {
      init_context_session(0);
    }
  
  return context_session[session_id-1].context;
};


int create_context_session(policy_context *context)
{
  int j;
  int i,sum;

  if (!(context_sessions_inited))
    {
      init_context_session(0);
    }

  j=0;
  do {
    j++;
  } while ((j<=MAX_POLICY_CONTEXTS)&&(context_session[j-1].context != NULL));

  if (j>MAX_POLICY_CONTEXTS) {
    return 0;
  } else {
  new_context:
    bcopy((void *) &j, (void *) context->p_mbuf->data, sizeof(u_int32_t));
    context->sequence = j;
    context_session[j-1].context = context;
    context_session[j-1].expiretime = time;
    context_session[j-1].expiretime.tv_sec += POLICY_SESSION_EXP_SEC;
    context_session[j-1].expiretime.tv_usec += POLICY_SESSION_EXP_USEC;
    context_session[j-1].proc = 0;
    context_session[j-1].index = 0;

    /* XXX DEBUG */

    bcopy((void *)context->p_mbuf->data, (void *)&context_session[j-1].hdr, 2*sizeof(u_int32_t) + sizeof(uid_t));
    context_session[j-1].reads = 0;
    context_session[j-1].bytesread = 0;
    MALLOC(context_session[j-1].lengths,u_int32_t *, sizeof(u_int32_t)*context_session[j-1].hdr.num_elems,M_TEMP, M_NOWAIT);
	
    bcopy((void*)context->p_mbuf->data + 2*sizeof(u_int32_t) + sizeof(uid_t), (void *)context_session[j-1].lengths, sizeof(u_int32_t)*context_session[j-1].hdr.num_elems);

    sum = 0;
    for (i=0;i<context_session[j-1].hdr.num_elems;i++)
      {
	sum += context_session[j-1].lengths[i];
      }
    
    context_session[j-1].totalbytes = (context_session[j-1].hdr.num_elems+2)*sizeof(u_int32_t) + sizeof(uid_t);
    context_session[j-1].totalbytes += sum;


    /*  */

    context->session_id = j;
    
    return j;
  }

};

void free_context_session(int session_id)
{
  if (session_id == 0 || session_id > MAX_POLICY_CONTEXTS)
    panic("Out of bounds context session in free");

  context_session[session_id-1].context->session_id = 0;

  FREE(context_session[session_id-1].lengths, M_TEMP);

  if ((session_id >0)&&(session_id <= MAX_POLICY_CONTEXTS))
    context_session[session_id-1].context = NULL;
};

int lookup_context_session(pid_t pid)
{
  int j;


  j=0;


  do {
    j++;
  } while ((j<=MAX_POLICY_CONTEXTS)&&((context_session[j-1].proc != pid)||(context_session[j-1].context == NULL)));

  if ((j>MAX_POLICY_CONTEXTS)||(context_session[j-1].context == NULL))
    {
      return 0;
    } else {
      return j;
    }
};

int assign_context_session(pid_t pid)
{
  policy_context *context;
  int sum;

  if (policy_context_head == NULL)
    return 0;

  context = policy_context_head;

  if (context->session_id == 0)
    {
      create_context_session(context);
    }
  sum = 0;

  while ((context->policy_context_next != NULL)&&(context_session[context->session_id].proc != 0))
    {
      context = context->policy_context_next;
      if (context->session_id == 0)
	{
	  create_context_session(context);

	}
      sum++;

    }
  

  if (context_session[context->session_id-1].proc != 0)
    {
      return 0;
    } else {
      context_session[context->session_id-1].proc = pid;
      return context->session_id;
    }
};

int policy_consistent(policy_context *context)
{
  int sum,j,length;

  policy_mbuf *pmb = context->p_mbuf;
  
  u_int32_t *lengths; 

  policy_mbuf_hdr *hdr;

  if (pmb == NULL)
    {
      printf("Consistency check: p_mbuf == NULL\n");
      return 1;
    } else if (pmb->data == NULL) {
      printf("Consistency check: p_mbuf != NULL but p_mbuf->data == NULL\n");
      return (pmb->length == 0);
    } 

  lengths = (u_int32_t*)(((void *)context->p_mbuf->data) + 2*sizeof(u_int32_t) + sizeof(uid_t));

  hdr = (policy_mbuf_hdr *)pmb->data;

  if (hdr == NULL)
    {
      printf("Consistency check: p_mbuf header == NULL\n");
      return 1;
    }

  length = pmb->length;

  while (pmb->next != NULL)
    {
      pmb = pmb->next;
      length += pmb->length;
    }

  sum = 2*sizeof(u_int32_t) + sizeof(uid_t);
  sum += hdr->num_elems * sizeof(u_int32_t);

  for (j=0; j<hdr->num_elems; j++)
    {
      sum += lengths[j];
    }

  printf("Consistency Check: %d ?= %d  ---> %s\n", sum, length, (sum == length ? "yes" : "no"));

  return (sum == length);

};
