Import('env')

cuda_gen = Builder(action='./compiler/generator < $SOURCE > $TARGET')
cuda_object = Builder(action='$NVCC -c $NVCC_FLAGS $SOURCE -o $TARGET')
run_test = Builder(action='./$SOURCE | diff ${SOURCE}.out -')

AddOption('--rtlib', dest='rtlib', default='ocelot',
            help='CUDA runtime library (ocelot or cudart)')
AddOption('--cuda_arch', dest='cuda_arch', default='sm_20',
            help='CUDA architecture')
AddOption('--cuda_lib', dest='cuda_lib', default='/usr/local/cuda/lib', 
            help='Cuda lib directory')
AddOption('--8400gs', dest='8400gs',
            default=False, action="store_true", 
            help="Build for the GeForce 8400 GS")

rtlib = GetOption('rtlib')
cuda_lib = GetOption('cuda_lib')
cuda_arch = GetOption('cuda_arch')
ccflags = "-Wall -g"

# This is just a convenience for me to quickly choose the options for my GPU
# If you're using ocelot, don't use this flag
if GetOption('8400gs'):
    rtlib = 'cudart'
    cuda_lib = '/opt/cuda/lib64'
    cuda_arch = 'sm_11'
    ccflags += " -DBLOCK_SIZE=256"

env['LINK'] = env['CXX']
env.Append(LIBS = [rtlib, 'm'])
env.Append(LIBPATH = [cuda_lib])
env.Append(CPPPATH = ['./rtlib'])
env.Append(CUDA_ARCH = cuda_arch)
env.Append(CCFLAGS = ccflags)
env.Append(NVCC = 'nvcc', NVCC_FLAGS = '-arch=$CUDA_ARCH -I$CPPPATH -Xcompiler $CCFLAGS')
env.Append(BUILDERS = {'CudaGen' : cuda_gen,
                       'CudaObject' : cuda_object,
                       'RunTest': run_test})

# Add the name of your example to this list to get it compiled
test_cases = ['hello', 'reduce', 'control_flow', 'arrays', 'functions', 'float',
              'complex', 'map', 'logic', 'inline', 'strings', 'length', 'pfor',
              'dotprod', 'time']
headers = ['libvector', 'vector_utils', 'vector_array', 'vector_iter']

for test_case in test_cases:
    env.CudaGen(test_case + '.cu', [test_case + '.vec'])
    env.CudaObject(test_case + '.o', [test_case + '.cu'] +
        ['../rtlib/' + h + '.hpp' for h in headers])
    env.Program(test_case, test_case + '.o')
    env.RunTest(test_case + '_test', test_case)
