11#include <ginkgo/ginkgo.hpp>
12#include <ginkgo/extensions/kokkos.hpp>
13#include <ginkgo/extensions/config/json_config.hpp>
23namespace NeoN::la::ginkgo
26std::shared_ptr<gko::Executor> getGkoExecutor(Executor exec);
32gko::array<T> createGkoArray(std::shared_ptr<const gko::Executor> exec, std::span<T> values)
34 return gko::make_array_view(exec, values.size(), values.data());
44template<
typename ValueType,
typename IndexType>
45std::shared_ptr<gko::matrix::Csr<ValueType, IndexType>> createGkoMtx(
46 std::shared_ptr<const gko::Executor> exec,
const LinearSystem<ValueType, IndexType>& sys
49 auto nrows =
static_cast<gko::dim<2>::dimension_type
>(sys.rhs().size());
50 auto mtx = sys.view().matrix;
53 auto vals = gko::array<ValueType>::view(
55 static_cast<gko::size_type
>(mtx.values.size()),
56 const_cast<ValueType*
>(mtx.values.data())
59 auto col = gko::array<IndexType>::view(
61 static_cast<gko::size_type
>(mtx.colIdxs.size()),
62 const_cast<IndexType*
>(mtx.colIdxs.data())
65 auto row = gko::array<IndexType>::view(
67 static_cast<gko::size_type
>(mtx.rowOffs.size()),
68 const_cast<IndexType*
>(mtx.rowOffs.data())
70 return gko::share(gko::matrix::Csr<ValueType, IndexType>::create(
71 exec, gko::dim<2> {nrows, nrows}, vals, col, row
75template<
typename ValueType>
76std::shared_ptr<gko::matrix::Dense<ValueType>>
77createGkoDense(std::shared_ptr<const gko::Executor> exec, ValueType* ptr, localIdx s)
79 auto size =
static_cast<std::size_t
>(s);
80 return gko::share(gko::matrix::Dense<ValueType>::create(
81 exec, gko::dim<2> {size, 1}, createGkoArray(exec, std::span {ptr, size}), 1
85template<
typename ValueType>
86std::shared_ptr<gko::matrix::Dense<ValueType>>
87createGkoDense(std::shared_ptr<const gko::Executor> exec,
const ValueType* ptr, localIdx s)
89 auto size =
static_cast<std::size_t
>(s);
90 auto const_array_view = gko::array<ValueType>::const_view(exec, size, ptr);
91 return gko::share(gko::matrix::Dense<ValueType>::create(
92 exec, gko::dim<2> {size, 1}, const_array_view.copy_to_array(), 1
97gko::config::pnode parse(
const Dictionary& dict);
99class GinkgoSolver :
public SolverFactory::template
Register<GinkgoSolver>
106 GinkgoSolver(Executor exec,
const Dictionary& solverConfig)
107 : Base(exec), gkoExec_(getGkoExecutor(exec)), config_(parse(solverConfig)),
108 factory_(gko::config::parse(
109 config_, gko::config::registry(), gko::config::make_type_descriptor<
scalar>()
114 static std::string
name() {
return "Ginkgo"; }
116 static std::string doc() {
return "TBD"; }
118 static std::string schema() {
return "none"; }
121 solve(
const LinearSystem<scalar, localIdx>& sys, Vector<scalar>& x)
const final
123 auto startEval = std::chrono::steady_clock::now();
124 using vec = gko::matrix::Dense<scalar>;
126 auto retrieve = [](
const auto& in)
128 auto host = vec::create(in->get_executor()->get_master(), gko::dim<2> {1});
129 scalar res = host->copy_from(in)->at(0);
133 auto nrows = sys.rhs().size();
135 auto gkoMtx = detail::createGkoMtx(gkoExec_, sys);
136 auto solver = factory_->generate(gkoMtx);
138 std::shared_ptr<const gko::log::Convergence<scalar>> logger =
139 gko::log::Convergence<scalar>::create();
140 solver->add_logger(logger);
142 auto rhs = detail::createGkoDense(gkoExec_, sys.rhs().data(), nrows);
143 auto rhs2 = detail::createGkoDense(gkoExec_, sys.rhs().data(), nrows);
144 auto gkoX = detail::createGkoDense(gkoExec_, x.data(), nrows);
146 auto one = gko::initialize<vec>({1.0}, gkoExec_);
147 auto neg_one = gko::initialize<vec>({-1.0}, gkoExec_);
148 auto init = gko::initialize<vec>({0.0}, gkoExec_);
149 gkoMtx->apply(one, gkoX, neg_one, rhs2);
150 rhs->compute_norm2(init);
151 scalar initResNorm = retrieve(init);
153 solver->apply(rhs, gkoX);
155 scalar finalResNorm = retrieve(gko::as<vec>(logger->get_residual_norm()));
157 auto numIter =
label(logger->get_num_iterations());
159 auto endEval = std::chrono::steady_clock::now();
162 std::chrono::duration_cast<std::chrono::microseconds>(endEval - startEval).count()
165 return {numIter, initResNorm, finalResNorm, duration};
169 virtual std::unique_ptr<SolverFactory> clone() const final
177 std::shared_ptr<const gko::Executor> gkoExec_;
178 gko::config::pnode config_;
179 std::shared_ptr<const gko::LinOpFactory> factory_;
A template class for registering derived classes with a base class.
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
void solve(Expression< typename VectorType::ElementType > &exp, VectorType &solution, scalar t, scalar dt, const Dictionary &fvSchemes, const Dictionary &fvSolution)
KOKKOS_INLINE_FUNCTION T one()
const std::string & name(const NeoN::Document &doc)
Retrieves the name of a Document.