Gyselalib++
 
Loading...
Searching...
No Matches
matrix_batch_tridiag.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <sll/matrix_batch.hpp>
8
9#include <Kokkos_Core.hpp>
10
22template <class ExecSpace>
23class MatrixBatchTridiag : public MatrixBatch<ExecSpace>
24{
25public:
27 using MatrixBatch<ExecSpace>::size;
28 using MatrixBatch<ExecSpace>::batch_size;
29
30private:
34 using DKokkosView2D
35 = Kokkos::View<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space>;
36 DKokkosView2D m_subdiag;
37 DKokkosView2D m_diag;
38 DKokkosView2D m_uppdiag;
39
40public:
55 const int batch_size,
56 const int mat_size,
57 DKokkosView2D const aa,
58 DKokkosView2D const bb,
59 DKokkosView2D const cc)
60
61 : MatrixBatch<ExecSpace>(batch_size, mat_size)
62 , m_subdiag(aa)
63 , m_diag(bb)
64 , m_uppdiag(cc)
65 {
66 }
67
79 bool check_stability() const
80 {
81 assert(batch_size() == m_subdiag.extent(0));
82 assert(size() == m_subdiag.extent(1));
83 assert(batch_size() == m_diag.extent(0));
84 assert(size() == m_diag.extent(1));
85 assert(batch_size() == m_uppdiag.extent(0));
86 assert(size() == m_uppdiag.extent(1));
87
88 int const tmp_batch_size = batch_size();
89 int const tmp_mat_size = size();
90 bool is_diagdom = false;
91 bool is_symmetric = false;
92
93 DKokkosView2D subdiag_proxy = m_subdiag;
94 DKokkosView2D diag_proxy = m_diag;
95 DKokkosView2D uppdiag_proxy = m_uppdiag;
96
97 Kokkos::MDRangePolicy<Kokkos::Rank<2>> batch_policy({0, 0}, {tmp_batch_size, tmp_mat_size});
98 Kokkos::parallel_reduce(
99 "DiagDominant",
100 batch_policy,
101 KOKKOS_LAMBDA(int batch_idx, int k, bool& check_diag_dom) {
102 check_diag_dom = check_diag_dom
103 && (Kokkos::abs(subdiag_proxy(batch_idx, k))
104 + Kokkos::abs(uppdiag_proxy(batch_idx, k))
105 <= Kokkos::abs(diag_proxy(batch_idx, k)));
106 },
107 Kokkos::LAnd<bool>(is_diagdom));
108 Kokkos::parallel_reduce(
109 "Symmetric",
110 batch_policy,
111 KOKKOS_LAMBDA(int batch_idx, int k, bool& check_sym) {
112 check_sym
113 = check_sym
114 && (Kokkos::abs(
115 subdiag_proxy(batch_idx, k) - uppdiag_proxy(batch_idx, k))
116 < 1e-16);
117 },
118 Kokkos::LAnd<bool>(is_symmetric));
119
120 return (is_diagdom || is_symmetric);
121 }
122
130 void setup_solver() final
131 {
132 assert(check_stability());
133 }
134
140 void solve(BatchedRHS const b) const final
141 {
142 assert(batch_size() == b.extent(0));
143 assert(size() == b.extent(1));
144
145 int const tmp_batch_size = m_subdiag.extent(0);
146 int const tmp_mat_size = m_subdiag.extent(1);
147
148 Kokkos::View<
149 double**,
150 Kokkos::LayoutRight,
151 Kokkos::DefaultExecutionSpace::memory_space> const
152 cprim("cprim", batch_size(), size());
153 Kokkos::View<
154 double**,
155 Kokkos::LayoutRight,
156 Kokkos::DefaultExecutionSpace::memory_space> const
157 dprim("dprim", batch_size(), size());
158
159 DKokkosView2D subdiag_proxy = m_subdiag;
160 DKokkosView2D diag_proxy = m_diag;
161 DKokkosView2D uppdiag_proxy = m_uppdiag;
162
163 Kokkos::parallel_for(
164 "Tridiagonal solver",
165 Kokkos::RangePolicy<ExecSpace>(0, tmp_batch_size),
166 KOKKOS_LAMBDA(const int batch_idx) {
167 //ForwardStep
168 cprim(batch_idx, 0) = uppdiag_proxy(batch_idx, 0) / diag_proxy(batch_idx, 0);
169 dprim(batch_idx, 0) = b(batch_idx, 0) / diag_proxy(batch_idx, 0);
170 for (int i = 1; i < tmp_mat_size; i++) {
171 cprim(batch_idx, i)
172 = uppdiag_proxy(batch_idx, i)
173 / (diag_proxy(batch_idx, i)
174 - subdiag_proxy(batch_idx, i) * cprim(batch_idx, i - 1));
175 dprim(batch_idx, i)
176 = (b(batch_idx, i)
177 - subdiag_proxy(batch_idx, i) * dprim(batch_idx, i - 1))
178 / (diag_proxy(batch_idx, i)
179 - subdiag_proxy(batch_idx, i) * cprim(batch_idx, i - 1));
180 }
181 //BackwardStep
182 b(batch_idx, tmp_mat_size - 1) = dprim(batch_idx, tmp_mat_size - 1);
183 for (int i = tmp_mat_size - 2; i >= 0; i--) {
184 b(batch_idx, i)
185 = dprim(batch_idx, i) - cprim(batch_idx, i) * b(batch_idx, i + 1);
186 }
187 });
188 }
189};
A structure for solving a set of independant tridiagonal systems using a direct method.
Definition matrix_batch_tridiag.hpp:24
void solve(BatchedRHS const b) const final
Solve the batched linear problem Ax=b.
Definition matrix_batch_tridiag.hpp:140
bool check_stability() const
Check if the matrices are in the stability area of the solver.
Definition matrix_batch_tridiag.hpp:79
MatrixBatchTridiag(const int batch_size, const int mat_size, DKokkosView2D const aa, DKokkosView2D const bb, DKokkosView2D const cc)
Creates an instance of the MatrixBatchTridiag class.
Definition matrix_batch_tridiag.hpp:54
void setup_solver() final
Perform a pre-process operation on the solver.
Definition matrix_batch_tridiag.hpp:130
MatrixBatch superclass for managing a collection of linear systems.
Definition matrix_batch.hpp:23
std::size_t batch_size() const
Get the batch size of the linear problem.
Definition matrix_batch.hpp:76
std::size_t size() const
Get the size of the square matrix corresponding to a single batch in one of its dimensions.
Definition matrix_batch.hpp:66
Kokkos::View< double **, Kokkos::LayoutRight, ExecSpace > BatchedRHS
The type of a Kokkos::View storing batched right-hand sides. Second dimenion is batch dimension.
Definition matrix_batch.hpp:26