12#include <Kokkos_Core.hpp>
13#include <petscvec_kokkos.hpp>
24namespace NeoN::la::petscSolverContext
27template<
typename ValueType>
28class petscSolverContext
45 petscSolverContext(Executor exec)
46 : init_(false), updated_(false), exec_(exec), Amat_(nullptr), sol_(nullptr), rhs_(nullptr),
52 virtual ~petscSolverContext()
64 bool initialized() const noexcept {
return init_; }
68 bool updated() const noexcept {
return updated_; }
71 void initialize(
const LinearSystem<scalar, localIdx>& sys)
73 std::size_t size = sys.matrix().values().size();
74 std::size_t nrows = sys.rhs().size();
75 PetscInt colIdx[size];
76 PetscInt rowIdx[size];
77 PetscInt rhsIdx[nrows];
81 auto hostLS = sys.copyToHost();
85 auto rowPtrHost = hostLS.matrix().rowOffs().view();
86 auto colIdxHost = hostLS.matrix().colIdxs().view();
87 auto rhsHost = sys.rhs().copyToHost();
89 for (
size_t index = 0; index < nrows; ++index)
91 rhsIdx[index] =
static_cast<PetscInt
>(index);
96 for (
size_t index = 0; index < size; ++index)
98 colIdx[index] =
static_cast<PetscInt
>(colIdxHost[index]);
104 size_t rowOffset = rowPtrHost[rowI + 1];
105 for (
size_t index = 0; index < size; ++index)
107 if (index == rowOffset)
110 rowOffset = rowPtrHost[rowI + 1];
112 rowIdx[index] = rowI;
117 PetscInitialize(NULL, NULL, 0, NULL);
119 MatCreate(PETSC_COMM_WORLD, &Amat_);
120 MatSetSizes(Amat_, sys.matrix().nRows(), sys.rhs().size(), PETSC_DECIDE, PETSC_DECIDE);
122 VecCreate(PETSC_COMM_SELF, &rhs_);
123 VecSetSizes(rhs_, PETSC_DECIDE, nrows);
125 std::string execName = std::visit([](
const auto& e) {
return e.name(); }, exec_);
127 if (execName ==
"GPUExecutor")
129 VecSetType(rhs_, VECKOKKOS);
130 MatSetType(Amat_, MATAIJKOKKOS);
134 VecSetType(rhs_, VECSEQ);
135 MatSetType(Amat_, MATSEQAIJ);
137 VecDuplicate(rhs_, &sol_);
139 VecSetPreallocationCOO(rhs_, nrows, rhsIdx);
140 MatSetPreallocationCOO(Amat_, size, colIdx, rowIdx);
142 KSPCreate(PETSC_COMM_WORLD, &ksp_);
143 KSPSetFromOptions(ksp_);
144 KSPSetOperators(ksp_, Amat_, Amat_);
151 void update() {
NF_ERROR_EXIT(
"Mesh changes not supported"); }
153 [[nodiscard]] Mat& AMat() {
return Amat_; }
155 [[nodiscard]] Vec& rhs() {
return rhs_; }
157 [[nodiscard]] Vec& sol() {
return sol_; }
159 [[nodiscard]] KSP& ksp() {
return ksp_; }
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
std::variant< SerialExecutor, CPUExecutor, GPUExecutor > Executor