import cupy as cp
import time
import numpy as np
from mpi import size, rank

with open('filtering.cu') as f:
    code = f.read()

kers = ('exact_one', 'filter')
ep_pontok_module = cp.RawModule(code=code, options=('--std=c++11',), name_expressions=kers)
exact_one_cuda = ep_pontok_module.get_function(kers[0])
filter_cuda = ep_pontok_module.get_function(kers[1])

def expSpace(min, max, N, exponentialliness = 20.0):
    LinVec = cp.linspace(0, cp.log10(exponentialliness+1, dtype=cp.float64),N, dtype=cp.float64)
    return (max-min)/exponentialliness * (10.0**LinVec - 1) + min


def convert(egyensulyi_mtx):
    parok = np.empty([0], dtype=np.int8)

    for i in egyensulyi_mtx:
        for S in range(0, 4):
            for U in range(0, 4):
                if i[S][U] == 1:
                    parok = np.append(parok, S+1)
                    parok = np.append(parok, U+1)
        parok = np.append(parok, 0)
        parok = np.append(parok, 0)
    
    N = parok.size
    return np.resize(parok, (int(N/2), 2))

def printresults(egyensulyi_mtx):
    ossz = 0
    for i in egyensulyi_mtx:
        parok = np.empty([0], dtype=np.int8)
        for S in range(0, 4):
            for U in range(0, 4):
                if i[S][U] == 1:
                    parok = np.append(parok, S+1)
                    parok = np.append(parok, U+1)
        N = parok.size
        print(f"{int(N/2)}x2")
        print(np.resize(parok, (int(N/2), 2)))
        print()

def writetofile(filename, Cx_cpu, Cy_cpu, Dx_cpu, Dy_cpu, Dz_cpu, mtx_cpu, mpi):
    size_C = Cx_cpu.size
    size_D = Dx_cpu.size
    lcm = compute_lcm(size_C, size_D)
    pos = size_C * size_D

    if mpi:
        filename = filename + f'R{rank}-{size}.out'
    print("Filename: ", filename)
    f = open(filename, "w")

    for i in range(0, pos):
        #parok = np.empty([0], dtype=np.int8)
        #for S in range(0, 4):
        #    for U in range(0, 4):
        #        if mtx_cpu[i][S][U] == 1:
        #            parok = np.append(parok, S+1)
        #            parok = np.append(parok, U+1)
        #N = parok.size
        f.write(f"{Cx_cpu[i % size_C]}, {Cy_cpu[i % size_C]}, {Dx_cpu[(i + int(i / lcm)) % size_D]}, {Dy_cpu[(i + int(i / lcm)) % size_D]}, {Dz_cpu[(i + int(i / lcm)) % size_D]}\n")
        #f.write(np.array2string(np.resize(parok, (int(N/2), 2))))
        f.write(np.array2string(mtx_cpu[i]))
        f.write("\n")

    f.close()

def writetofile2(filename, Cx_cpu, Cy_cpu, Dx_cpu, Dy_cpu, Dz_cpu, mtx_cpu, mpi):
    size_C = Cx_cpu.size
    size_D = Dx_cpu.size
    lcm = compute_lcm(size_C, size_D)
    pos = size_C * size_D

    if mpi:
        filename = filename + f'R{rank}-{size}.out'

    print("Filename: ", filename)
    f = open(filename, "w")
    for i in range(0, pos):
        s = 0
        for S in range(0, 4):
            for U in range(0, 4):
                s *= 2
                s += mtx_cpu[i][S][U]
                
        f.write(f"{i % size_C}, {(i + int(i / lcm)) % size_D}, {s}\n")

    f.close()

        
def search(egyensulyi_mtx, S, U):
    for i in egyensulyi_mtx:
        if i[S][U] == 1:
            print(i)

def exact_one(egyensulyi_mtx, S, U):
    for i in egyensulyi_mtx:
        ok = True
        for j in range(0, 4):
            for k in range(0, 4):
                if j == S and k == U and i[j][k] == 0:
                    ok = False
                if j != S and k != U and i[j][k] == 1:
                    ok = False
                if j == S and k != U and i[j][k] == 1:
                    ok = False
                if j != S and k == U and i[j][k] == 1:
                    ok = False
        if ok:
            print(i)

def exact_one_gpu(egyensulyi_mtx, S, U):
    size = int(egyensulyi_mtx.size / 16)
    indexes = cp.zeros((size,), dtype=bool)
    numBlock = int((size + 256 - 1) / 256)
    exact_one_cuda((numBlock,), (256,), (egyensulyi_mtx, size, indexes, S, U))
    return egyensulyi_mtx[indexes]

def filter_gpu(egyensulyi_mtx, S, U):
    size = int(egyensulyi_mtx.size / 16)
    indexes = cp.zeros((size,), dtype=cp.bool)
    numBlock = int((size + 256 - 1) / 256)
    filter_cuda((numBlock,), (256,), (egyensulyi_mtx, size, indexes, S, U))
    return egyensulyi_mtx[indexes]

def compute_lcm(x, y):
   if x > y:
       greater = x
   else:
       greater = y
   while(True):
       if((greater % x == 0) and (greater % y == 0)):
           lcm = greater
           break
       greater += 1
   return lcm