open Ast
open Sast
open Printf

(* Add builtin functions in varmap which used in semantic analysis *)
let add_builtin_fun varmap =
    
    let builtinmap =
        VarMap.add "if" (Fun(Fix, Wild_Card(1),
            [Bits(1); Wild_Card(1); Wild_Card(1)])) varmap in

    let builtinmap = 
        VarMap.add "+" (Fun(Len_Flex, Int, [Int])) builtinmap in

    let builtinmap = 
        VarMap.add "*" (Fun(Len_Flex, Int, [Int])) builtinmap in

    let builtinmap = 
        VarMap.add "-" (Fun(Fix, Int, [Int;Int])) builtinmap in

    let builtinmap = 
        VarMap.add "/" (Fun(Fix, Int, [Int;Int])) builtinmap in
     
    let builtinmap = 
        VarMap.add "mod" (Fun(Fix, Int, [Int;Int])) builtinmap in

    let builtinmap = 
        VarMap.add "pow" (Fun(Fix, Int, [Int;Int])) builtinmap in

    let builtinmap = 
        VarMap.add "inverse" (Fun(Fix, Int, [Int;Int])) builtinmap in

    let builtinmap = 
        VarMap.add "and" (Fun(Len_Flex, Bits(1), [Bits(1)])) builtinmap in

    let builtinmap = 
        VarMap.add "or" (Fun(Len_Flex, Bits(1), [Bits(1)])) builtinmap in

    let builtinmap = 
        VarMap.add "not" (Fun(Fix, Bits(1), [Bits(1)])) builtinmap in

    let builtinmap = 
        VarMap.add "less" (Fun(Fix, Bits(1), [Int;Int])) builtinmap in

    let builtinmap = 
        VarMap.add "greater" (Fun(Fix, Bits(1), [Int;Int])) builtinmap in 

    let builtinmap = 
        VarMap.add "leq" (Fun(Fix, Bits(1), [Int;Int])) builtinmap in

    let builtinmap = 
        VarMap.add "geq" (Fun(Fix, Bits(1), [Int;Int])) builtinmap in

    let builtinmap =
        VarMap.add "eq" (Fun(Fix, Bits(1),
            [Wild_Card(1);Wild_Card(1)])) builtinmap in

    let builtinmap = 
        VarMap.add "neq" (Fun(Fix, Bits(1),
            [Wild_Card(1);Wild_Card(1)])) builtinmap in

    let builtinmap = 
        VarMap.add "&" (Fun(Len_Flex, Bits(0), [Bits(0)])) builtinmap in

    let builtinmap = 
        VarMap.add "|" (Fun(Len_Flex, Bits(0), [Bits(0)])) builtinmap in

    let builtinmap = 
        VarMap.add "^" (Fun(Len_Flex, Bits(0), [Bits(0)])) builtinmap in

    let builtinmap = 
        VarMap.add "parity" (Fun(Fix, Bits(1), [Bits(0)])) builtinmap in

    let builtinmap = 
        VarMap.add "<<" (Fun(Fix, Bits(0), [Bits(0); Int])) builtinmap in

    let builtinmap =
        VarMap.add ">>" (Fun(Fix, Bits(0), [Bits(0);Int])) builtinmap in

    let builtinmap =
        VarMap.add ">>>" (Fun(Fix, Bits(0), [Bits(0);Int])) builtinmap in

    let builtinmap =
        VarMap.add "<<<" (Fun(Fix, Bits(0), [Bits(0);Int])) builtinmap in

    let builtinmap =
        VarMap.add "flip-bit" (Fun(Fix, Bits(0),
            [Bits(0);Int])) builtinmap in

    let builtinmap =
        VarMap.add "flip" (Fun(Fix, Bits(0), [Bits(0)])) builtinmap in
    
    let builtinmap =
        VarMap.add "set" (Fun(Fix, Wild_Card(1),
            [Wild_Card(1); Wild_Card(1)])) builtinmap in

    let builtinmap =
        VarMap.add "if" (Fun(Fix, Wild_Card(1),
            [Bits(1); Wild_Card(1); Wild_Card(1)])) builtinmap in

    let builtinmap =
        VarMap.add "group" (Special("group")) builtinmap in

    let builtinmap =
        VarMap.add "merge" (Special("merge")) builtinmap in

    let builtinmap =
        VarMap.add "map" (Special("map")) builtinmap in

    let builtinmap =
        VarMap.add "reduce" (Special("reduce")) builtinmap in

    let builtinmap =
        VarMap.add "transpose" (Special("transpose")) builtinmap in

    let builtinmap =
        VarMap.add "zero" (Special("zero")) builtinmap in

    let builtinmap =
        VarMap.add "rand" (Special("rand")) builtinmap in

    let builtinmap =
        VarMap.add "int-of-bits" (Fun(Fix, Int, [Bits(0)])) builtinmap in
    
    let builtinmap =
        VarMap.add "string-of-bits"
        (Fun(Fix, String, [Bits(0)])) builtinmap in

    let builtinmap =
        VarMap.add "bits-of-int"
        (Fun(Fix, Bits(0), [Int; Int])) builtinmap in

    let builtinmap =
        VarMap.add "bits-of-string"
        (Fun(Fix, Bits(0), [Int; String])) builtinmap in

    let builtinmap =
        VarMap.add "pad"
        (Fun(Fix, Bits(0), [Bits(-1); Int])) builtinmap in

    let builtinmap =
        VarMap.add "is-prime"
        (Fun(Fix, Bits(1), [Int])) builtinmap in
    let builtinmap =
        VarMap.add "next-prime"
        (Fun(Fix, Int, [Int])) builtinmap in
    builtinmap
;;


(* generate c code string for the function with dynamic type *)

let gen_if name ts =
    sprintf
"%s if__%s(bitset<1> b, %s x, %s y) {
    if (b == bitset<1>(1))
        return x;
    else
        return y;
}\n" ts name ts ts
;;

let gen_eq name ts =
    sprintf
"bitset<1> eq__%s(%s x, %s y) {
    if (x == y)
        return bitset<1>(1);
    else
        return bitset<1>(0);
}\n" name ts ts
;;

let gen_neq name ts =
    sprintf
"bitset<1> neq__%s(%s x, %s y) {
    if (x == y)
        return bitset<1>(0);
    else
        return bitset<1>(1);
}\n" name ts ts
;;

let rec gen_xor_xors len s =
    if len = 0 then
        s
    else
        let news = sprintf "bs%d ^ %s" len s in
        gen_xor_xors (len-1) news
;;

let rec gen_xor_args len s bt =
    if len = 0 then
        s
    else
        let news = sprintf "%s bs%d, %s" bt len s in
        gen_xor_args (len-1) news bt
;;

let gen_xor bit_len arg_len =
    let bt = sprintf "bitset<%d>" bit_len in
    let args =
        gen_xor_args (arg_len-1) (sprintf "%s bs%d" bt arg_len) bt in
    let xors =
        gen_xor_xors (arg_len-1) (sprintf "bs%d" arg_len) in
    sprintf "%s xor__%d__%d (%s) {%s    return %s;%s}\n"
    bt bit_len arg_len args "\n" xors "\n"

let rec gen_or_ors len s =
    if len = 0 then
        s
    else
        let news = sprintf "bs%d | %s" len s in
        gen_or_ors (len-1) news;;

let gen_or bit_len arg_len =
    let bt = sprintf "bitset<%d>" bit_len in
    let args =
        gen_xor_args (arg_len-1) (sprintf "%s bs%d" bt arg_len) bt in
    let ors = gen_or_ors (arg_len-1) (sprintf "bs%d" arg_len) in
    sprintf "%s or__%d__%d (%s) {%s    return %s;%s}\n"
    bt bit_len arg_len args "\n" ors "\n"

let rec gen_and_ands len s =
    if len = 0 then
        s
    else
        let news = sprintf "bs%d & %s" len s in
        gen_or_ors (len-1) news;;

let gen_and bit_len arg_len =
    let bt = sprintf "bitset<%d>" bit_len in
    let args =
        gen_xor_args (arg_len-1) (sprintf "%s bs%d" bt arg_len) bt in
    let ands = gen_and_ands (arg_len-1) (sprintf "bs%d" arg_len) in
    sprintf "%s and__%d__%d (%s) {%s    return %s;%s}\n"
    bt bit_len arg_len args "\n" ands "\n"

let gen_merge out_b_len name in_b_len =
    sprintf
"bitset<%d> %s(vector< bitset<%d> > vb) {
    string s = \"\";
    for (int i = 0; i < vb.size(); i++)
        s = s + vb[i].to_string();
    return bitset<%d> (s);
}\n" out_b_len name in_b_len out_b_len
;;

let gen_group out_b_len name in_b_len =
    sprintf
"vector< bitset<%d> > %s(bitset<%d> b, string ns) {
    int v_len = atoi(ns.c_str());
    v_len = %d / v_len;
    string s = b.to_string();
    vector < bitset<%d> > result;
    result.resize(v_len);
    for (int i = 0; i < v_len; i++) {
        result[i] = bitset<%d>(s.substr(%d*i, %d));
    }
    return result;
}\n"
    out_b_len name in_b_len
    in_b_len
    out_b_len
    out_b_len out_b_len out_b_len
;;

let gen_map out_t_s name in_t_s =
    sprintf
"vector< %s > %s(function<%s (%s)> f, vector< %s > b) {
    vector< %s > result;
    result.resize(b.size());
    for (int i = 0; i < b.size(); i++) {
        result[i] = f(b[i]);
    }
    return result;
}\n"
    out_t_s name out_t_s in_t_s in_t_s
    out_t_s
;;

let gen_reduce out_t_s name in_t_s =
    sprintf
"%s %s(function<%s (%s, %s)> f, vector< %s > bsv) {
    %s result = bsv[0];
    for (int i = 1; i < bsv.size(); i++)
        result = f(result, bsv[i]);
    return result;
}\n"
    out_t_s name out_t_s in_t_s in_t_s in_t_s
    out_t_s
;;

let gen_transpose ts name =
    sprintf
"vector<vector<%s> > %s(vector<vector<%s> > m) {
    int nrow = m.size();
    int ncol = m[0].size();
    vector<vector<%s> > newm;
    newm.resize(ncol);
    for (int i = 0; i < ncol; i++)
        newm[i].resize(nrow);
    for (int i = 0; i < nrow; i++)
        for (int j = 0; j < ncol; j++)
            newm[j][i] = m[i][j];
    return newm;
}\n" ts name ts ts
;;

let gen_rotate_r bit_len =
    sprintf
"bitset<%d> rotate_r__%d(bitset<%d> bs, string ns) {
    int n = atoi(ns.c_str());
    int _n = n %% %d;
    string s = bs.to_string();
    s = s + s;
    bitset<%d> dbs = bitset<%d>(s);
    dbs <<= %d - _n;
    dbs >>= %d;
    s = dbs.to_string();
    s = s.substr(%d, %d);
    bitset<%d> result = bitset<%d>(s);
    return result;
}\n"
    bit_len bit_len bit_len
    bit_len
    (bit_len*2) (bit_len*2)
    bit_len
    bit_len
    bit_len bit_len
    bit_len bit_len
;;

let gen_rotate_l bit_len =
    sprintf
"bitset<%d> rotate_l__%d(bitset<%d> bs, string ns) {
    int n = atoi(ns.c_str());
    int _n = n %% %d;
    string s = bs.to_string();
    s = s + s;
    bitset<%d> dbs = bitset<%d>(s);
    dbs <<= _n;
    dbs >>= %d;
    s = dbs.to_string();
    s = s.substr(%d, %d);
    bitset<%d> result = bitset<%d>(s);
    return result;
}\n"
    bit_len bit_len bit_len
    bit_len
    (bit_len*2) (bit_len*2)
    bit_len
    bit_len bit_len
    bit_len bit_len
;;

let gen_shift_r bit_len =
    sprintf
"bitset<%d> shift_r__%d(bitset<%d> bs, string ns) {
    int n = atoi(ns.c_str());
    return bs >> n;
}\n" bit_len bit_len bit_len
;;

let gen_shift_l bit_len =
    sprintf
"bitset<%d> shift_l__%d(bitset<%d> bs, string ns) {
    int n = atoi(ns.c_str());
    return bs << n;
}\n" bit_len bit_len bit_len
;;

let gen_flip_bit bit_len =
    sprintf
"bitset<%d> flip_bit__%d(bitset<%d> bs, string ns) {
    int n = atoi(ns.c_str());
    bs[%d-n-1] = !bs[%d-n-1];
    return bs;
}\n" bit_len bit_len bit_len bit_len bit_len
;;

let gen_flip bit_len =
        sprintf
"bitset<%d> flip__%d(bitset<%d> bs) {
    return bs.flip();
}\n" bit_len bit_len bit_len

let gen_zero n name =
    sprintf
"bitset<%d> %s (string s) {
    return bitset<%d>(0);
}\n" n name n
;;

let gen_rand n name =
    sprintf
"bitset<%d> %s (string ns) {
    int seed = clock();
    mpz_t a;
    mpz_init(a);
    gmp_randstate_t state;
    gmp_randinit_mt(state);
    string s = \"1\";
    for (int i = 1; i < %d; i++) {
        seed = clock();
        gmp_randseed_ui (state, seed);
        mpz_urandomb(a, state, 1);
        if (string(mpz_get_str(NULL, 10, a)) == \"0\")
            s = s + \"0\";
        else
            s = s + \"1\";
    }
    return bitset<%d>(s);
}\n" n name n n
;;

let gen_int_of_bits n =
    sprintf
"string int_of_bits__%d (bitset<%d> b) {
    string s = b.to_string();
    char *cp = new char [s.length()];
    strcpy(cp, s.c_str());
    mpz_t n;
    mpz_init(n);
    mpz_set_str(n, cp, 2);
    mpz_get_str(cp, 10, n);
    return string(cp);
}\n" n n
;;

let gen_bits_of_int n =
    sprintf
"bitset<%d> bits_of_int__%d (string a, string ns) {
    int n = atoi(ns.c_str());
    return bitset<%d>(n);
}\n" n n n
;;

let gen_bits_of_string n =
    sprintf
"bitset<%d> bits_of_string__%d(string a, string s) {
    string result = \"\";
    string pad = \"\";
    if (s.length()*8 < %d) {
        string tmp((%d - s.length()*8), '0');
        pad = tmp;
    }
    for (int i = 0; i < s.length(); i++) {
        short c = s.at(i);
        result += (bitset<8>(c)).to_string();
    }
    result += pad;
    return bitset<%d>(result);
}\n" n n n n n
;;

let gen_string_of_bits n =
    sprintf
"string string_of_bits__%d(bitset<%d> bs) {
    string result = \"\";
    string bss = bs.to_string();
    string pad;
    int pad_len = bss.length() %% 8;
    if (pad_len > 0) {
        pad_len = 8 - pad_len;
        string tmp(pad_len, '0');
        pad = tmp;
    }
    bss += pad;
    for (int i = 0; i < bss.length(); i = i+8) {
        string subbss = bss.substr(i, 8);
        bitset<8> subbs(subbss);
        unsigned int un = subbs.to_ulong();
        int n = un;
        string tmp(1, n);
        result += tmp;
    }
    return result;
}\n" n n
;;


let gen_parity n =
    sprintf
"bitset<1> parity_%d (bitset<%d> bs) {
    return bitset<1>((bs.count()) %% 2);;
}\n" n n
;;

let gen_pad m n =
    sprintf
"bitset<%d> pad__%d__%d(bitset<%d> bs, string ns) {
    string bss = bs.to_string();
    string pad = \"\";
    int pad_len;
    if (%d > %d) {
        pad_len = %d - %d;
        string tmp(pad_len, '0');
        pad = tmp;
    }
    bss += pad;
    return bitset<%d>(bss);
}" m m n n m n m n m

;;