13#include <sundials/sundials_nvector.h>
14#include <sundials/sundials_core.hpp>
15#include <nvector/nvector_serial.h>
16#include <nvector/nvector_kokkos.hpp>
17#include <arkode/arkode_arkstep.h>
18#include <arkode/arkode_erkstep.h>
24namespace NeoN::sundials
32inline auto SUN_CONTEXT_DELETER = [](SUNContext* ctx)
45inline auto SUN_ARK_DELETER = [](
char* ark)
49 void* arkodMem =
reinterpret_cast<void*
>(ark);
50 ARKodeFree(&arkodMem);
60inline ARKODE_ERKTableID stringToERKTable(
const std::string& key)
62 if (key ==
"Forward-Euler")
return ARKODE_FORWARD_EULER_1_1;
65 NF_ERROR_EXIT(
"Currently unsupported until field time step-stage indexing resolved.");
66 return ARKODE_HEUN_EULER_2_1_2;
68 if (key ==
"Midpoint")
70 NF_ERROR_EXIT(
"Currently unsupported until field time step-stage indexing resolved.");
71 return ARKODE_EXPLICIT_MIDPOINT_EULER_2_1_2;
74 "Unsupported Runge-Kutta time integration method selectied: " + key +
".\n"
75 +
"Supported methods are: Forward-Euler, Heun, Midpoint."
77 return ARKODE_ERK_NONE;
88template<
typename SKVectorType,
typename ValueType>
91 auto view = ::sundials::kokkos::GetVec<SKVectorType>(vector)->View();
92 auto fieldView = field.
view();
94 field.
exec(), field.
range(), KOKKOS_LAMBDA(
const localIdx i) { view(i) = fieldView[i]; }
105template<
typename ValueType>
109 if (std::holds_alternative<NeoN::GPUExecutor>(field.
exec()))
111 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>>(
116 if (std::holds_alternative<NeoN::CPUExecutor>(field.
exec()))
118 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>>(
123 if (std::holds_alternative<NeoN::SerialExecutor>(field.
exec()))
125 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::Serial>>(field, vector);
139template<
typename SKVectorType,
typename ValueType>
142 auto view = ::sundials::kokkos::GetVec<SKVectorType>(vector)->View();
143 ValueType* fieldData = field.
data();
145 field.
exec(), field.
range(), KOKKOS_LAMBDA(
const localIdx i) { fieldData[i] = view(i); }
155template<
typename ValueType>
158 if (std::holds_alternative<NeoN::GPUExecutor>(field.
exec()))
160 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>>(
165 if (std::holds_alternative<NeoN::CPUExecutor>(field.
exec()))
167 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>>(
172 if (std::holds_alternative<NeoN::SerialExecutor>(field.
exec()))
174 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::Serial>>(vector, field);
194template<
typename SolutionVectorType>
195int explicitRKSolve([[maybe_unused]] sunrealtype t, N_Vector y, N_Vector ydot,
void* userData)
198 using ValueType =
typename SolutionVectorType::VectorValueType;
201 sunrealtype* yDotArray = N_VGetArrayPointer(ydot);
202 sunrealtype* yArray = N_VGetArrayPointer(y);
205 yDotArray !=
nullptr && yArray !=
nullptr && pdeExpre !=
nullptr,
206 "Failed to dereference pointers in sundails."
209 auto size =
static_cast<localIdx>(N_VGetLength(y));
212 if (std::holds_alternative<NeoN::GPUExecutor>(pdeExpre->
exec()))
216 NeoN::sundials::fieldToSunNVector(source, ydot);
230template<
typename Vector>
231void initNVector(
size_t size, std::shared_ptr<SUNContext> context, Vector& vec)
233 vec.initNVector(size, context);
242template<
typename Vector>
243const N_Vector& sunNVector(
const Vector& vec)
245 return vec.sunNVector();
254template<
typename Vector>
255N_Vector& sunNVector(Vector& vec)
257 return vec.sunNVector();
266template<
typename ValueType>
272 ~SKVectorSerial() =
default;
273 SKVectorSerial(
const SKVectorSerial& other)
274 : kvector_(other.kvector_), svector_(other.kvector_) {};
275 SKVectorSerial(SKVectorSerial&& other) noexcept
276 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
277 SKVectorSerial& operator=(
const SKVectorSerial& other) =
delete;
278 SKVectorSerial& operator=(SKVectorSerial&& other) =
delete;
281 using KVector = ::sundials::kokkos::Vector<Kokkos::Serial>;
282 void initNVector(
size_t size, std::shared_ptr<SUNContext> context)
284 kvector_ = KVector(size, *context);
287 const N_Vector& sunNVector()
const {
return svector_; };
288 N_Vector& sunNVector() {
return svector_; };
293 N_Vector svector_ {
nullptr};
301template<
typename ValueType>
302class SKVectorHostDefault
306 using KVector = ::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>;
308 SKVectorHostDefault() =
default;
309 ~SKVectorHostDefault() =
default;
310 SKVectorHostDefault(
const SKVectorHostDefault& other)
311 : kvector_(other.kvector_), svector_(other.kvector_) {};
312 SKVectorHostDefault(SKVectorHostDefault&& other) noexcept
313 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
314 SKVectorHostDefault& operator=(
const SKVectorHostDefault& other) =
delete;
315 SKVectorHostDefault& operator=(SKVectorHostDefault&& other) =
delete;
317 void initNVector(
size_t size, std::shared_ptr<SUNContext> context)
319 kvector_ = KVector(size, *context);
322 const N_Vector& sunNVector()
const {
return svector_; };
323 N_Vector& sunNVector() {
return svector_; };
328 N_Vector svector_ {
nullptr};
336template<
typename ValueType>
341 using KVector = ::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>;
343 SKVectorDefault() =
default;
344 ~SKVectorDefault() =
default;
345 SKVectorDefault(
const SKVectorDefault& other)
346 : kvector_(other.kvector_), svector_(other.kvector_) {};
347 SKVectorDefault(SKVectorDefault&& other) noexcept
348 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
349 SKVectorDefault& operator=(
const SKVectorDefault& other) =
delete;
350 SKVectorDefault& operator=(SKVectorDefault&& other) =
delete;
352 void initNVector(
size_t size, std::shared_ptr<SUNContext> context)
354 kvector_ = KVector(size, *context);
358 const N_Vector& sunNVector()
const {
return svector_; };
360 N_Vector& sunNVector() {
return svector_; };
365 N_Vector svector_ {
nullptr};
374template<
typename ValueType>
379 using SKVectorSerialV = SKVectorSerial<ValueType>;
380 using SKVectorHostDefaultV = SKVectorHostDefault<ValueType>;
381 using SKDefaultVectorV = SKVectorDefault<ValueType>;
382 using SKVectorVariant = std::variant<SKVectorSerialV, SKVectorHostDefaultV, SKDefaultVectorV>;
387 SKVector() { vector_.template emplace<SKVectorHostDefaultV>(); };
392 ~SKVector() =
default;
398 SKVector(
const SKVector&) =
default;
403 SKVector& operator=(
const SKVector&) =
delete;
409 SKVector(SKVector&&) noexcept = default;
414 SKVector& operator=(SKVector&&) noexcept = delete;
420 void setExecutor(const
NeoN::Executor& exec)
422 if (std::holds_alternative<NeoN::GPUExecutor>(exec))
424 vector_.template emplace<SKDefaultVectorV>();
427 if (std::holds_alternative<NeoN::CPUExecutor>(exec))
429 vector_.template emplace<SKVectorHostDefaultV>();
432 if (std::holds_alternative<NeoN::SerialExecutor>(exec))
434 vector_.template emplace<SKVectorSerialV>();
439 "Unsupported NeoN executor: "
440 << std::visit([](
const auto& e) {
return e.name(); }, exec) <<
"."
449 void initNVector(
size_t size, std::shared_ptr<SUNContext> context)
452 [size, &context](
auto& vec) { detail::initNVector(size, context, vec); }, vector_
460 const N_Vector& sunNVector()
const
463 [](
const auto& vec) ->
const N_Vector& {
return detail::sunNVector(vec); }, vector_
471 N_Vector& sunNVector()
473 return std::visit([](
auto& vec) -> N_Vector& {
return detail::sunNVector(vec); }, vector_);
480 const SKVectorVariant& variant()
const {
return vector_; }
486 SKVectorVariant& variant() {
return vector_; }
490 SKVectorVariant vector_;
A class to contain the data and executors for a field and define some basic operations.
ValueType * data()
Direct access to the underlying field data.
std::pair< localIdx, localIdx > range() const
Gets the range of the field.
const Executor & exec() const
Gets the executor associated with the field.
View< ValueType > view() &&=delete
Vector< ValueType > explicitOperation(localIdx nCells) const
const Executor & exec() const
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
#define NF_ASSERT(condition, message)
Macro for asserting a condition and printing an error message if the condition is false.
SpatialOperator< scalar > source(fvcc::VolumeField< scalar > &coeff, fvcc::VolumeField< scalar > &phi)
void parallelFor(const Executor &exec, std::pair< localIdx, localIdx > range, Kernel kernel, std::string name="parallelFor")