NeoN
A framework for CFD software
Loading...
Searching...
No Matches
ginkgo.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// SPDX-FileCopyrightText: 2024 NeoN authors
3
4#pragma once
5
6#if NF_WITH_GINKGO
7
8#include <ginkgo/ginkgo.hpp>
9#include <ginkgo/extensions/kokkos.hpp>
10
11#include "NeoN/fields/field.hpp"
17
18
19namespace NeoN::la::ginkgo
20{
21
22std::shared_ptr<gko::Executor> getGkoExecutor(Executor exec);
23
24namespace detail
25{
26
27template<typename T>
28gko::array<T> createGkoArray(std::shared_ptr<const gko::Executor> exec, std::span<T> values)
29{
30 return gko::make_array_view(exec, values.size(), values.data());
31}
32
33// template<typename T>
34// gko::detail::const_array_view<T>
35// createConstGkoArray(std::shared_ptr<const gko::Executor> exec, const std::span<const T> values)
36// {
37// return gko::make_const_array_view(exec, values.size(), values.data());
38// }
39
40template<typename ValueType, typename IndexType>
41std::shared_ptr<gko::matrix::Csr<ValueType, IndexType>> createGkoMtx(
42 std::shared_ptr<const gko::Executor> exec, const LinearSystem<ValueType, IndexType>& sys
43)
44{
45 auto nrows = static_cast<gko::dim<2>::dimension_type>(sys.rhs().size());
46 auto mtx = sys.view().matrix;
47 // NOTE we get a const view of the system but need a non const view to vals and indices
48 // auto vals = createConstGkoArray(exec, mtx.values).copy_to_array();
49 auto vals = gko::array<ValueType>::view(
50 exec,
51 static_cast<gko::size_type>(mtx.values.size()),
52 const_cast<ValueType*>(mtx.values.data())
53 );
54 // auto col = createGkoArray(exec, mtx.colIdxs);
55 auto col = gko::array<IndexType>::view(
56 exec,
57 static_cast<gko::size_type>(mtx.colIdxs.size()),
58 const_cast<IndexType*>(mtx.colIdxs.data())
59 );
60 // auto row = createGkoArray(exec, mtx.rowOffs);
61 auto row = gko::array<IndexType>::view(
62 exec,
63 static_cast<gko::size_type>(mtx.rowOffs.size()),
64 const_cast<IndexType*>(mtx.rowOffs.data())
65 );
66 return gko::share(gko::matrix::Csr<ValueType, IndexType>::create(
67 exec, gko::dim<2> {nrows, nrows}, vals, col, row
68 ));
69}
70
71template<typename ValueType>
72std::shared_ptr<gko::matrix::Dense<ValueType>>
73createGkoDense(std::shared_ptr<const gko::Executor> exec, ValueType* ptr, localIdx s)
74{
75 auto size = static_cast<std::size_t>(s);
76 return gko::share(gko::matrix::Dense<ValueType>::create(
77 exec, gko::dim<2> {size, 1}, createGkoArray(exec, std::span {ptr, size}), 1
78 ));
79}
80
81template<typename ValueType>
82std::shared_ptr<gko::matrix::Dense<ValueType>>
83createGkoDense(std::shared_ptr<const gko::Executor> exec, const ValueType* ptr, localIdx s)
84{
85 auto size = static_cast<std::size_t>(s);
86 auto const_array_view = gko::array<ValueType>::const_view(exec, size, ptr);
87 return gko::share(gko::matrix::Dense<ValueType>::create(
88 exec, gko::dim<2> {size, 1}, const_array_view.copy_to_array(), 1
89 ));
90}
91
92}
93gko::config::pnode parse(const Dictionary& dict);
94
95class GinkgoSolver : public SolverFactory::template Register<GinkgoSolver>
96{
97
98 using Base = SolverFactory::template Register<GinkgoSolver>;
99
100public:
101
102 GinkgoSolver(Executor exec, const Dictionary& solverConfig)
103 : Base(exec), gkoExec_(getGkoExecutor(exec)), config_(parse(solverConfig)),
104 factory_(gko::config::parse(
105 config_, gko::config::registry(), gko::config::make_type_descriptor<scalar>()
106 )
107 .on(gkoExec_))
108 {}
109
110 static std::string name() { return "Ginkgo"; }
111
112 static std::string doc() { return "TBD"; }
113
114 static std::string schema() { return "none"; }
115
116 virtual SolverStats
117 solve(const LinearSystem<scalar, localIdx>& sys, Vector<scalar>& x) const final
118 {
119 using vec = gko::matrix::Dense<scalar>;
120
121 auto retrieve = [](const auto& in)
122 {
123 auto host = vec::create(in->get_executor()->get_master(), gko::dim<2> {1});
124 scalar res = host->copy_from(in)->at(0);
125 return res;
126 };
127
128 auto nrows = sys.rhs().size();
129
130 auto gkoMtx = detail::createGkoMtx(gkoExec_, sys);
131 auto solver = factory_->generate(gkoMtx);
132
133 std::shared_ptr<const gko::log::Convergence<scalar>> logger =
134 gko::log::Convergence<scalar>::create();
135 solver->add_logger(logger);
136
137 auto rhs = detail::createGkoDense(gkoExec_, sys.rhs().data(), nrows);
138 auto rhs2 = detail::createGkoDense(gkoExec_, sys.rhs().data(), nrows);
139 auto gkoX = detail::createGkoDense(gkoExec_, x.data(), nrows);
140
141 auto one = gko::initialize<vec>({1.0}, gkoExec_);
142 auto neg_one = gko::initialize<vec>({-1.0}, gkoExec_);
143 auto init = gko::initialize<vec>({0.0}, gkoExec_);
144 gkoMtx->apply(one, gkoX, neg_one, rhs2);
145 rhs->compute_norm2(init);
146 scalar initResNorm = retrieve(init);
147
148 solver->apply(rhs, gkoX);
149
150 scalar finalResNorm = retrieve(gko::as<vec>(logger->get_residual_norm()));
151
152 auto numIter = label(logger->get_num_iterations());
153
154 return {numIter, initResNorm, finalResNorm};
155 }
156
157 // TODO why use a smart pointer here?
158 virtual std::unique_ptr<SolverFactory> clone() const final
159 {
160 NF_ERROR_EXIT("Not implemented");
161 return {};
162 }
163
164private:
165
166 std::shared_ptr<const gko::Executor> gkoExec_;
167 gko::config::pnode config_;
168 std::shared_ptr<const gko::LinOpFactory> factory_;
169};
170
171
172}
173
174#endif
A template class for registering derived classes with a base class.
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
Definition error.hpp:108
void solve(Expression< typename VectorType::ElementType > &exp, VectorType &solution, scalar t, scalar dt, const Dictionary &fvSchemes, const Dictionary &fvSolution)
Definition solver.hpp:36
KOKKOS_INLINE_FUNCTION T one()
float scalar
Definition scalar.hpp:14
const std::string & name(const NeoN::Document &doc)
Retrieves the name of a Document.
int32_t label
Definition label.hpp:24