10#include <sundials/sundials_nvector.h>
11#include <sundials/sundials_core.hpp>
12#include <nvector/nvector_serial.h>
13#include <nvector/nvector_kokkos.hpp>
14#include <arkode/arkode_arkstep.h>
15#include <arkode/arkode_erkstep.h>
46 void* arkodMem =
reinterpret_cast<void*
>(ark);
47 ARKodeFree(&arkodMem);
59 if (key ==
"Forward-Euler")
return ARKODE_FORWARD_EULER_1_1;
62 NF_ERROR_EXIT(
"Currently unsupported until field time step-stage indexing resolved.");
63 return ARKODE_HEUN_EULER_2_1_2;
65 if (key ==
"Midpoint")
67 NF_ERROR_EXIT(
"Currently unsupported until field time step-stage indexing resolved.");
68 return ARKODE_EXPLICIT_MIDPOINT_EULER_2_1_2;
71 "Unsupported Runge-Kutta time integration method selectied: " + key +
".\n"
72 +
"Supported methods are: Forward-Euler, Heun, Midpoint."
74 return ARKODE_ERK_NONE;
85template<
typename SKVectorType,
typename ValueType>
88 auto view = ::sundials::kokkos::GetVec<SKVectorType>(vector)->View();
89 auto fieldView = field.
view();
91 field.
exec(), field.
range(), KOKKOS_LAMBDA(
const localIdx i) { view(i) = fieldView[i]; }
102template<
typename ValueType>
106 if (std::holds_alternative<NeoN::GPUExecutor>(field.
exec()))
108 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>>(
113 if (std::holds_alternative<NeoN::CPUExecutor>(field.
exec()))
115 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>>(
120 if (std::holds_alternative<NeoN::SerialExecutor>(field.
exec()))
122 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::Serial>>(field, vector);
136template<
typename SKVectorType,
typename ValueType>
139 auto view = ::sundials::kokkos::GetVec<SKVectorType>(vector)->View();
140 ValueType* fieldData = field.
data();
142 field.
exec(), field.
range(), KOKKOS_LAMBDA(
const localIdx i) { fieldData[i] = view(i); }
152template<
typename ValueType>
155 if (std::holds_alternative<NeoN::GPUExecutor>(field.
exec()))
157 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>>(
162 if (std::holds_alternative<NeoN::CPUExecutor>(field.
exec()))
164 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>>(
169 if (std::holds_alternative<NeoN::SerialExecutor>(field.
exec()))
171 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::Serial>>(vector, field);
191template<
typename SolutionVectorType>
192int explicitRKSolve([[maybe_unused]] sunrealtype t, N_Vector y, N_Vector ydot,
void* userData)
195 using ValueType =
typename SolutionVectorType::VectorValueType;
198 sunrealtype* yDotArray = N_VGetArrayPointer(ydot);
199 sunrealtype* yArray = N_VGetArrayPointer(y);
202 yDotArray !=
nullptr && yArray !=
nullptr && pdeExpre !=
nullptr,
203 "Failed to dereference pointers in sundails."
206 auto size =
static_cast<localIdx>(N_VGetLength(y));
209 if (std::holds_alternative<NeoN::GPUExecutor>(pdeExpre->
exec()))
227template<
typename Vector>
230 vec.initNVector(size, context);
239template<
typename Vector>
242 return vec.sunNVector();
251template<
typename Vector>
254 return vec.sunNVector();
263template<
typename ValueType>
271 : kvector_(other.kvector_), svector_(other.kvector_) {};
273 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
278 using KVector = ::sundials::kokkos::Vector<Kokkos::Serial>;
279 void initNVector(
size_t size, std::shared_ptr<SUNContext> context)
281 kvector_ =
KVector(size, *context);
290 N_Vector svector_ {
nullptr};
298template<
typename ValueType>
303 using KVector = ::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>;
308 : kvector_(other.kvector_), svector_(other.kvector_) {};
310 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
314 void initNVector(
size_t size, std::shared_ptr<SUNContext> context)
316 kvector_ =
KVector(size, *context);
325 N_Vector svector_ {
nullptr};
333template<
typename ValueType>
338 using KVector = ::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>;
343 : kvector_(other.kvector_), svector_(other.kvector_) {};
345 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
349 void initNVector(
size_t size, std::shared_ptr<SUNContext> context)
351 kvector_ =
KVector(size, *context);
362 N_Vector svector_ {
nullptr};
371template<
typename ValueType>
379 using SKVectorVariant = std::variant<SKVectorSerialV, SKVectorHostDefaultV, SKDefaultVectorV>;
384 SKVector() { vector_.template emplace<SKVectorHostDefaultV>(); };
419 if (std::holds_alternative<NeoN::GPUExecutor>(exec))
421 vector_.template emplace<SKDefaultVectorV>();
424 if (std::holds_alternative<NeoN::CPUExecutor>(exec))
426 vector_.template emplace<SKVectorHostDefaultV>();
429 if (std::holds_alternative<NeoN::SerialExecutor>(exec))
431 vector_.template emplace<SKVectorSerialV>();
436 "Unsupported NeoN executor: "
437 << std::visit([](
const auto& e) {
return e.name(); }, exec) <<
"."
446 void initNVector(
size_t size, std::shared_ptr<SUNContext> context)
470 return std::visit([](
auto& vec) -> N_Vector& {
return detail::sunNVector(vec); }, vector_);
A class to contain the data and executors for a field and define some basic operations.
ValueType * data()
Direct access to the underlying field data.
std::pair< localIdx, localIdx > range() const
Gets the range of the field.
const Executor & exec() const
Gets the executor associated with the field.
View< ValueType > view() &&=delete
Vector< ValueType > explicitOperation(localIdx nCells) const
const Executor & exec() const
Default executor SUNDIALS Kokkos vector wrapper.
~SKVectorDefault()=default
::sundials::kokkos::Vector< Kokkos::DefaultExecutionSpace > KVector
SKVectorDefault(SKVectorDefault &&other) noexcept
void initNVector(size_t size, std::shared_ptr< SUNContext > context)
SKVectorDefault & operator=(const SKVectorDefault &other)=delete
SKVectorDefault(const SKVectorDefault &other)
const N_Vector & sunNVector() const
SKVectorDefault()=default
SKVectorDefault & operator=(SKVectorDefault &&other)=delete
Host default executor SUNDIALS Kokkos vector wrapper.
void initNVector(size_t size, std::shared_ptr< SUNContext > context)
SKVectorHostDefault(SKVectorHostDefault &&other) noexcept
SKVectorHostDefault(const SKVectorHostDefault &other)
const N_Vector & sunNVector() const
SKVectorHostDefault & operator=(SKVectorHostDefault &&other)=delete
~SKVectorHostDefault()=default
SKVectorHostDefault()=default
::sundials::kokkos::Vector< Kokkos::DefaultHostExecutionSpace > KVector
SKVectorHostDefault & operator=(const SKVectorHostDefault &other)=delete
Serial executor SUNDIALS Kokkos vector wrapper.
SKVectorSerial & operator=(SKVectorSerial &&other)=delete
~SKVectorSerial()=default
SKVectorSerial & operator=(const SKVectorSerial &other)=delete
::sundials::kokkos::Vector< Kokkos::Serial > KVector
void initNVector(size_t size, std::shared_ptr< SUNContext > context)
SKVectorSerial(const SKVectorSerial &other)
const N_Vector & sunNVector() const
SKVectorSerial(SKVectorSerial &&other) noexcept
Unified interface for SUNDIALS Kokkos vector management.
SKVectorVariant & variant()
Gets mutable reference to variant storing implementation.
void initNVector(size_t size, std::shared_ptr< SUNContext > context)
Initializes underlying vector with given size and context.
const N_Vector & sunNVector() const
Gets const reference to underlying N_Vector.
N_Vector & sunNVector()
Gets mutable reference to underlying N_Vector.
SKVector()
Default constructor. Initializes with host-default vector.
const SKVectorVariant & variant() const
Gets const reference to variant storing implementation.
SKVector & operator=(const SKVector &)=delete
Copy assignment operator (deleted).
void setExecutor(const NeoN::Executor &exec)
Sets appropriate vector implementation based on executor type.
SKVector(SKVector &&) noexcept=default
Move constructor.
SKVector(const SKVector &)=default
Copy constructor.
~SKVector()=default
Default destructor.
std::variant< SKVectorSerialV, SKVectorHostDefaultV, SKDefaultVectorV > SKVectorVariant
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
#define NF_ASSERT(condition, message)
Macro for asserting a condition and printing an error message if the condition is false.
const N_Vector & sunNVector(const Vector &vec)
Provides const access to underlying N_Vector.
void initNVector(size_t size, std::shared_ptr< SUNContext > context, Vector &vec)
Initializes a vector wrapper with specified size and context.
int explicitRKSolve(sunrealtype t, N_Vector y, N_Vector ydot, void *userData)
Performs a single explicit Runge-Kutta stage evaluation.
auto SUN_CONTEXT_DELETER
Custom deleter for SUNContext shared pointers.
ARKODE_ERKTableID stringToERKTable(const std::string &key)
Maps dictionary keywords to SUNDIALS RKButcher tableau identifiers.
void sunNVectorToVectorImpl(const N_Vector &vector, NeoN::Vector< ValueType > &field)
Converts SUNDIALS N_Vector data back to NeoN Vector format.
void sunNVectorToVector(const N_Vector &vector, NeoN::Vector< ValueType > &field)
Dispatcher for N_Vector to field conversion based on executor type.
auto SUN_ARK_DELETER
Custom deleter for explicit type RK solvers (ERK, ARK, etc) for the unique pointers.
void fieldToSunNVectorImpl(const NeoN::Vector< ValueType > &field, N_Vector &vector)
Converts NeoN Vector data to SUNDIALS N_Vector format.
void fieldToSunNVector(const NeoN::Vector< ValueType > &field, N_Vector &vector)
Dispatcher for field to N_Vector conversion based on executor type.
void parallelFor(const Executor &exec, std::pair< localIdx, localIdx > range, Kernel kernel, std::string name="parallelFor")
std::variant< SerialExecutor, CPUExecutor, GPUExecutor > Executor