/* modified from RFC 2104 sample code */

#include <stdio.h>
#include <string.h>
#include <assert.h>
#include "sysdep.h" // includes sys/types.h for sys/socket.h
#ifdef __unix
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <inttypes.h>
#endif
#include "sysutil.h"
#include "md5.h"
#include "hmac_md5.h"

// static
unsigned char HMAC_auth_central::shared_secret[16];
int HMAC_auth_central::secret_len;

/* assumption: hmac_result has to be at least 16 bytes long */
void hmac_md5(unsigned char *text, int text_len,
              unsigned char *key, int key_len,
              unsigned char *hmac_result) /* hmac digest to be filled in */
{
  MD5_CTX context;
  unsigned char k_ipad[64], k_opad[64];
  unsigned char k_digest[16];
  int i;

  if (key_len > 64) {
    md5_simple(key, key_len, k_digest);
    key = k_digest;
    key_len = 16;
  }

  memset(k_ipad, 0, sizeof(k_ipad));
  memset(k_opad, 0, sizeof(k_opad));
  memcpy(k_ipad, key, key_len);
  memcpy(k_opad, key, key_len);

  for (i = 0; i < 64; i++) {
    k_ipad[i] ^= 0x36; /* K XOR ipad */
    k_opad[i] ^= 0x5c; /* K XOR opad */
  }

  MD5Init(&context);
  MD5Update(&context, k_ipad, 64);
  MD5Update(&context, text, text_len);
  MD5Final(hmac_result, &context); /* temp result MD5(k_ipad, text) */

  MD5Init(&context);
  MD5Update(&context, k_opad, 64);
  MD5Update(&context, hmac_result, 16);
  MD5Final(hmac_result, &context); /* final result MD5(k_opad, ...) */
}


/* assumes the digest has at least 16-byte long storage */
void md5_simple(unsigned char *text, int text_len, unsigned char *digest)
{
  MD5_CTX context;

  MD5Init(&context);
  MD5Update(&context, text, text_len);
  MD5Final(digest, &context);
}


/* the key file stores the key in ascii text hex form,
 key_len is filled in after invocation, currently hardcoded at 16. */
void hmac_read_key(char *keyfile, unsigned char *key, int *key_len)
{
  FILE *fp;
  char buf[4];
  int hex_byte, i;

  fp = fopen(keyfile, "r");
  if (fp == NULL) {
    printf("can't open %s\n", keyfile);
    exit(1);
  }

  printf("reading shared secret key from file %s\n", keyfile);

  for (i = 0; i < 16; i++) {
    if (fread(buf, sizeof(char), 2, fp) < 2)
      printf("error reading key file\n");
    buf[3] = '\0';
    sscanf(buf, "%x", &hex_byte);
    key[i] = (unsigned char) (hex_byte & 0xff);
  }

  *key_len = 16;
  fclose(fp);
}


// keyfile is ignored if it's an agent, since 1 key per agent.
HMAC_auth::HMAC_auth(int sd, int prog_type, char *keyfile)
  // initialize replay prevention counter to a non-zero random value
  : sock(sd), my_type(prog_type)
{
  struct sockaddr_in sa;
  socklen_t s_len = sizeof(sa);
  struct test_request challenge;

  switch (my_type)
    {         // I am a controller, dest_ip is the endpoint that 
    case 0x1: // will receive/process the request
      if (getpeername(sock, (struct sockaddr *) &sa, &s_len) < 0) {
        printf("getpeername() failed!\n");
        exit(1);
      }
      if (readn(sd, (char *)&challenge, sizeof(challenge)) <= 0) {
        printf("agent side closed connection!\n");
        exit(1);
      }
      challenge.req_cmd = ntohl(challenge.req_cmd);
      nonce_net = challenge.un.nonce_val; // network order
      challenge.un.nonce_val = ntohl(challenge.un.nonce_val);
      assert(challenge.req_cmd == 0x3); // challenge with nonce value
      nonce_val = challenge.un.nonce_val; // store in this object
      printf("agent %s sent a nonce value of %lu\n",
             inet_ntoa(sa.sin_addr), nonce_val);
      // controller pick a random starting point for replay attack
      replay_prv_cnt = random32(0);
      printf("I picked initial replay prevent counter %lu\n",
             replay_prv_cnt);
      // keyfile is different for a test controller, since the controller
      // needs to know all the shared keys for all endpoints.
      hmac_read_key(keyfile, shared_secret, &secret_len);
      assert(secret_len == 16);
      break;
    case 0x2: // is an endpoint, dest_ip is my local interface's IP address
      if (getsockname(sock, (struct sockaddr *) &sa, &s_len) < 0) {
        printf("getsockname() failed!\n");
        exit(1);
      }
      nonce_val = random32(0); // same nonce value per connection/thread
      printf("setting nonce value %lu\n", nonce_val);
      replay_prv_cnt = 0; // agent is more conservative, use 0.
      // key_file is by default HMAC_KEYFILE for a test endpoint, 
      // since each endpoint has and needs only 1 shared key.
      secret_len = HMAC_auth_central::get_keylen();
      assert(secret_len == 16);
      memcpy(shared_secret, HMAC_auth_central::get_key(), secret_len);
      break;
    default:
      printf("illegal type %d\n", my_type);
      exit(1);
    }

}


void HMAC_auth::calc_digest(unsigned char *req, int req_len)
{
  memcpy(internal_buf, replay_prv_net, 4);
  memcpy(internal_buf+4, &nonce_net, 4);
  assert(req_len+8 < 1024); // assume assembling buffer is big enough
  memcpy(internal_buf+8, req, req_len);
  hmac_md5(internal_buf, req_len+8, shared_secret, secret_len, auth_data);
}

void HMAC_auth::send_req(unsigned char *req, int req_len)
{
  //HMAC_auth_central::inc(replay_prevent); // increment by one
  unsigned long x;
  x = htonl(++replay_prv_cnt);  // now replay counter is kept per connection
  memcpy(replay_prv_net, &x, 4);
  calc_digest(req, req_len); // req_len upper limit is checked in calc_digest()

  mlen_net = htonl(28 + req_len);
  writen(sock, (char *) this, 28);
  writen(sock, (char *) req, req_len);

  //HMAC_auth_central::sync(); // this approach may be too aggressive
}

int HMAC_auth::receive_req(unsigned char *req, int *req_len)
{
  int m;
  m = readn(sock, (char *) this, 28);
  if (m == 0 || m == -1) {
    printf("remote side closed connection\n");
    //close(sock);
    return 2; // closed
  }
  else if (m != 28)
    printf("weird socket read %d!\n", m);
  *req_len = ntohl(mlen_net) - 28;
  if (*req_len > 1024) {
    printf("warning: received a message of length %d, ignoring the rest\n",
           *req_len);
    *req_len = 1024;
  }
  if (readn(sock, (char *) req, *req_len) != *req_len)
    printf("weird socket read!\n");

  // first verify that the sender used the same nonce value we requested
  unsigned long x;
  memcpy(&x, &nonce_net, 4);
  x = ntohl(x);
  if (x != nonce_val) {
    printf("illegal nonce value %lu != expected %lu\n", x, nonce_val);
    return false;
  }
  
  // next verify the replay prevention counter is not violated
  memcpy(&x, replay_prv_net, 4);
  x = ntohl(x);
  //  cur = HMAC_auth_central::cur_replay_value();
  if (x <= replay_prv_cnt) {
    printf("illegal replay prevent counter %ld <= %ld\n", x, replay_prv_cnt);
    return false;
  }

  // then verify HMAC-MD5 digest
  unsigned char digest_rcvd[16]; // keep a copy of received digest value
  memcpy(digest_rcvd, auth_data, 16);
  calc_digest(req, *req_len);
  if (memcmp(digest_rcvd, auth_data, 16) != 0)
    return false;
  else {
    replay_prv_cnt = x;
    return true;
  }
}

void HMAC_auth_central::init(char *keyfile)
{
  hmac_read_key(keyfile, shared_secret, &secret_len);
  assert(secret_len == 16);
}
