Gyselalib++
 
Loading...
Searching...
No Matches
quadrature.hpp
1// SPDX-License-Identifier: MIT
2
3#pragma once
4#include <cassert>
5
6#include <ddc/ddc.hpp>
7
8#include <Kokkos_Core.hpp>
9
10#include "ddc_alias_inline_functions.hpp"
11#include "ddc_aliases.hpp"
12#include "ddc_helper.hpp"
13
24template <
25 class IdxRangeQuadrature,
26 class IdxRangeTotal = IdxRangeQuadrature,
27 class MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
29{
30private:
32 using IdxQuadrature = typename IdxRangeQuadrature::discrete_element_type;
33
34 using QuadConstField = DConstField<IdxRangeQuadrature, MemorySpace>;
35
36 QuadConstField m_coefficients;
37
38public:
44 explicit Quadrature(QuadConstField coeffs) : m_coefficients(coeffs) {}
45
59 template <class ExecutionSpace, class IntegratorFunction>
60 double operator()(ExecutionSpace exec_space, IntegratorFunction integrated_function) const
61 {
62 static_assert(
63 Kokkos::SpaceAccessibility<ExecutionSpace, MemorySpace>::accessible,
64 "Execution space is not compatible with memory space where coefficients are found");
65 static_assert(
66 std::is_invocable_v<IntegratorFunction, IdxQuadrature>,
67 "The object passed to Quadrature::operator() is not defined on the quadrature "
68 "idx_range.");
69
70 QuadConstField const coeff_proxy = m_coefficients;
71
72 // This fence helps avoid a CPU seg fault. See #290 for more details
73 exec_space.fence();
74 // This condition is necessary to execute in serial, even in a device activated build.
75 // Without it a seg fault appears
76 if constexpr (std::is_same_v<ExecutionSpace, Kokkos::DefaultHostExecutionSpace>) {
77 return ddc::transform_reduce(
78 get_idx_range(coeff_proxy),
79 0.0,
80 ddc::reducer::sum<double>(),
81 KOKKOS_LAMBDA(IdxQuadrature const ix) {
82 return coeff_proxy(ix) * integrated_function(ix);
83 });
84 } else {
85 return ddc::parallel_transform_reduce(
86 exec_space,
87 get_idx_range(coeff_proxy),
88 0.0,
89 ddc::reducer::sum<double>(),
90 KOKKOS_LAMBDA(IdxQuadrature const ix) {
91 return coeff_proxy(ix) * integrated_function(ix);
92 });
93 }
94 }
95
111 template <class ExecutionSpace, class BatchIdxRange, class IntegratorFunction>
113 ExecutionSpace exec_space,
114 Field<double, BatchIdxRange, MemorySpace> const result,
115 IntegratorFunction integrated_function) const
116 {
117 static_assert(
118 Kokkos::SpaceAccessibility<ExecutionSpace, MemorySpace>::accessible,
119 "Execution space is not compatible with memory space where coefficients are found");
120 static_assert(
121 std::is_same_v<ExecutionSpace, Kokkos::DefaultExecutionSpace>,
122 "Kokkos::TeamPolicy only works with the default execution space. Please use "
123 "DefaultExecutionSpace to call this batched operator.");
124 using ExpectedBatchDims = ddc::type_seq_remove_t<
125 ddc::to_type_seq_t<IdxRangeTotal>,
126 ddc::to_type_seq_t<IdxRangeQuadrature>>;
127 static_assert(
128 ddc::type_seq_same_v<ddc::to_type_seq_t<BatchIdxRange>, ExpectedBatchDims>,
129 "The batch idx_range deduced from the type of result does not match the class "
130 "template parameters.");
131
132 // Get useful index types
133 using IdxTotal = typename IdxRangeTotal::discrete_element_type;
134 using IdxBatch = typename BatchIdxRange::discrete_element_type;
135
136 static_assert(
137 std::is_invocable_v<IntegratorFunction, IdxTotal>,
138 "The object passed to Quadrature::operator() is not defined on the total "
139 "idx_range.");
140
141 // Get index ranges
142 IdxRangeQuadrature quad_idx_range(get_idx_range(m_coefficients));
143 BatchIdxRange batch_idx_range(get_idx_range(result));
144
145 QuadConstField const coeff_proxy = m_coefficients;
146 // Loop over batch dimensions
147 Kokkos::parallel_for(
148 Kokkos::TeamPolicy<>(exec_space, batch_idx_range.size(), Kokkos::AUTO),
149 KOKKOS_LAMBDA(const Kokkos::TeamPolicy<>::member_type& team) {
150 const int idx = team.league_rank();
151 IdxBatch ib = to_discrete_element(idx, batch_idx_range);
152
153 // Sum over quadrature dimensions
154 double teamSum = 0;
155 Kokkos::parallel_reduce(
156 Kokkos::TeamThreadRange(team, quad_idx_range.size()),
157 [&](int const& thread_index, double& sum) {
158 IdxQuadrature iq
159 = to_discrete_element(thread_index, quad_idx_range);
160 IdxTotal it(ib, iq);
161 sum += coeff_proxy(iq) * integrated_function(it);
162 },
163 teamSum);
164 result(ib) = teamSum;
165 });
166 }
167
168private:
179 template <class HeadDim, class... Grid1D>
180 KOKKOS_FUNCTION static Idx<HeadDim, Grid1D...> to_discrete_element(
181 int idx,
182 IdxRange<HeadDim, Grid1D...> idx_range)
183 {
184 IdxRange<Grid1D...> subidx_range(idx_range);
185 Idx<HeadDim> head_idx(ddc::select<HeadDim>(idx_range).front() + idx / subidx_range.size());
186 if constexpr (sizeof...(Grid1D) == 0) {
187 return head_idx;
188 } else {
189 Idx<Grid1D...> tail_idx = to_discrete_element(idx % subidx_range.size(), subidx_range);
190 return Idx<HeadDim, Grid1D...>(head_idx, tail_idx);
191 }
192 }
193};
194
195namespace detail {
196template <class NewMemorySpace, class IdxRangeQuadrature, class IdxRangeTotal, class MemorySpace>
197struct OnMemorySpace<NewMemorySpace, Quadrature<IdxRangeQuadrature, IdxRangeTotal, MemorySpace>>
198{
200};
201} // namespace detail
A class providing an operator for integrating functions defined on a discrete index range.
Definition quadrature.hpp:29
double operator()(ExecutionSpace exec_space, IntegratorFunction integrated_function) const
An operator for calculating the integral of a function defined on a discrete index range.
Definition quadrature.hpp:60
void operator()(ExecutionSpace exec_space, Field< double, BatchIdxRange, MemorySpace > const result, IntegratorFunction integrated_function) const
An operator for calculating the integral of a function defined on a discrete index range by cycling o...
Definition quadrature.hpp:112
Quadrature(QuadConstField coeffs)
Create a Quadrature object.
Definition quadrature.hpp:44