from numba import cuda
import cupy as cp
from utils import compute_lcm

with open('epgpu.cu') as f:
    code = f.read()

kers = ('gpu_egyensulyi', )
ep_pontok_module = cp.RawModule(code=code, options=('--std=c++11',), name_expressions=kers)
fun = ep_pontok_module.get_function(kers[0])

kers_pontok = ('gpu_pontok', )
ep_pontok_module_pontok = cp.RawModule(code=code, options=('--std=c++11',), name_expressions=kers_pontok)
fun_pontok = ep_pontok_module_pontok.get_function(kers_pontok[0])

def start_kernel(Cx, Cy, Dx, Dy, Dz, v, w):
    print(f"Cnt: {Cx.size}x{Dx.size}={Cx.size*Dx.size}")
    print("Res size (byte): ", Cx.size*Dx.size*4*4)
    print(fun.attributes)
    lcm = compute_lcm(Cx.size, Dx.size)
    egyensulyi_mtx = cp.zeros((Cx.size*Dx.size, 4, 4), dtype=cp.int8)
    numBlock = int((Cx.size*Dx.size + fun.max_threads_per_block - 1) / fun.max_threads_per_block)
    print(f"{Cx.size}, {Cy.size}, {Dx.size}, {Dy.size}, {Dz.size}, {egyensulyi_mtx.shape}, {egyensulyi_mtx.nbytes}")
    fun((numBlock,), (fun.max_threads_per_block,), (v, w, Cx, Cy, Dx, Dy, Dz, Cx.size, Dx.size, lcm, egyensulyi_mtx))

    return egyensulyi_mtx

def start_kernel_save_points(Cx, Cy, Dx, Dy, Dz, v, w)
    points = cp.zeros((v*w, 3), dtype=cp.double)
    points_type = cp.zeros((v*w, 3), dtype=cp.char)
    numBlock = int((w*v + fun_pontok.max_threads_per_block - 1) / fun_pontok.max_threads_per_block)
    fun_pontok((numBlock), (fun_pontok.max_threads_per_block), (v, w, Cx, Cy, Dx, Dy, Dz, points, points_type))
    return points, tipus