NeoN
A framework for CFD software
Loading...
Searching...
No Matches
ginkgo.hpp
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2024 - 2025 NeoN authors
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#if NF_WITH_GINKGO
8
9#include <chrono>
10
11#include <ginkgo/ginkgo.hpp>
12#include <ginkgo/extensions/kokkos.hpp>
13#include <ginkgo/extensions/config/json_config.hpp>
14
15#include "NeoN/fields/field.hpp"
21
22
23namespace NeoN::la::ginkgo
24{
25
26std::shared_ptr<gko::Executor> getGkoExecutor(Executor exec);
27
28namespace detail
29{
30
31template<typename T>
32gko::array<T> createGkoArray(std::shared_ptr<const gko::Executor> exec, std::span<T> values)
33{
34 return gko::make_array_view(exec, values.size(), values.data());
35}
36
37// template<typename T>
38// gko::detail::const_array_view<T>
39// createConstGkoArray(std::shared_ptr<const gko::Executor> exec, const std::span<const T> values)
40// {
41// return gko::make_const_array_view(exec, values.size(), values.data());
42// }
43
44template<typename ValueType, typename IndexType>
45std::shared_ptr<gko::matrix::Csr<ValueType, IndexType>> createGkoMtx(
46 std::shared_ptr<const gko::Executor> exec, const LinearSystem<ValueType, IndexType>& sys
47)
48{
49 auto nrows = static_cast<gko::dim<2>::dimension_type>(sys.rhs().size());
50 auto mtx = sys.view().matrix;
51 // NOTE we get a const view of the system but need a non const view to vals and indices
52 // auto vals = createConstGkoArray(exec, mtx.values).copy_to_array();
53 auto vals = gko::array<ValueType>::view(
54 exec,
55 static_cast<gko::size_type>(mtx.values.size()),
56 const_cast<ValueType*>(mtx.values.data())
57 );
58 // auto col = createGkoArray(exec, mtx.colIdxs);
59 auto col = gko::array<IndexType>::view(
60 exec,
61 static_cast<gko::size_type>(mtx.colIdxs.size()),
62 const_cast<IndexType*>(mtx.colIdxs.data())
63 );
64 // auto row = createGkoArray(exec, mtx.rowOffs);
65 auto row = gko::array<IndexType>::view(
66 exec,
67 static_cast<gko::size_type>(mtx.rowOffs.size()),
68 const_cast<IndexType*>(mtx.rowOffs.data())
69 );
70 return gko::share(gko::matrix::Csr<ValueType, IndexType>::create(
71 exec, gko::dim<2> {nrows, nrows}, vals, col, row
72 ));
73}
74
75template<typename ValueType>
76std::shared_ptr<gko::matrix::Dense<ValueType>>
77createGkoDense(std::shared_ptr<const gko::Executor> exec, ValueType* ptr, localIdx s)
78{
79 auto size = static_cast<std::size_t>(s);
80 return gko::share(gko::matrix::Dense<ValueType>::create(
81 exec, gko::dim<2> {size, 1}, createGkoArray(exec, std::span {ptr, size}), 1
82 ));
83}
84
85template<typename ValueType>
86std::shared_ptr<gko::matrix::Dense<ValueType>>
87createGkoDense(std::shared_ptr<const gko::Executor> exec, const ValueType* ptr, localIdx s)
88{
89 auto size = static_cast<std::size_t>(s);
90 auto const_array_view = gko::array<ValueType>::const_view(exec, size, ptr);
91 return gko::share(gko::matrix::Dense<ValueType>::create(
92 exec, gko::dim<2> {size, 1}, const_array_view.copy_to_array(), 1
93 ));
94}
95
96}
97gko::config::pnode parse(const Dictionary& dict);
98
99class GinkgoSolver : public SolverFactory::template Register<GinkgoSolver>
100{
101
102 using Base = SolverFactory::template Register<GinkgoSolver>;
103
104public:
105
106 GinkgoSolver(Executor exec, const Dictionary& solverConfig)
107 : Base(exec), gkoExec_(getGkoExecutor(exec)), config_(parse(solverConfig)),
108 factory_(gko::config::parse(
109 config_, gko::config::registry(), gko::config::make_type_descriptor<scalar>()
110 )
111 .on(gkoExec_))
112 {}
113
114 static std::string name() { return "Ginkgo"; }
115
116 static std::string doc() { return "TBD"; }
117
118 static std::string schema() { return "none"; }
119
120 virtual SolverStats
121 solve(const LinearSystem<scalar, localIdx>& sys, Vector<scalar>& x) const final
122 {
123 auto startEval = std::chrono::steady_clock::now();
124 using vec = gko::matrix::Dense<scalar>;
125
126 auto retrieve = [](const auto& in)
127 {
128 auto host = vec::create(in->get_executor()->get_master(), gko::dim<2> {1});
129 scalar res = host->copy_from(in)->at(0);
130 return res;
131 };
132
133 auto nrows = sys.rhs().size();
134
135 auto gkoMtx = detail::createGkoMtx(gkoExec_, sys);
136 auto solver = factory_->generate(gkoMtx);
137
138 std::shared_ptr<const gko::log::Convergence<scalar>> logger =
139 gko::log::Convergence<scalar>::create();
140 solver->add_logger(logger);
141
142 auto rhs = detail::createGkoDense(gkoExec_, sys.rhs().data(), nrows);
143 auto rhs2 = detail::createGkoDense(gkoExec_, sys.rhs().data(), nrows);
144 auto gkoX = detail::createGkoDense(gkoExec_, x.data(), nrows);
145
146 auto one = gko::initialize<vec>({1.0}, gkoExec_);
147 auto neg_one = gko::initialize<vec>({-1.0}, gkoExec_);
148 auto init = gko::initialize<vec>({0.0}, gkoExec_);
149 gkoMtx->apply(one, gkoX, neg_one, rhs2);
150 rhs->compute_norm2(init);
151 scalar initResNorm = retrieve(init);
152
153 solver->apply(rhs, gkoX);
154
155 scalar finalResNorm = retrieve(gko::as<vec>(logger->get_residual_norm()));
156
157 auto numIter = label(logger->get_num_iterations());
158
159 auto endEval = std::chrono::steady_clock::now();
160 auto duration =
161 static_cast<float>(
162 std::chrono::duration_cast<std::chrono::microseconds>(endEval - startEval).count()
163 )
164 / 1000.0;
165 return {numIter, initResNorm, finalResNorm, duration};
166 }
167
168 // TODO why use a smart pointer here?
169 virtual std::unique_ptr<SolverFactory> clone() const final
170 {
171 NF_ERROR_EXIT("Not implemented");
172 return {};
173 }
174
175private:
176
177 std::shared_ptr<const gko::Executor> gkoExec_;
178 gko::config::pnode config_;
179 std::shared_ptr<const gko::LinOpFactory> factory_;
180};
181
182
183}
184
185#endif
A template class for registering derived classes with a base class.
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
Definition error.hpp:110
void solve(Expression< typename VectorType::ElementType > &exp, VectorType &solution, scalar t, scalar dt, const Dictionary &fvSchemes, const Dictionary &fvSolution)
Definition solver.hpp:38
KOKKOS_INLINE_FUNCTION T one()
float scalar
Definition scalar.hpp:16
const std::string & name(const NeoN::Document &doc)
Retrieves the name of a Document.
int32_t label
Definition label.hpp:26