#----------------------------------------------------------------------
# Python interface for ISPACK3
# Copyright (C) 2023--2024 Toshiki Matsushima <toshiki@gfd-dennou.org>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA.
#----------------------------------------------------------------------
from mpi4py import MPI
import numpy as np
import time
import ispack3 as isp

jm  = 2**10
ntr = 10
mm  = jm-1
im  = jm*2
nm  = mm+1
nn  = mm
nt = mm

ig=1
ipow=0

eps = im*10**(-14)

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
np_size = comm.Get_size()

jv = isp.syqrjv(jm)
it, t, r = isp.syini1(mm, nm, im, comm)
p, jc = isp.syini2(mm, nm, jm, ig, r, comm)
c = isp.syinic(mm,nt,comm)
d = isp.syinid(mm,nt,comm)

g_shape = (((jm//jv-1)//np_size+1)*jv, im)
w_shape = (2 * jv * ((jm // jv - 1) // np_size + 1) * (mm // np_size + 1) * np_size * 2, )

s_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )
sy_shape = ((mm // np_size + 1) * (2 * (nt + 2) - mm // np_size * np_size), )

if comm.Get_rank() == 0:
    sall_shape = ((mm+1)**2 )
else:
    sall_shape = (0, )

G = isp.aligned_array(g_shape, align=64)
W = isp.aligned_array(w_shape, dtype=np.float64, align=64)
    
S = np.empty(s_shape, dtype=np.float64)
SD = np.empty(s_shape, dtype=np.float64)
SX = np.empty(s_shape, dtype=np.float64)
SXD = np.empty(s_shape, dtype=np.float64)
SXR = np.empty(sy_shape, dtype=np.float64)
SY = np.empty(sy_shape, dtype=np.float64)
SYD = np.empty(s_shape, dtype=np.float64)

SALL = np.empty(sall_shape, dtype=np.float64)
SDALL = np.empty(sall_shape, dtype=np.float64)

np.random.seed(0)
SALL = 2 * np.random.rand(*SALL.shape) - 1

isp.syss2s(mm,nn,SALL,S,comm)

isp.syclap(mm,nt,S,SD,d,2,comm)
isp.syclap(mm,nt,SD,S,d,1,comm)

isp.sycs2x(mm,nt,SD,SX,comm)
isp.sycrpk(mm,nt,nt+1,SX,SXR,comm)
isp.sycs2y(mm,nt,SD,SY,c,comm)

ipow=1
isp.syts2g(mm,nm,nt+1,im,jm,jv,SXR,G,it,t,p,r,jc,W,ipow,comm)
isp.sytg2s(mm,nm,nt+1,im,jm,jv,SXR,G,it,t,p,r,jc,W,ipow,comm)
isp.syts2g(mm,nm,nt+1,im,jm,jv,SY,G,it,t,p,r,jc,W,ipow,comm)
isp.sytg2s(mm,nm,nt+1,im,jm,jv,SY,G,it,t,p,r,jc,W,ipow,comm)

isp.sycy2s(mm,nt,SY,SYD,c,comm)
isp.sycrpk(mm,nt+1,nt,SXR,SX,comm)
isp.sycs2x(mm,nt,SX,SXD,comm)

SD = SXD + SYD

isp.sygs2s(mm,nt,S,SALL,comm)
isp.sygs2s(mm,nt,SD,SDALL,comm)

if(rank==0):

    l = np.arange(len(SALL))
    n, m = isp.sxl2nm(nn, l)
    
    SL_values = np.zeros_like(SALL, dtype=np.float64)
    SL_values[m == 0] = np.abs(SDALL[m == 0] - SALL[m == 0])
    SL_values[m > 0] = np.sqrt((SDALL[m > 0] - SALL[m > 0])**2 + (SDALL[m < 0] - SALL[m < 0])**2)

    SLMAX = np.max(SL_values)
    LAS = l[np.argmax(SL_values)]

    SLAMAX = np.sum(SL_values**2)

    n, m = isp.sxl2nm(mm, LAS)
    
    print("maxerror=", SLMAX, "(n=", n, ", m=", m, ")" )
    print("rmserror=", np.sqrt(SLAMAX/((mm+1)*(mm+2)/2) ) )
    print("gradient and divergence check:")
    if(SLMAX <= eps):
        print('** OK')
    else:
        print('** Fail')
