import cupy as cp
from numba import cuda
from utils import expSpace
from mpi import size, rank

def gen_angels_to_pick(n):
    if n % 2 == 0:
        raise "n%2==0"
    N = int((n + 3) / 2)
    X1 = expSpace(0.0, cp.pi/2.0, N)
    X3 = expSpace(cp.pi/2.0, cp.pi, N)
    X3 = -1*X3 + 3.0*cp.pi/2.0

    anglestopick = cp.concatenate((X1, X3), axis=None)
    anglestopick = cp.unique(anglestopick) # Vigyázni vele!
    anglestopick = anglestopick[1:-1]
    
    return anglestopick

parosit = cp.RawKernel(r'''
extern "C" 
__global__ 
void parosit(const double* x1, const double* x2, double* a, double* b, const int m, const double PI) {
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    if (m <= tid || m*m <= tid*m+m-1)
        return;
    double alpha = x1[tid];
    if (m*m <= tid*m+m-1)
        return;
    for (int i = 0; i < m; i++) {
        double betha = x2[i];
        if ((alpha + betha) < PI && betha >= alpha && alpha > 0.0) {
            a[tid*m+i] = alpha;
            b[tid*m+i] = betha;
        } else {
            a[tid*m+i] = -1.0;
            b[tid*m+i] = -1.0;
        }
    }
}
''', 'parosit')

parosit2 = cp.RawKernel(r'''
extern "C" 
__global__ 
void parosit2(const double* x1, const double* x2, double* a, double* b, const int m, const double PI) {
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    if (m <= tid || m*m <= tid*m+m-1)
        return;
    double alpha = x1[tid];
    if (m*m <= tid*m+m-1)
        return;
    for (int i = 0; i < m; i++) {
        double betha = x2[i];
        if ((alpha + betha) < PI && alpha > 0.0) {
            a[tid*m+i] = alpha;
            b[tid*m+i] = betha;
        } else {
            a[tid*m+i] = -1.0;
            b[tid*m+i] = -1.0;
        }
    }
}
''', 'parosit2')

parosit2_mpi = cp.RawKernel(r'''
extern "C" 
__global__ 
void parosit2_mpi(const double* x1, const double* x2, double* a, double* b, const int m, const int m2, const double PI) {
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    if (m <= tid || m*m2 <= tid*m+m2-1)
        return;
    double alpha = x1[tid];
    for (int i = 0; i < m2; i++) {
        double betha = x2[i];
        if ((alpha + betha) < PI && alpha > 0.0) {
            a[tid*m+i] = alpha;
            b[tid*m+i] = betha;
        } else {
            a[tid*m+i] = -1.0;
            b[tid*m+i] = -1.0;
        }
    }
}
''', 'parosit2_mpi')

def angles_alap(anglestopick):
    m = anglestopick.size
    alpha_arr = cp.zeros((m, m), dtype=cp.float64)
    beta_arr = cp.zeros((m, m), dtype=cp.float64)
    blocksize = int((m + 64 - 1) / 64)
    parosit((blocksize,), (64,), (anglestopick, anglestopick, alpha_arr, beta_arr, m, cp.pi))

    tompa_beta_arr = beta_arr[beta_arr > cp.pi]
    tompa_beta_mpi_arr = tompa_beta_arr - cp.pi/2
    hegyes_beta_arr = beta_arr[(beta_arr <= cp.pi) & (beta_arr > 0.0)]
    tompa_alpha_arr = alpha_arr[beta_arr > cp.pi]
    hegyes_alpha_arr = alpha_arr[(beta_arr <= cp.pi) & (beta_arr > 0.0)]

    tghB = cp.tan(hegyes_beta_arr)
    tghA = cp.tan(hegyes_alpha_arr)
    tgtB_mpi = cp.tan(tompa_beta_mpi_arr)
    tgtA = cp.tan(tompa_alpha_arr)

    # hegyes
    sztgh = tghB * tghA
    otgh = tghA + tghB
    hCy = sztgh / otgh
    hCx = hCy / tghA

    # tompa
    sztgt = tgtB_mpi*tgtA
    mtgt = 1.0 - sztgt
    tCy = tgtA / mtgt
    tCx = tCy / tgtA

    Cx = cp.concatenate((hCx, tCx), axis=None)
    Cy = cp.concatenate((hCy, tCy), axis=None)

    return Cx, Cy


def angles_ratet(anglestopick, mpi):
    m = anglestopick.size
    if not mpi:
        alpha_arr = cp.zeros((m, m), dtype=cp.float64)
        beta_arr = cp.zeros((m, m), dtype=cp.float64)
        blocksize = int((m + 64 - 1) / 64)
        parosit2((blocksize,), (m,), (anglestopick, anglestopick, alpha_arr, beta_arr, m, cp.pi))
    else:
        also = int(rank/size * m)
        felso = int((rank+1)/size * m)
        anglestopick2 = anglestopick[also:felso]
        m2 = anglestopick2.size
        alpha_arr = cp.zeros((m, m2), dtype=cp.float64)
        beta_arr = cp.zeros((m, m2), dtype=cp.float64)
        blocksize = int((m + 64 - 1) / 64)
        parosit2_mpi((blocksize,), (m,), (anglestopick, anglestopick2, alpha_arr, beta_arr, m, m2, cp.pi))

    tompa_beta_arr = beta_arr[beta_arr > cp.pi]
    tompa_beta_mpi_arr = tompa_beta_arr - cp.pi/2
    hegyes_beta_arr = beta_arr[(beta_arr <= cp.pi) & (beta_arr > 0.0)]
    tompa_alpha_arr = alpha_arr[beta_arr > cp.pi]
    hegyes_alpha_arr = alpha_arr[(beta_arr <= cp.pi) & (beta_arr > 0.0)]

    tghB = cp.tan(hegyes_beta_arr)
    tghA = cp.tan(hegyes_alpha_arr)
    tgtB_mpi = cp.tan(tompa_beta_mpi_arr)
    tgtA = cp.tan(tompa_alpha_arr)

    # hegyes
    sztgh = tghB*tghA
    otgh = tghA + tghB
    hCy = sztgh/otgh
    hCx = hCy / tghA

    # tompa
    sztgt = tgtB_mpi*tgtA
    mtgt = 1.0 - sztgt
    tCy = tgtA / mtgt
    tCx = tCy / tgtA

    Ey = cp.concatenate((hCy, tCy), axis=None)
    Ex = cp.concatenate((hCx, tCx), axis=None)

    cos = cp.cos(anglestopick)
    sin = cp.sin(anglestopick)

    Dx = cp.outer(cp.full(anglestopick.size, 1.0, dtype=cp.float64), Ex).flatten()
    Dy = cp.outer(cos, Ey).flatten()
    Dz = cp.outer(sin, Ey).flatten()

    return Dx, Dy, Dz