# coding: utf-8
#!/usr/bin/ruby
require "numru/gphys"
require 'getoptlong'
#require "./gp_dcpam_methods_v1.0"
include NumRu


# output files in VTK format (and netcdf format) of U, V, OMG at a potential temperature surface

w_vector_factor =  50.0
w_zm_factor     =  50.0
#w_zm_factor     = 100.0

filename0 = "test5"

rmiss = 9.969209968386869e+36

def print_usage
  <<~USAGE
  Usage : 
    $ ruby s2p.rb Ps.nc Temp.nc U.nc V.nc OMG.nc outdir

      Ps.nc   : NetCDF file for surface pressure
      Temp.nc : NetCDF file for temperature
      U.nc    : Input NetCDF file U
      V.nc    : Input NetCDF file V
      OMG.nc  : Input NetCDF file OMG
      outdir  : Output directory

    options:
      --merge
        Distributed files are used as input. Those files are merged
        and output is one file.
        Note that the name of NetCDF files to be merged is IN_rank??????.nc, 
        if this option is given.
     --planet <planet name>
        Planet name is given. It is used to identify pressure levels
        used in interpolation.
     --varname <variable name>
        [optional] Variable name is given.
  USAGE
end

###############################################################################
# interpolate_on_p
#   Values in GPhys object is interpolated on pressure levels.
#
# arguments:
# plev (array)           : pressure to be interpolated, unit is Pa
# gp_ps_part (GPhys obj) : surface pressure
# gp_part (GPhys obj)    : values to be interpolated
#
# return value:
# (GPhys obj) : interpolated values
#------------------------------------------------------------------------------

def interpolate_on_p( rmiss, plev, gp_ps_part, gp_part )

  na_plev = NArray.to_na(plev)
  va_plev = VArray.new( na_plev, {"units"=>"Pa"}, "pressure")

  # make a GPhys object of pressure
  gp_p  = mkgpp( gp_ps_part, gp_part )

  # set pressure as an assocated coordinate
  gp_part.set_assoc_coords([gp_p])

  # interpolate values on pressure levels
  gp_part_onplev = gp_part.interpolate("sig"=>va_plev)

  gp_part_onplev.put_att('missing_value',[rmiss])
  gp_part_onplev.put_att("FillValue",[rmiss])

  return gp_part_onplev

end


###############################################################################
# interpolate_on_pt
#   Values in GPhys object is interpolated on potential temperature levels.
#
# arguments:
# p00                    : reference temperature
# kappa                  : R/Cp
# ptlev (array)          : potential temperature to be interpolated, unit is K
# gp_ps (GPhys obj)      : surface pressure
# gp_t (GPhys obj)       : temperature
# gp (GPhys obj)         : values to be interpolated
#
# return value:
# (GPhys obj) : interpolated values
#------------------------------------------------------------------------------

def interpolate_on_pt( rmiss, p00, kappa, ptlev, gp_ps, gp_t, gp )

  na_ptlev = NArray.to_na(ptlev)
  va_ptlev = VArray.new( na_ptlev, {"units"=>"K"}, "pt")

  # make a GPhys object of pressure
  gp_p  = mkgpp( gp_ps, gp_t )
  # make a GPhys object of potential temperature
  gp_pt = mkgppt( p00, kappa, gp_p, gp_t, "pt" )

  # set pressure as an assocated coordinate
  gp.set_assoc_coords([gp_pt])

  # interpolate values on pressure levels
  gp_onptlev = gp.interpolate("sig"=>va_ptlev)

  gp_onptlev.put_att('missing_value',[rmiss])
  gp_onptlev.put_att("FillValue",[rmiss])

  return gp_onptlev

end


# make GPhys object of pressure
def mkgpp( gp_ps, gp )

  imax   = gp.coord('lon').val.size
  jmax   = gp.coord('lat').val.size
  kmax   = gp.coord('sig').val.size
  tmax   = gp.coord('time').val.size
  na_sig = gp.coord('sig').val

  # calculate pressure at each grid
  na_p =
    gp_ps.val.reshape(imax,jmax,1,tmax) *
    na_sig.reshape(1,1,kmax,1)
  va_p = VArray.new( na_p,
                     {"long_name"=>"air_pressure", "units"=>"Pa"},
                     "pressure" )
  # time axis of object is extracted
  ax_lon     = gp.axis('lon')
  ax_lat     = gp.axis('lat')
  ax_sig     = gp.axis('sig')
  ax_time_sl = gp.axis('time')
  # make GPhys object of pressure
  gp_p = GPhys.new( Grid.new(ax_lon,ax_lat,ax_sig,ax_time_sl), va_p )

  return gp_p

end

# make GPhys object of potential temperature
def mkgppt( p00, kappa, gp_p, gp_t, name = "PotTemp" )

  gp_pt = gp_t * ( p00 / gp_p )**kappa

  gp_pt.name = name
  gp_pt.long_name = 'potential temperature'

  return gp_pt

end


parser = GetoptLong.new

parser.set_options(
  ['--merge', '-m',              GetoptLong::NO_ARGUMENT],
  ['--planet',                   GetoptLong::REQUIRED_ARGUMENT],
  ['--time_index_s', '--tis',    GetoptLong::REQUIRED_ARGUMENT],
  ['--time_index_e', '--tie',    GetoptLong::REQUIRED_ARGUMENT],
  ['--p00',                      GetoptLong::REQUIRED_ARGUMENT],
  ['--kappa',                    GetoptLong::REQUIRED_ARGUMENT],
)

$OPT_merge = false
$OPT_planet = "Earth"
$OPT_time_index_s = 0
$OPT_time_index_e = -1
$OPT_p00 = UNumeric[ 1.0e5, 'Pa' ]
$OPT_kappa = UNumeric[ ( 8.3144621 / 28.9644e-3 ) / 1004.6, '1' ]

begin
  parser.each_option do |name, arg|
    eval "$OPT_#{name.sub(/^--/, '').gsub(/-/, '_')} = '#{arg}'"
#    print name, ":", arg, "\n"
    if name == "--merge" then
      $OPT_merge = true
    end
    if name == "--time_index_s" then
      $OPT_time_index_s = $OPT_time_index_s.to_i
    end
    if name == "--time_index_e" then
      $OPT_time_index_e = $OPT_time_index_e.to_i
    end
    if name == "--p00" then
      $OPT_p00 = UNumeric[ $OPT_p00.to_f, 'Pa' ]
    end
    if name == "--kappa" then
      $OPT_kappa = UNumeric[ $OPT_kappa.to_f, '1' ]
    end
  end
rescue
  exit(1)
end

#print $OPT_merge, "\n"
#print $OPT_planet, "\n"
#print $OPT_p00
#print $OPT_kappa
#exit


if ARGV.size < 6 then
  puts print_usage
  exit
end


# 惑星表面圧力
ncfn_ps = ARGV[0]
vname_ps = "Ps"
# 温度
ncfn_t = ARGV[1]
vname_t = "Temp"
# 入力
ncfn_u = ARGV[2]
vname_u = "U"
# 入力
ncfn_v = ARGV[3]
vname_v = "V"
# 入力
ncfn_w = ARGV[4]
vname_w = "OMG"
# 出力 directory
dir_out = ARGV[5]


if Dir.exist?(dir_out) then
  ncfn_out = dir_out + "/out.nc"
  if File.exist?(ncfn_out) then
    print "File, ", ncfn_out, " exists.\n"
    print "Overwrite the file? (yes/no)\n"
    input = $stdin.gets
    if input.chomp != 'yes' then
      print "STOP\n"
      exit
    end
  end
else
  print "Directory, ", dir_out, " does exist.\n"
  exit
end

ncfn = ncfn_u
outncfn = ncfn_out

unless ncfn[-3..-1] == '.nc' then
  print "ERROR : Unexpected extention of file name: ", ncfn, "\n"
  exit
end
is = ncfn.rindex("/") != nil ? ncfn.rindex("/") : -1
is += 1
ie = -4


print "   Input (Ps)         : ", ncfn_ps, "\n"
print "   Input (Temp)       : ", ncfn_t, "\n"
print "   Input (U)          : ", ncfn_u, "\n"
print "   Input (V)          : ", ncfn_v, "\n"
print "   Input (OMG)        : ", ncfn_w, "\n"
print "   Output directory   : ", dir_out, "\n"
print "   Output netcdf file : ", outncfn, "\n"

if $OPT_merge then
  url = ncfn_ps[0..-4] + "_rank??????.nc@" + vname_ps
else
  url = ncfn_ps + "@" + vname_ps
end
gp_ps = GPhys::IO.open_gturl( url )
na_time = gp_ps.coord('time').val
times = na_time[$OPT_time_index_s]
timee = na_time[$OPT_time_index_e]
gp_ps = gp_ps.cut('time'=>times..timee)
#
if $OPT_merge then
  url = ncfn_t[0..-4] + "_rank??????.nc@" + vname_t
else
  url = ncfn_t + "@" + vname_t
end
gp_t = GPhys::IO.open_gturl( url )
na_time = gp_t.coord('time').val
times = na_time[$OPT_time_index_s]
timee = na_time[$OPT_time_index_e]
gp_t = gp_t.cut('time'=>times..timee)
#
if $OPT_merge then
  url = ncfn_u[0..-4] + "_rank??????.nc@" + vname_u
else
  url = ncfn_u + "@" + vname_u
end
gp_u = GPhys::IO.open_gturl( url )
na_time = gp_u.coord('time').val
times = na_time[$OPT_time_index_s]
timee = na_time[$OPT_time_index_e]
gp_u = gp_u.cut('time'=>times..timee)
#
if $OPT_merge then
  url = ncfn_v[0..-4] + "_rank??????.nc@" + vname_v
else
  url = ncfn_v + "@" + vname_v
end
gp_v = GPhys::IO.open_gturl( url )
na_time = gp_v.coord('time').val
times = na_time[$OPT_time_index_s]
timee = na_time[$OPT_time_index_e]
gp_v = gp_v.cut('time'=>times..timee)
#
if $OPT_merge then
  url = ncfn_w[0..-4] + "_rank??????.nc@" + vname_w
else
  url = ncfn_w + "@" + vname_w
end
gp_w = GPhys::IO.open_gturl( url )
na_time = gp_w.coord('time').val
times = na_time[$OPT_time_index_s]
timee = na_time[$OPT_time_index_e]
gp_w = gp_w.cut('time'=>times..timee)


itime = 0
itimes = $OPT_time_index_s
if $OPT_time_index_e < 0 then
  itimee = $OPT_time_index_e + na_time.size
else
  itimee = $OPT_time_index_e
end
ntime = (itimee-itimes+1)

a_ptlev = [300.0]
a_ptlev = [310.0]
a_plev = [1000e2, 900e2, 800e2, 700e2, 600e2, 500e2, 400e2, 300e2, 250e2, 200e2, 150e2, 100e2]
a_plev = [950e2, 900e2, 800e2, 700e2, 600e2, 500e2, 400e2, 350e2, 300e2, 275e2, 250e2, 225e2, 200e2, 175e2, 150e2, 125e2, 100e2]

na_lon = gp_u.coord('lon').val
na_lat = gp_u.coord('lat').val
na_pt  = NArray.to_na(a_ptlev)
imax = na_lon.size
jmax = na_lat.size
kmax = na_pt.size

is = 0
ie = imax-1
js = 0
js = jmax/2-1
je = jmax-1
ks = 0
ke = kmax-1

outfile = NetCDF.create(outncfn)
GPhys::NetCDF_IO.each_along_dims_write(gp_u, outfile, -1) do |sub|
  # https://qiita.com/hokaccha/items/3abd55aa23894b57ffd1#comment-00bb23380f8b3b7cc489
  progress = ((itime+1).to_f/ntime.to_f*100).round(3)
  print "working... "+progress.round(3).to_s+"%\r"
  STDOUT.flush
  if (itime+1) == ntime then
    print "\n"
  end


  #====================
  # an isentropic surface

  sub_ps = gp_ps[is..ie,js..je,itime..itime]
  # calculate potential temperature (GPhys object)
  sub_t = gp_t[is..ie,js..je,true,itime..itime]
  sub_p = mkgpp( sub_ps, sub_t )
  # calculate potential temperature (GPhys object)
  sub_pt = mkgppt( $OPT_p00, $OPT_kappa, sub_p, sub_t )

  sub_u = gp_u[is..ie,js..je,true,itime..itime]
  sub_v = gp_v[is..ie,js..je,true,itime..itime]
  sub_w = gp_w[is..ie,js..je,true,itime..itime]

  # calculate pressure at a specified potential temperature surface
  sub_p = interpolate_on_pt( rmiss, $OPT_p00, $OPT_kappa, a_ptlev, sub_ps, sub_t, sub_p )
  # calculate U at a specified potential temperature surface
  sub_u = interpolate_on_pt( rmiss, $OPT_p00, $OPT_kappa, a_ptlev, sub_ps, sub_t, sub_u )
  # calculate V at a specified potential temperature surface
  sub_v = interpolate_on_pt( rmiss, $OPT_p00, $OPT_kappa, a_ptlev, sub_ps, sub_t, sub_v )
  # calculate W at a specified potential temperature surface
  sub_w = interpolate_on_pt( rmiss, $OPT_p00, $OPT_kappa, a_ptlev, sub_ps, sub_t, sub_w )

  # output files in VTK format
  outfilename = dir_out + "/"+filename0+"_ptsfc_"+itime.to_s.rjust(5,"0")+".vtk"
  file = File.open(outfilename, "w")
  file.write '# vtk DataFile Version 1.0', "\n"
  file.write outfilename, "\n"
  file.write 'ASCII', "\n"
  file.write 'DATASET STRUCTURED_GRID', "\n"
  file.write 'DIMENSIONS', ' ', (ie-is+1), ' ', (je-js+1), ' ', (ke-ks+1), "\n"
  file.write 'POINTS ', (ie-is+1)*(je-js+1)*(ke-ks+1), ' float', "\n"
  for k in 0..(ke-ks)
    for j in 0..(je-js)
      for i in 0..(ie-is)
        zval = sub_p.val[i,j,k,0]
        if ( zval == rmiss ) then
          zval = 1.0e5
        end
        zval = - 8.0 * Math::log(zval/1.0e5)
        file.write na_lon[i+is], ' ', na_lat[j+js], ' ', zval, "\n"
      end
    end
  end
  file.write 'POINT_DATA ', (ie-is+1)*(je-js+1)*(ke-ks+1), "\n"
  file.write 'VECTORS velocity_ptsfc float', "\n"
  for k in 0..(ke-ks)
    for j in 0..(je-js)
      for i in 0..(ie-is)
        zval = sub_p.val[i,j,k,0]
        if ( zval == rmiss ) then
          u = 0.0
          v = 0.0
          w = 0.0
        else
          u = sub_u.val[i,j,k,0]
          v = sub_v.val[i,j,k,0]
          dens = zval / ( (8.3/29.0e-3) *sub_t.val[i,j,k,0] )
          grav = 9.8
          w = - sub_w.val[i,j,k,0] / ( dens * grav )
          w *= w_vector_factor
        end
        file.write u, ' ', v, ' ', w, "\n"
      end
    end
  end
#  file.write 'POINT_DATA ', (ie-is+1)*(je-js+1)*(ke-ks+1), "\n"
  file.write 'SCALARS w_ptsfc float', "\n"
  file.write 'LOOKUP_TABLE default', "\n"
  for k in 0..(ke-ks)
    for j in 0..(je-js)
      for i in 0..(ie-is)
        zval = sub_p.val[i,j,k,0]
        if ( zval == rmiss ) then
          w = 0.0
        else
          dens = zval / ( (8.3/29.0e-3) *sub_t.val[i,j,k,0] )
          grav = 9.8
          w = - sub_w.val[i,j,k,0] / ( dens * grav )
          w *= w_zm_factor
        end
        file.write w, "\n"
      end
    end
  end
  file.close

  #====================
  # a meridional plane; zonal mean

  sub_ps = gp_ps[is..ie,js..je,itime..itime]
  sub_v = gp_v[is..ie,js..je,true,itime..itime]
  sub_w = gp_w[is..ie,js..je,true,itime..itime]
  sub_t = gp_t[is..ie,js..je,true,itime..itime]
  # calculate V at a specified potential temperature surface
  sub_v = interpolate_on_p( rmiss, a_plev, sub_ps, sub_v )
  # calculate W at a specified potential temperature surface
  sub_w = interpolate_on_p( rmiss, a_plev, sub_ps, sub_w )
  # calculate PotTemp at a specified potential temperature surface
  sub_pt = interpolate_on_p( rmiss, a_plev, sub_ps, sub_pt )
  # calculate T at a specified potential temperature surface
  sub_t = interpolate_on_p( rmiss, a_plev, sub_ps, sub_t )
  # zonal mean 
  sub_v = sub_v.mean('lon')
  sub_w = sub_w.mean('lon')
  sub_pt = sub_pt.mean('lon')
  sub_t  = sub_t.mean('lon')
  # output files in VTK format
  outfilename = dir_out + "/"+filename0+"_merisfc_"+itime.to_s.rjust(5,"0")+".vtk"
  file = File.open(outfilename, "w")
  file.write '# vtk DataFile Version 1.0', "\n"
  file.write outfilename, "\n"
  file.write 'ASCII', "\n"
  file.write 'DATASET STRUCTURED_GRID', "\n"
  file.write 'DIMENSIONS', ' ', 1, ' ', (je-js+1), ' ', a_plev.size, "\n"
  file.write 'POINTS ', 1*(je-js+1)*a_plev.size, ' float', "\n"
  for k in 0..(a_plev.size-1)
    for j in 0..(je-js)
      for i in 0..0
        zval = a_plev[k]
        zval = - 8.0 * Math::log(zval/1.0e5)
        file.write -10.0, ' ', na_lat[j+js], ' ', zval, "\n"
#        file.write 360.0, ' ', na_lat[j+js], ' ', zval, "\n"
      end
    end
  end
  file.write 'POINT_DATA ', 1*(je-js+1)*a_plev.size, "\n"
  file.write 'VECTORS velocity_zm float', "\n"
  for k in 0..(a_plev.size-1)
    for j in 0..(je-js)
      zval = sub_pt.val[j,k,0]
      if ( zval == rmiss ) then
        u = 0.0
        v = 0.0
        w = 0.0
      else
        u = 0.0
        v = sub_v.val[j,k,0]
        dens = a_plev[k] / ( (8.3/29.0e-3) *sub_t.val[j,k,0] )
        grav = 9.8
        w = - sub_w.val[j,k,0] / ( dens * grav )
        w *= w_zm_factor
      end
      file.write u, ' ', v, ' ', w, "\n"
    end
  end
#  file.write 'POINT_DATA ', (ie-is+1)*(je-js+1)*(ke-ks+1), "\n"
  file.write 'SCALARS w_zm float', "\n"
  file.write 'LOOKUP_TABLE default', "\n"
  for k in 0..(a_plev.size-1)
    for j in 0..(je-js)
#      zval = sub_pt.val[j,k,0]
      #
      zval = sub_w.val[j,k,0]
      if ( zval == rmiss ) then
        zval = 0.0
      else
        dens = a_plev[k] / ( (8.3/29.0e-3) *sub_t.val[j,k,0] )
        grav = 9.8
        w = - sub_w.val[j,k,0] / ( dens * grav )
        w *= w_zm_factor
        zval = w
      end
      file.write zval, "\n"
    end
  end
  file.close


  #====================
  # all wind (u,v,w)

  if ( false ) then
    sub_ps = gp_ps[is..ie,js..je,itime..itime]
    sub_u = gp_u[is..ie,js..je,true,itime..itime]
    sub_v = gp_v[is..ie,js..je,true,itime..itime]
    sub_w = gp_w[is..ie,js..je,true,itime..itime]
    sub_t = gp_t[is..ie,js..je,true,itime..itime]
    # calculate U at a specified potential temperature surface
    sub_u = interpolate_on_p( rmiss, a_plev, sub_ps, sub_u )
    # calculate V at a specified potential temperature surface
    sub_v = interpolate_on_p( rmiss, a_plev, sub_ps, sub_v )
    # calculate W at a specified potential temperature surface
    sub_w = interpolate_on_p( rmiss, a_plev, sub_ps, sub_w )
    # calculate T at a specified potential temperature surface
    sub_t = interpolate_on_p( rmiss, a_plev, sub_ps, sub_t )
    # output files in VTK format
    outfilename = dir_out + "/"+filename0+"_uvw_"+itime.to_s.rjust(5,"0")+".vtk"
    file = File.open(outfilename, "w")
    file.write '# vtk DataFile Version 1.0', "\n"
    file.write outfilename, "\n"
    file.write 'ASCII', "\n"
    file.write 'DATASET STRUCTURED_GRID', "\n"
    file.write 'DIMENSIONS', ' ', (ie-is+1), ' ', (je-js+1), ' ', a_plev.size, "\n"
    file.write 'POINTS ', (ie-is+1)*(je-js+1)*a_plev.size, ' float', "\n"
    for k in 0..(a_plev.size-1)
      for j in 0..(je-js)
        for i in 0..(ie-is)
          zval = a_plev[k]
          zval = - 8.0 * Math::log(zval/1.0e5)
          file.write na_lon[i+is], ' ', na_lat[j+js], ' ', zval, "\n"
        end
      end
    end
    file.write 'POINT_DATA ', (ie-is+1)*(je-js+1)*a_plev.size, "\n"
    file.write 'VECTORS uvw_velocity float', "\n"
    for k in 0..(a_plev.size-1)
      for j in 0..(je-js)
        for i in 0..(ie-is)
          zval = sub_w.val[i,j,k,0]
          if ( zval == rmiss ) then
            u = 0.0
            v = 0.0
            w = 0.0
          else
            u = sub_u.val[i,j,k,0]
            v = sub_v.val[i,j,k,0]
            dens = a_plev[k] / ( (8.3/29.0e-3) *sub_t.val[i,j,k,0] )
            grav = 9.8
            w = - sub_w.val[i,j,k,0] / ( dens * grav )
            w *= w_vector_factor
          end
          file.write u, ' ', v, ' ', w, "\n"
        end
      end
    end
    file.write 'SCALARS uvw_w float', "\n"
    file.write 'LOOKUP_TABLE default', "\n"
    for k in 0..(a_plev.size-1)
      for j in 0..(je-js)
        for i in 0..(ie-is)
          zval = sub_w.val[i,j,k,0]
          if ( zval == rmiss ) then
            zval = 0.0
          else
            dens = a_plev[k] / ( (8.3/29.0e-3) *sub_t.val[i,j,k,0] )
            grav = 9.8
            w = - sub_w.val[i,j,k,0] / ( dens * grav )
            w *= w_vector_factor
            zval = w
          end
          file.write zval, "\n"
        end
      end
    end
    file.close
  end


  itime += 1
  [sub_u,sub_v,sub_w,sub_pt,sub_t]
#  [sub_p,sub_u,sub_v,sub_w]
#  [sub_pt]
end
outfile.close
