82 assert(
size() == m_subdiag.extent(1));
84 assert(
size() == m_diag.extent(1));
86 assert(
size() == m_uppdiag.extent(1));
89 int const tmp_mat_size =
size();
90 bool is_diagdom =
false;
91 bool is_symmetric =
false;
93 DKokkosView2D subdiag_proxy = m_subdiag;
94 DKokkosView2D diag_proxy = m_diag;
95 DKokkosView2D uppdiag_proxy = m_uppdiag;
97 Kokkos::MDRangePolicy<Kokkos::Rank<2>> batch_policy({0, 0}, {tmp_batch_size, tmp_mat_size});
98 Kokkos::parallel_reduce(
101 KOKKOS_LAMBDA(
int batch_idx,
int k,
bool& check_diag_dom) {
102 check_diag_dom = check_diag_dom
103 && (Kokkos::abs(subdiag_proxy(batch_idx, k))
104 + Kokkos::abs(uppdiag_proxy(batch_idx, k))
105 <= Kokkos::abs(diag_proxy(batch_idx, k)));
107 Kokkos::LAnd<bool>(is_diagdom));
108 Kokkos::parallel_reduce(
111 KOKKOS_LAMBDA(
int batch_idx,
int k,
bool& check_sym) {
115 subdiag_proxy(batch_idx, k) - uppdiag_proxy(batch_idx, k))
118 Kokkos::LAnd<bool>(is_symmetric));
120 return (is_diagdom || is_symmetric);
143 assert(
size() == b.extent(1));
145 int const tmp_batch_size = m_subdiag.extent(0);
146 int const tmp_mat_size = m_subdiag.extent(1);
151 Kokkos::DefaultExecutionSpace::memory_space>
const
156 Kokkos::DefaultExecutionSpace::memory_space>
const
159 DKokkosView2D subdiag_proxy = m_subdiag;
160 DKokkosView2D diag_proxy = m_diag;
161 DKokkosView2D uppdiag_proxy = m_uppdiag;
163 Kokkos::parallel_for(
164 "Tridiagonal solver",
165 Kokkos::RangePolicy<ExecSpace>(0, tmp_batch_size),
166 KOKKOS_LAMBDA(
const int batch_idx) {
168 cprim(batch_idx, 0) = uppdiag_proxy(batch_idx, 0) / diag_proxy(batch_idx, 0);
169 dprim(batch_idx, 0) = b(batch_idx, 0) / diag_proxy(batch_idx, 0);
170 for (
int i = 1; i < tmp_mat_size; i++) {
172 = uppdiag_proxy(batch_idx, i)
173 / (diag_proxy(batch_idx, i)
174 - subdiag_proxy(batch_idx, i) * cprim(batch_idx, i - 1));
177 - subdiag_proxy(batch_idx, i) * dprim(batch_idx, i - 1))
178 / (diag_proxy(batch_idx, i)
179 - subdiag_proxy(batch_idx, i) * cprim(batch_idx, i - 1));
182 b(batch_idx, tmp_mat_size - 1) = dprim(batch_idx, tmp_mat_size - 1);
183 for (
int i = tmp_mat_size - 2; i >= 0; i--) {
185 = dprim(batch_idx, i) - cprim(batch_idx, i) * b(batch_idx, i + 1);