NeoN
A framework for CFD software
Loading...
Searching...
No Matches
sundials.hpp
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2023 - 2025 NeoN authors
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#if NN_WITH_SUNDIALS
8
9#include <concepts>
10#include <functional>
11#include <memory>
12
13#include <sundials/sundials_nvector.h>
14#include <sundials/sundials_core.hpp>
15#include <nvector/nvector_serial.h>
16#include <nvector/nvector_kokkos.hpp>
17#include <arkode/arkode_arkstep.h>
18#include <arkode/arkode_erkstep.h>
19
20#include "NeoN/core/error.hpp"
22#include "NeoN/fields/field.hpp"
23
24namespace NeoN::sundials
25{
26
32inline auto SUN_CONTEXT_DELETER = [](SUNContext* ctx)
33{
34 if (ctx != nullptr)
35 {
36 SUNContext_Free(ctx);
37 }
38};
39
45inline auto SUN_ARK_DELETER = [](char* ark)
46{
47 if (ark != nullptr)
48 {
49 void* arkodMem = reinterpret_cast<void*>(ark);
50 ARKodeFree(&arkodMem);
51 }
52};
53
60inline ARKODE_ERKTableID stringToERKTable(const std::string& key)
61{
62 if (key == "Forward-Euler") return ARKODE_FORWARD_EULER_1_1;
63 if (key == "Heun")
64 {
65 NF_ERROR_EXIT("Currently unsupported until field time step-stage indexing resolved.");
66 return ARKODE_HEUN_EULER_2_1_2;
67 }
68 if (key == "Midpoint")
69 {
70 NF_ERROR_EXIT("Currently unsupported until field time step-stage indexing resolved.");
71 return ARKODE_EXPLICIT_MIDPOINT_EULER_2_1_2;
72 }
74 "Unsupported Runge-Kutta time integration method selectied: " + key + ".\n"
75 + "Supported methods are: Forward-Euler, Heun, Midpoint."
76 );
77 return ARKODE_ERK_NONE; // avoids compiler warnings.
78}
79
88template<typename SKVectorType, typename ValueType>
89void fieldToSunNVectorImpl(const NeoN::Vector<ValueType>& field, N_Vector& vector)
90{
91 auto view = ::sundials::kokkos::GetVec<SKVectorType>(vector)->View();
92 auto fieldView = field.view();
94 field.exec(), field.range(), KOKKOS_LAMBDA(const localIdx i) { view(i) = fieldView[i]; }
95 );
96};
97
105template<typename ValueType>
106void fieldToSunNVector(const NeoN::Vector<ValueType>& field, N_Vector& vector)
107{
108 // CHECK FOR N_Vector on correct space in DEBUG
109 if (std::holds_alternative<NeoN::GPUExecutor>(field.exec()))
110 {
111 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>>(
112 field, vector
113 );
114 return;
115 }
116 if (std::holds_alternative<NeoN::CPUExecutor>(field.exec()))
117 {
118 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>>(
119 field, vector
120 );
121 return;
122 }
123 if (std::holds_alternative<NeoN::SerialExecutor>(field.exec()))
124 {
125 fieldToSunNVectorImpl<::sundials::kokkos::Vector<Kokkos::Serial>>(field, vector);
126 return;
127 }
128 NF_ERROR_EXIT("Unsupported NeoN executor for field.");
129};
130
139template<typename SKVectorType, typename ValueType>
140void sunNVectorToVectorImpl(const N_Vector& vector, NeoN::Vector<ValueType>& field)
141{
142 auto view = ::sundials::kokkos::GetVec<SKVectorType>(vector)->View();
143 ValueType* fieldData = field.data();
145 field.exec(), field.range(), KOKKOS_LAMBDA(const localIdx i) { fieldData[i] = view(i); }
146 );
147};
148
155template<typename ValueType>
156void sunNVectorToVector(const N_Vector& vector, NeoN::Vector<ValueType>& field)
157{
158 if (std::holds_alternative<NeoN::GPUExecutor>(field.exec()))
159 {
160 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>>(
161 vector, field
162 );
163 return;
164 }
165 if (std::holds_alternative<NeoN::CPUExecutor>(field.exec()))
166 {
167 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>>(
168 vector, field
169 );
170 return;
171 }
172 if (std::holds_alternative<NeoN::SerialExecutor>(field.exec()))
173 {
174 sunNVectorToVectorImpl<::sundials::kokkos::Vector<Kokkos::Serial>>(vector, field);
175 return;
176 }
177 NF_ERROR_EXIT("Unsupported NeoN executor for field.");
178};
179
194template<typename SolutionVectorType>
195int explicitRKSolve([[maybe_unused]] sunrealtype t, N_Vector y, N_Vector ydot, void* userData)
196{
197 // Pointer wrangling
198 using ValueType = typename SolutionVectorType::VectorValueType;
200 reinterpret_cast<NeoN::dsl::Expression<ValueType>*>(userData);
201 sunrealtype* yDotArray = N_VGetArrayPointer(ydot);
202 sunrealtype* yArray = N_VGetArrayPointer(y);
203
204 NF_ASSERT(
205 yDotArray != nullptr && yArray != nullptr && pdeExpre != nullptr,
206 "Failed to dereference pointers in sundails."
207 );
208
209 auto size = static_cast<localIdx>(N_VGetLength(y));
210 // Copy initial value from y to source.
211 NeoN::Vector<NeoN::scalar> source = pdeExpre->explicitOperation(size) * -1.0; // compute spatial
212 fence(pdeExpre->exec());
213 NeoN::sundials::fieldToSunNVector(source, ydot); // assign rhs to ydot.
214 return 0;
215}
216
217namespace detail
218{
219
227template<typename Vector>
228void initNVector(size_t size, std::shared_ptr<SUNContext> context, Vector& vec)
229{
230 vec.initNVector(size, context);
231}
232
239template<typename Vector>
240const N_Vector& sunNVector(const Vector& vec)
241{
242 return vec.sunNVector();
243}
244
251template<typename Vector>
252N_Vector& sunNVector(Vector& vec)
253{
254 return vec.sunNVector();
255}
256}
257
263template<typename ValueType>
264class SKVectorSerial
265{
266public:
267
268 SKVectorSerial() {};
269 ~SKVectorSerial() = default;
270 SKVectorSerial(const SKVectorSerial& other)
271 : kvector_(other.kvector_), svector_(other.kvector_) {};
272 SKVectorSerial(SKVectorSerial&& other) noexcept
273 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
274 SKVectorSerial& operator=(const SKVectorSerial& other) = delete;
275 SKVectorSerial& operator=(SKVectorSerial&& other) = delete;
276
277
278 using KVector = ::sundials::kokkos::Vector<Kokkos::Serial>;
279 void initNVector(size_t size, std::shared_ptr<SUNContext> context)
280 {
281 kvector_ = KVector(size, *context);
282 svector_ = kvector_;
283 };
284 const N_Vector& sunNVector() const { return svector_; };
285 N_Vector& sunNVector() { return svector_; };
286
287private:
288
289 KVector kvector_ {};
290 N_Vector svector_ {nullptr};
291};
292
298template<typename ValueType>
299class SKVectorHostDefault
300{
301public:
302
303 using KVector = ::sundials::kokkos::Vector<Kokkos::DefaultHostExecutionSpace>;
304
305 SKVectorHostDefault() = default;
306 ~SKVectorHostDefault() = default;
307 SKVectorHostDefault(const SKVectorHostDefault& other)
308 : kvector_(other.kvector_), svector_(other.kvector_) {};
309 SKVectorHostDefault(SKVectorHostDefault&& other) noexcept
310 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
311 SKVectorHostDefault& operator=(const SKVectorHostDefault& other) = delete;
312 SKVectorHostDefault& operator=(SKVectorHostDefault&& other) = delete;
313
314 void initNVector(size_t size, std::shared_ptr<SUNContext> context)
315 {
316 kvector_ = KVector(size, *context);
317 svector_ = kvector_;
318 };
319 const N_Vector& sunNVector() const { return svector_; };
320 N_Vector& sunNVector() { return svector_; };
321
322private:
323
324 KVector kvector_ {};
325 N_Vector svector_ {nullptr};
326};
327
333template<typename ValueType>
334class SKVectorDefault
335{
336public:
337
338 using KVector = ::sundials::kokkos::Vector<Kokkos::DefaultExecutionSpace>;
339
340 SKVectorDefault() = default;
341 ~SKVectorDefault() = default;
342 SKVectorDefault(const SKVectorDefault& other)
343 : kvector_(other.kvector_), svector_(other.kvector_) {};
344 SKVectorDefault(SKVectorDefault&& other) noexcept
345 : kvector_(std::move(other.kvector_)), svector_(std::move(other.svector_)) {};
346 SKVectorDefault& operator=(const SKVectorDefault& other) = delete;
347 SKVectorDefault& operator=(SKVectorDefault&& other) = delete;
348
349 void initNVector(size_t size, std::shared_ptr<SUNContext> context)
350 {
351 kvector_ = KVector(size, *context);
352 svector_ = kvector_;
353 };
354
355 const N_Vector& sunNVector() const { return svector_; };
356
357 N_Vector& sunNVector() { return svector_; };
358
359private:
360
361 KVector kvector_ {};
362 N_Vector svector_ {nullptr};
363};
364
371template<typename ValueType>
372class SKVector
373{
374public:
375
376 using SKVectorSerialV = SKVectorSerial<ValueType>;
377 using SKVectorHostDefaultV = SKVectorHostDefault<ValueType>;
378 using SKDefaultVectorV = SKVectorDefault<ValueType>;
379 using SKVectorVariant = std::variant<SKVectorSerialV, SKVectorHostDefaultV, SKDefaultVectorV>;
380
384 SKVector() { vector_.template emplace<SKVectorHostDefaultV>(); };
385
389 ~SKVector() = default;
390
395 SKVector(const SKVector&) = default;
396
400 SKVector& operator=(const SKVector&) = delete;
401
406 SKVector(SKVector&&) noexcept = default;
407
411 SKVector& operator=(SKVector&&) noexcept = delete;
412
417 void setExecutor(const NeoN::Executor& exec)
418 {
419 if (std::holds_alternative<NeoN::GPUExecutor>(exec))
420 {
421 vector_.template emplace<SKDefaultVectorV>();
422 return;
423 }
424 if (std::holds_alternative<NeoN::CPUExecutor>(exec))
425 {
426 vector_.template emplace<SKVectorHostDefaultV>();
427 return;
428 }
429 if (std::holds_alternative<NeoN::SerialExecutor>(exec))
430 {
431 vector_.template emplace<SKVectorSerialV>();
432 return;
433 }
434
436 "Unsupported NeoN executor: "
437 << std::visit([](const auto& e) { return e.name(); }, exec) << "."
438 );
439 }
440
446 void initNVector(size_t size, std::shared_ptr<SUNContext> context)
447 {
448 std::visit(
449 [size, &context](auto& vec) { detail::initNVector(size, context, vec); }, vector_
450 );
451 }
452
457 const N_Vector& sunNVector() const
458 {
459 return std::visit(
460 [](const auto& vec) -> const N_Vector& { return detail::sunNVector(vec); }, vector_
461 );
462 }
463
468 N_Vector& sunNVector()
469 {
470 return std::visit([](auto& vec) -> N_Vector& { return detail::sunNVector(vec); }, vector_);
471 }
472
477 const SKVectorVariant& variant() const { return vector_; }
478
483 SKVectorVariant& variant() { return vector_; }
484
485private:
486
487 SKVectorVariant vector_;
488};
489}
490
491#endif
A class to contain the data and executors for a field and define some basic operations.
Definition vector.hpp:30
ValueType * data()
Direct access to the underlying field data.
Definition vector.hpp:217
std::pair< localIdx, localIdx > range() const
Gets the range of the field.
Definition vector.hpp:305
const Executor & exec() const
Gets the executor associated with the field.
Definition vector.hpp:229
View< ValueType > view() &&=delete
Vector< ValueType > explicitOperation(localIdx nCells) const
const Executor & exec() const
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
Definition error.hpp:110
#define NF_ASSERT(condition, message)
Macro for asserting a condition and printing an error message if the condition is false.
Definition error.hpp:144
SpatialOperator< scalar > source(fvcc::VolumeField< scalar > &coeff, fvcc::VolumeField< scalar > &phi)
Definition array.hpp:20
void parallelFor(const Executor &exec, std::pair< localIdx, localIdx > range, Kernel kernel, std::string name="parallelFor")
void fence(const Executor &exec)
Definition executor.hpp:21
int32_t localIdx
Definition label.hpp:32