#----------------------------------------------------------------------
# Python interface for ISPACK3
# Copyright (C) 2023--2024 Toshiki Matsushima <toshiki@gfd-dennou.org>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA.
#----------------------------------------------------------------------
from mpi4py import MPI
import numpy as np
import time
import ispack3 as isp

jm  = 2**10
ntr = 1
mm  = jm-1
im  = jm*2
nm  = mm
nn  = nm

ig=1
ipow=0

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
np_size = comm.Get_size()

jv = isp.syqrjv(jm)
it, t, r = isp.syini1(mm, nm, im, comm)
p, jc = isp.syini2(mm, nm, jm, ig, r, comm)

g_shape = (((jm//jv-1)//np_size+1)*jv, im)
w_shape = (2 * jv * ((jm // jv - 1) // np_size + 1) * (mm // np_size + 1) * np_size * 2, )
s_shape = ((mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size), )

G = isp.aligned_array(g_shape, align=64)
W = isp.aligned_array(w_shape, dtype=np.float64, align=64)

S = np.zeros(s_shape, dtype=np.float64)
SD = np.empty(s_shape, dtype=np.float64)

l = np.arange(len(S))
n, m = isp.syl2nm(mm, nn, l, comm)

np.random.seed(rank)
S[n >= 0] = 2 * np.random.rand(*S[n >= 0].shape) - 1

if(rank==0):
    print("MM=", mm, ", IM=", im, ", JM=", jm, ", JV=", jv, ", NTR=", ntr)
    print("SSE=", isp.mxgcpu())
    print("number of threads =", isp.mxgomp())
    print("number of processes =", np_size)    

rc=5*im*np.log(im)/np.log(2.0)*0.5*jm+(mm+1)*(mm+1)*jm

comm.Barrier()

start_time = time.perf_counter()

for n in range(ntr):    
    isp.syts2g(mm, nm, nm, im, jm, jv, S, G, it, t, p, r, jc, W, ipow, comm)

    
elapsed_time = time.perf_counter() - start_time
    
GFLOPS=rc*ntr/elapsed_time/10**9

if(rank==0):    
    print("S2G:", elapsed_time/ntr, "sec (", GFLOPS, "GFlops)")

start_time = time.perf_counter()

for n in range(ntr):
    isp.sytg2s(mm, nm, nm, im, jm, jv, SD, G, it, t, p, r, jc, W, ipow, comm)

elapsed_time = time.perf_counter() - start_time
    
GFLOPS=rc*ntr/elapsed_time/10**9

if(rank==0):    
    print("G2S:", elapsed_time/ntr, "sec (", GFLOPS, "GFlops)")

SL_values = np.zeros_like(S, dtype=np.float64)
SL_values[(m == 0) & (n >= 0)] = np.abs( S[(m == 0) & (n >= 0)] - SD[(m == 0) & (n >= 0)]  )
SL_values[(m > 0) & (n >= 0)] = np.sqrt((S[(m > 0) & (n >= 0)] - SD[(m > 0) & (n >= 0)] )**2 + (S[(m < 0) & (n >= 0)] - SD[(m < 0) & (n >= 0)])**2)

SLMAX = np.max(SL_values)
SLAMAX = np.sum(SL_values**2)
LAS = l[np.argmax(SL_values)]
n, m = isp.syl2nm(mm, nn, LAS, comm)

master_rank = 0
SLMAX_gathered = comm.gather(SLMAX, root=master_rank)
SLAMAX_gathered = comm.gather(SLAMAX, root=master_rank)
n_gathered = comm.gather(n, root=master_rank)
m_gathered = comm.gather(m, root=master_rank)

if(rank==master_rank):

    SLMAX = np.max(SLMAX_gathered)
    l = np.argmax(SLMAX_gathered)

    SLAMAX = np.sum(SLAMAX_gathered)
    
    n = n_gathered[l]
    m = m_gathered[l]
    
    print("maxerror=", SLMAX, "(n=", n, ", m=", m, ")" )
    print("rmserror=", np.sqrt(SLAMAX/((mm+1)*(mm+2)/2) ) )
