10#include <Kokkos_Core.hpp>
11#include <petscvec_kokkos.hpp>
22namespace NeoN::la::petscSolverContext
25template<
typename ValueType>
26class petscSolverContext
43 petscSolverContext(Executor exec)
44 : init_(false), updated_(false), exec_(exec), Amat_(nullptr), sol_(nullptr), rhs_(nullptr),
50 virtual ~petscSolverContext()
62 bool initialized() const noexcept {
return init_; }
66 bool updated() const noexcept {
return updated_; }
69 void initialize(
const LinearSystem<scalar, localIdx>& sys)
71 std::size_t size = sys.matrix().values().size();
72 std::size_t nrows = sys.rhs().size();
73 PetscInt colIdx[size];
74 PetscInt rowIdx[size];
75 PetscInt rhsIdx[nrows];
79 auto hostLS = sys.copyToHost();
83 auto rowPtrHost = hostLS.matrix().rowPtrs().view();
84 auto colIdxHost = hostLS.matrix().colIdxs().view();
85 auto rhsHost = sys.rhs().copyToHost();
87 for (
size_t index = 0; index < nrows; ++index)
89 rhsIdx[index] =
static_cast<PetscInt
>(index);
94 for (
size_t index = 0; index < size; ++index)
96 colIdx[index] =
static_cast<PetscInt
>(colIdxHost[index]);
102 size_t rowOffset = rowPtrHost[rowI + 1];
103 for (
size_t index = 0; index < size; ++index)
105 if (index == rowOffset)
108 rowOffset = rowPtrHost[rowI + 1];
110 rowIdx[index] = rowI;
115 PetscInitialize(NULL, NULL, 0, NULL);
117 MatCreate(PETSC_COMM_WORLD, &Amat_);
118 MatSetSizes(Amat_, sys.matrix().nRows(), sys.rhs().size(), PETSC_DECIDE, PETSC_DECIDE);
120 VecCreate(PETSC_COMM_SELF, &rhs_);
121 VecSetSizes(rhs_, PETSC_DECIDE, nrows);
123 std::string execName = std::visit([](
const auto& e) {
return e.name(); }, exec_);
125 if (execName ==
"GPUExecutor")
127 VecSetType(rhs_, VECKOKKOS);
128 MatSetType(Amat_, MATAIJKOKKOS);
132 VecSetType(rhs_, VECSEQ);
133 MatSetType(Amat_, MATSEQAIJ);
135 VecDuplicate(rhs_, &sol_);
137 VecSetPreallocationCOO(rhs_, nrows, rhsIdx);
138 MatSetPreallocationCOO(Amat_, size, colIdx, rowIdx);
140 KSPCreate(PETSC_COMM_WORLD, &ksp_);
141 KSPSetFromOptions(ksp_);
142 KSPSetOperators(ksp_, Amat_, Amat_);
149 void update() {
NF_ERROR_EXIT(
"Mesh changes not supported"); }
151 [[nodiscard]] Mat& AMat() {
return Amat_; }
153 [[nodiscard]] Vec& rhs() {
return rhs_; }
155 [[nodiscard]] Vec& sol() {
return sol_; }
157 [[nodiscard]] KSP& ksp() {
return ksp_; }
#define NF_ERROR_EXIT(message)
Macro for printing an error message and aborting the program.
std::variant< SerialExecutor, CPUExecutor, GPUExecutor > Executor