/* -*- c++ -*- ----------------------------------------------------------
   LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
   https://www.lammps.org/, Sandia National Laboratories
   LAMMPS development team: developers@lammps.org

   Copyright (2003) Sandia Corporation.  Under the terms of Contract
   DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
   certain rights in this software.  This software is distributed under
   the GNU General Public License.

   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

#ifdef DIHEDRAL_CLASS
// clang-format off
DihedralStyle(charmmfsw/kk,DihedralCharmmfswKokkos<LMPDeviceType>);
DihedralStyle(charmmfsw/kk/device,DihedralCharmmfswKokkos<LMPDeviceType>);
DihedralStyle(charmmfsw/kk/host,DihedralCharmmfswKokkos<LMPHostType>);
// clang-format on
#else

// clang-format off
#ifndef LMP_DIHEDRAL_CHARMMFSW_KOKKOS_H
#define LMP_DIHEDRAL_CHARMMFSW_KOKKOS_H

#include "dihedral_charmmfsw.h"
#include "kokkos_type.h"
#include "dihedral_charmm_kokkos.h" // needed for s_EVM_FLOAT

namespace LAMMPS_NS {

template<int NEWTON_BOND, int EVFLAG>
struct TagDihedralCharmmfswCompute{};

template<class DeviceType>
class DihedralCharmmfswKokkos : public DihedralCharmmfsw {
 public:
  typedef DeviceType device_type;
  typedef ArrayTypes<DeviceType> AT;
  typedef EVM_FLOAT value_type;

  DihedralCharmmfswKokkos(class LAMMPS *);
  ~DihedralCharmmfswKokkos() override;
  void compute(int, int) override;
  void coeff(int, char **) override;
  void init_style() override;
  void read_restart(FILE *) override;

  template<int NEWTON_BOND, int EVFLAG>
  KOKKOS_INLINE_FUNCTION
  void operator()(TagDihedralCharmmfswCompute<NEWTON_BOND,EVFLAG>, const int&, EVM_FLOAT&) const;

  template<int NEWTON_BOND, int EVFLAG>
  KOKKOS_INLINE_FUNCTION
  void operator()(TagDihedralCharmmfswCompute<NEWTON_BOND,EVFLAG>, const int&) const;

  //template<int NEWTON_BOND>
  KOKKOS_INLINE_FUNCTION
  void ev_tally(EVM_FLOAT &evm, const int i1, const int i2, const int i3, const int i4,
                          KK_FLOAT &edihedral, KK_FLOAT *f1, KK_FLOAT *f3, KK_FLOAT *f4,
                          const KK_FLOAT &vb1x, const KK_FLOAT &vb1y, const KK_FLOAT &vb1z,
                          const KK_FLOAT &vb2x, const KK_FLOAT &vb2y, const KK_FLOAT &vb2z,
                          const KK_FLOAT &vb3x, const KK_FLOAT &vb3y, const KK_FLOAT &vb3z) const;

  KOKKOS_INLINE_FUNCTION
  void ev_tally(EVM_FLOAT &evm, const int i, const int j,
        const KK_FLOAT &evdwl, const KK_FLOAT &ecoul, const KK_FLOAT &fpair, const KK_FLOAT &delx,
                const KK_FLOAT &dely, const KK_FLOAT &delz) const;

  typedef typename KKDevice<DeviceType>::value KKDeviceType;
  TransformView<KK_ACC_FLOAT*,double*,Kokkos::LayoutRight,KKDeviceType> k_eatom;
  TransformView<KK_ACC_FLOAT*[6],double*[6],LMPDeviceLayout,KKDeviceType> k_vatom;

 protected:

  class NeighborKokkos *neighborKK;

  typename AT::t_kkfloat_1d_3_lr_randomread x;
  typename AT::t_int_1d_randomread atomtype;
  typename AT::t_kkfloat_1d_randomread q;
  typename AT::t_kkacc_1d_3 f;
  typename AT::t_int_2d_lr dihedrallist;

  Kokkos::View<KK_ACC_FLOAT*,Kokkos::LayoutRight,KKDeviceType,Kokkos::MemoryTraits<Kokkos::Atomic>> d_eatom;
  Kokkos::View<KK_ACC_FLOAT*[6],LMPDeviceLayout,KKDeviceType,Kokkos::MemoryTraits<Kokkos::Atomic>> d_vatom;

  TransformView<KK_ACC_FLOAT*,double*,Kokkos::LayoutRight,KKDeviceType> k_eatom_pair;
  TransformView<KK_ACC_FLOAT*[6],double*[6],LMPDeviceLayout,KKDeviceType> k_vatom_pair;
  Kokkos::View<KK_ACC_FLOAT*,Kokkos::LayoutRight,KKDeviceType,Kokkos::MemoryTraits<Kokkos::Atomic>> d_eatom_pair;
  Kokkos::View<KK_ACC_FLOAT*[6],LMPDeviceLayout,KKDeviceType,Kokkos::MemoryTraits<Kokkos::Atomic>> d_vatom_pair;

  int nlocal,newton_bond;
  int eflag,vflag;
  KK_FLOAT qqrd2e;

  DAT::tdual_int_scalar k_warning_flag;
  typename AT::t_int_scalar d_warning_flag;
  HAT::t_int_scalar h_warning_flag;

  typename AT::t_kkfloat_2d d_lj14_1;
  typename AT::t_kkfloat_2d d_lj14_2;
  typename AT::t_kkfloat_2d d_lj14_3;
  typename AT::t_kkfloat_2d d_lj14_4;

  DAT::tdual_kkfloat_1d k_k;
  DAT::tdual_kkfloat_1d k_multiplicity;
  DAT::tdual_kkfloat_1d k_shift;
  DAT::tdual_kkfloat_1d k_sin_shift;
  DAT::tdual_kkfloat_1d k_cos_shift;
  DAT::tdual_kkfloat_1d k_weight;

  typename AT::t_kkfloat_1d d_k;
  typename AT::t_kkfloat_1d d_multiplicity;
  typename AT::t_kkfloat_1d d_shift;
  typename AT::t_kkfloat_1d d_sin_shift;
  typename AT::t_kkfloat_1d d_cos_shift;
  typename AT::t_kkfloat_1d d_weight;

  void allocate() override;
};

}

#endif
#endif

