8#include <ginkgo/ginkgo.hpp>
9#include <ginkgo/extensions/kokkos.hpp>
19namespace NeoN::la::ginkgo
22std::shared_ptr<gko::Executor> getGkoExecutor(Executor exec);
28gko::array<T> createGkoArray(std::shared_ptr<const gko::Executor> exec, std::span<T> values)
30 return gko::make_array_view(exec, values.size(), values.data());
40template<
typename ValueType,
typename IndexType>
41std::shared_ptr<gko::matrix::Csr<ValueType, IndexType>> createGkoMtx(
42 std::shared_ptr<const gko::Executor> exec,
const LinearSystem<ValueType, IndexType>& sys
45 auto nrows =
static_cast<gko::dim<2>::dimension_type
>(sys.rhs().size());
46 auto mtx = sys.view().matrix;
49 auto vals = gko::array<ValueType>::view(
51 static_cast<gko::size_type
>(mtx.values.size()),
52 const_cast<ValueType*
>(mtx.values.data())
55 auto col = gko::array<IndexType>::view(
57 static_cast<gko::size_type
>(mtx.colIdxs.size()),
58 const_cast<IndexType*
>(mtx.colIdxs.data())
61 auto row = gko::array<IndexType>::view(
63 static_cast<gko::size_type
>(mtx.rowOffs.size()),
64 const_cast<IndexType*
>(mtx.rowOffs.data())
66 return gko::share(gko::matrix::Csr<ValueType, IndexType>::create(
67 exec, gko::dim<2> {nrows, nrows}, vals, col, row
71template<
typename ValueType>
72std::shared_ptr<gko::matrix::Dense<ValueType>>
73createGkoDense(std::shared_ptr<const gko::Executor> exec, ValueType* ptr, localIdx s)
75 auto size =
static_cast<std::size_t
>(s);
76 return gko::share(gko::matrix::Dense<ValueType>::create(
77 exec, gko::dim<2> {size, 1}, createGkoArray(exec, std::span {ptr, size}), 1
81template<
typename ValueType>
82std::shared_ptr<gko::matrix::Dense<ValueType>>
83createGkoDense(std::shared_ptr<const gko::Executor> exec,
const ValueType* ptr, localIdx s)
85 auto size =
static_cast<std::size_t
>(s);
86 auto const_array_view = gko::array<ValueType>::const_view(exec, size, ptr);
87 return gko::share(gko::matrix::Dense<ValueType>::create(
88 exec, gko::dim<2> {size, 1}, const_array_view.copy_to_array(), 1
93gko::config::pnode parse(
const Dictionary& dict);
95class GinkgoSolver :
public SolverFactory::template
Register<GinkgoSolver>
102 GinkgoSolver(Executor exec,
const Dictionary& solverConfig)
103 : Base(exec), gkoExec_(getGkoExecutor(exec)), config_(parse(solverConfig)),
104 factory_(gko::config::parse(
105 config_, gko::config::registry(), gko::config::make_type_descriptor<
scalar>()
110 static std::string
name() {
return "Ginkgo"; }
112 static std::string doc() {
return "TBD"; }
114 static std::string schema() {
return "none"; }
117 solve(
const LinearSystem<scalar, localIdx>& sys, Vector<scalar>& x)
const final
119 using vec = gko::matrix::Dense<scalar>;
121 auto retrieve = [](
const auto& in)
123 auto host = vec::create(in->get_executor()->get_master(), gko::dim<2> {1});
124 scalar res = host->copy_from(in)->at(0);
128 auto nrows = sys.rhs().size();
130 auto gkoMtx = detail::createGkoMtx(gkoExec_, sys);
131 auto solver = factory_->generate(gkoMtx);
133 std::shared_ptr<const gko::log::Convergence<scalar>> logger =
134 gko::log::Convergence<scalar>::create();
135 solver->add_logger(logger);
137 auto rhs = detail::createGkoDense(gkoExec_, sys.rhs().data(), nrows);
138 auto rhs2 = detail::createGkoDense(gkoExec_, sys.rhs().data(), nrows);
139 auto gkoX = detail::createGkoDense(gkoExec_, x.data(), nrows);
141 auto one = gko::initialize<vec>({1.0}, gkoExec_);
142 auto neg_one = gko::initialize<vec>({-1.0}, gkoExec_);
143 auto init = gko::initialize<vec>({0.0}, gkoExec_);
144 gkoMtx->apply(one, gkoX, neg_one, rhs2);
145 rhs->compute_norm2(init);
146 scalar initResNorm = retrieve(init);
148 solver->apply(rhs, gkoX);
150 scalar finalResNorm = retrieve(gko::as<vec>(logger->get_residual_norm()));
152 auto numIter =
label(logger->get_num_iterations());
154 return {numIter, initResNorm, finalResNorm};
158 virtual std::unique_ptr<SolverFactory> clone() const final
166 std::shared_ptr<const gko::Executor> gkoExec_;
167 gko::config::pnode config_;
168 std::shared_ptr<const gko::LinOpFactory> factory_;
A template class for registering derived classes with a base class.
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
void solve(Expression< typename VectorType::ElementType > &exp, VectorType &solution, scalar t, scalar dt, const Dictionary &fvSchemes, const Dictionary &fvSolution)
KOKKOS_INLINE_FUNCTION T one()
const std::string & name(const NeoN::Document &doc)
Retrieves the name of a Document.