Gyselalib++
 
Loading...
Searching...
No Matches
itimestepper.hpp
1// SPDX-License-Identifier: MIT
2#pragma once
3#include <array>
4#include <type_traits>
5
6#include <ddc/ddc.hpp>
7
8#include "multipatch_field.hpp"
9#include "multipatch_field_mem.hpp"
10
18template <
19 class FieldMem,
20 class DerivFieldMem = FieldMem,
21 class ExecSpace = Kokkos::DefaultExecutionSpace>
23{
24 static_assert(
25 (ddc::is_chunk_v<FieldMem>) or (is_vector_field_v<FieldMem>)
26 or (is_multipatch_field_mem_v<FieldMem>));
27 static_assert(
28 (ddc::is_chunk_v<DerivFieldMem>) or (is_vector_field_v<DerivFieldMem>)
29 or (is_multipatch_field_mem_v<DerivFieldMem>));
30
31 static_assert(
32 (std::is_same_v<
33 typename FieldMem::discrete_domain_type,
34 typename DerivFieldMem::discrete_domain_type>)
35 || (is_multipatch_field_mem_v<FieldMem> && is_multipatch_field_mem_v<DerivFieldMem>));
36
37 static_assert(
38 Kokkos::SpaceAccessibility<ExecSpace, typename FieldMem::memory_space>::accessible,
39 "MemorySpace has to be accessible for ExecutionSpace.");
40 static_assert(
41 Kokkos::SpaceAccessibility<ExecSpace, typename DerivFieldMem::memory_space>::accessible,
42 "MemorySpace has to be accessible for ExecutionSpace.");
43
44public:
46 using IdxRange = typename FieldMem::discrete_domain_type;
47
48
50 using ValField = typename FieldMem::span_type;
51
53 using ValConstField = typename FieldMem::view_type;
54
56 using DerivField = typename DerivFieldMem::span_type;
57
59 using DerivConstField = typename DerivFieldMem::view_type;
60
61public:
78 void update(ValField y, double dt, std::function<void(DerivField, ValConstField)> dy_calculator)
79 const
80 {
81 update(ExecSpace(), y, dt, dy_calculator);
82 }
83
102 void update(
103 ExecSpace const& exec_space,
104 ValField y,
105 double dt,
106 std::function<void(DerivField, ValConstField)> dy_calculator) const
107 {
108 static_assert(ddc::is_chunk_v<FieldMem>);
109 using Idx = typename IdxRange::discrete_element_type;
110 update(exec_space, y, dt, dy_calculator, [&](ValField y, DerivConstField dy, double dt) {
111 ddc::parallel_for_each(
112 exec_space,
113 get_idx_range(y),
114 KOKKOS_LAMBDA(Idx const idx) { y(idx) = y(idx) + dy(idx) * dt; });
115 });
116 }
117
133 virtual void update(
134 ExecSpace const& exec_space,
135 ValField y,
136 double dt,
137 std::function<void(DerivField, ValConstField)> dy_calculator,
138 std::function<void(ValField, DerivConstField, double)> y_update) const = 0;
139
140protected:
147 void copy(ValField copy_to, ValConstField copy_from) const
148 {
149 if constexpr (ddc::is_chunk_v<ValField>) {
150 ddc::parallel_deepcopy(copy_to, copy_from);
151 } else {
152 ddcHelper::deepcopy(copy_to, copy_from);
153 }
154 }
155
163 template <class DerivFieldType, class Idx, class... DDims>
164 KOKKOS_FUNCTION static void fill_k_total(DerivFieldType k_total, Idx i, Coord<DDims...> new_val)
165 {
166 static_assert(
167 (std::is_same_v<DerivField, DerivFieldType>)
168 || (is_multipatch_field_v<DerivField>));
169 ((ddcHelper::get<DDims>(k_total)(i) = ddc::get<DDims>(new_val)), ...);
170 }
171
181 template <class FuncType, class... T>
182 void assemble_k_total(ExecSpace const& exec_space, DerivField k_total, FuncType func, T... k)
183 const
184 {
185 static_assert(std::conjunction_v<std::is_same<T, DerivField>...>);
186 std::size_t constexpr n_args = sizeof...(T);
187 using element_type = typename DerivField::element_type;
188 static_assert(
189 std::is_invocable_r_v<element_type, FuncType, std::array<element_type, n_args>>);
190 std::array<DerivField, n_args> k_arr({k...});
191 if constexpr (is_vector_field_v<DerivField>) {
192 assemble_vector_field_k_total(exec_space, k_total, func, k_arr);
193 } else if constexpr (is_multipatch_field_v<DerivField>) {
194 assemble_multipatch_field_k_total(exec_space, k_total, func, k_arr);
195 } else {
196 assemble_field_k_total(exec_space, k_total, func, k_arr);
197 }
198 }
199
200public:
210 template <class FieldType, class FuncType, std::size_t n_args>
212 ExecSpace const& exec_space,
213 FieldType k_total,
214 FuncType func,
215 std::array<FieldType, n_args> k_arr) const
216 {
217 static_assert(ddc::is_chunk_v<FieldType>);
218 using Idx = typename FieldType::discrete_domain_type::discrete_element_type;
219 ddc::parallel_for_each(
220 exec_space,
221 get_idx_range(k_total),
222 KOKKOS_LAMBDA(Idx const i) {
223 std::array<double, n_args> k_elems;
224 for (int j(0); j < n_args; ++j) {
225 k_elems[j] = k_arr[j](i);
226 }
227 k_total(i) = func(k_elems);
228 });
229 }
230
240 template <class FieldType, class FuncType, std::size_t n_args>
242 ExecSpace const& exec_space,
243 FieldType k_total,
244 FuncType func,
245 std::array<FieldType, n_args> k_arr) const
246 {
247 static_assert(is_vector_field_v<FieldType>);
248 using element_type = typename FieldType::element_type;
249 using Idx = typename FieldType::discrete_domain_type::discrete_element_type;
250 ddc::parallel_for_each(
251 exec_space,
252 get_idx_range(k_total),
253 KOKKOS_LAMBDA(Idx const i) {
254 std::array<element_type, n_args> k_elems;
255 for (int j(0); j < n_args; ++j) {
256 k_elems[j] = k_arr[j](i);
257 }
258 fill_k_total(k_total, i, func(k_elems));
259 });
260 }
261
262private:
271 template <
272 class Patch,
273 template <typename P>
274 typename T,
275 class... Patches,
276 class FuncType,
277 std::size_t n_args>
278 void assemble_multipatch_field_k_total_on_patch(
279 ExecSpace const& exec_space,
281 FuncType func,
282 std::array<MultipatchField<T, Patches...>, n_args> k_arr) const
283 {
284 using FieldType = T<Patch>;
285 static_assert((ddc::is_chunk_v<FieldType>) or (is_vector_field_v<FieldType>));
286 std::array<FieldType, n_args> k_arr_on_patch;
287 FieldType k_total_on_patch = k_total.template get<Patch>();
288 for (std::size_t i(0); i < n_args; ++i) {
289 k_arr_on_patch[i] = k_arr[i].template get<Patch>();
290 }
291 if constexpr (is_vector_field_v<FieldType>) {
292 assemble_vector_field_k_total(exec_space, k_total_on_patch, func, k_arr_on_patch);
293 } else {
294 assemble_field_k_total(exec_space, k_total_on_patch, func, k_arr_on_patch);
295 }
296 }
297
306 template <
307 template <typename P>
308 typename T,
309 class... Patches,
310 class FuncType,
311 std::size_t n_args>
312 void assemble_multipatch_field_k_total(
313 ExecSpace const& exec_space,
315 FuncType func,
316 std::array<MultipatchField<T, Patches...>, n_args> k_arr) const
317 {
318 ((assemble_multipatch_field_k_total_on_patch<Patches>(exec_space, k_total, func, k_arr)),
319 ...);
320 }
321};
See DerivFieldMemImplementation.
Definition derivative_field.hpp:10
See DerivFieldImplementation.
Definition derivative_field.hpp:20
The superclass from which all timestepping methods inherit.
Definition itimestepper.hpp:23
typename FieldMem::discrete_domain_type IdxRange
The type of the index range on which the values of the function are defined.
Definition itimestepper.hpp:46
void assemble_vector_field_k_total(ExecSpace const &exec_space, FieldType k_total, FuncType func, std::array< FieldType, n_args > k_arr) const
Calculate func(k_arr[0], k_arr[1], ...) when FieldType is a VectorField.
Definition itimestepper.hpp:241
void assemble_k_total(ExecSpace const &exec_space, DerivField k_total, FuncType func, T... k) const
A method to assemble multiple derivative fields into one.
Definition itimestepper.hpp:182
typename DerivFieldMem::view_type DerivConstField
The constant type of the derivatives values of the function being evolved.
Definition itimestepper.hpp:59
void update(ExecSpace const &exec_space, ValField y, double dt, std::function< void(DerivField, ValConstField)> dy_calculator) const
Carry out one step of the timestepping scheme.
Definition itimestepper.hpp:102
virtual void update(ExecSpace const &exec_space, ValField y, double dt, std::function< void(DerivField, ValConstField)> dy_calculator, std::function< void(ValField, DerivConstField, double)> y_update) const =0
Carry out one step of the timestepping scheme.
typename FieldMem::span_type ValField
The type of the values of the function being evolved.
Definition itimestepper.hpp:50
void update(ValField y, double dt, std::function< void(DerivField, ValConstField)> dy_calculator) const
Carry out one step of the timestepping scheme.
Definition itimestepper.hpp:78
void copy(ValField copy_to, ValConstField copy_from) const
Make a copy of the values of the function being evolved.
Definition itimestepper.hpp:147
void assemble_field_k_total(ExecSpace const &exec_space, FieldType k_total, FuncType func, std::array< FieldType, n_args > k_arr) const
Calculate func(k_arr[0], k_arr[1], ...) when FieldType is a Field (ddc::ChunkSpan).
Definition itimestepper.hpp:211
typename FieldMem::view_type ValConstField
The constant type of the values of the function being evolved.
Definition itimestepper.hpp:53
static KOKKOS_FUNCTION void fill_k_total(DerivFieldType k_total, Idx i, Coord< DDims... > new_val)
A method to fill an element of a vector field.
Definition itimestepper.hpp:164
A class to store field objects on patches.
Definition multipatch_field.hpp:30
Base tag for a patch.
Definition patch.hpp:14
A class which describes the real space in the temporal direction.
Definition geometry.hpp:44