Gyselalib++
itimestepper.hpp
1 // SPDX-License-Identifier: MIT
2 #pragma once
3 #include <array>
4 #include <type_traits>
5 
6 #include <ddc/ddc.hpp>
7 
8 #include "multipatch_field.hpp"
9 #include "multipatch_field_mem.hpp"
10 
18 template <
19  class FieldMem,
20  class DerivFieldMem = FieldMem,
21  class ExecSpace = Kokkos::DefaultExecutionSpace>
23 {
24  static_assert(
25  (ddc::is_chunk_v<FieldMem>) or (is_vector_field_v<FieldMem>)
26  or (is_multipatch_field_mem_v<FieldMem>));
27  static_assert(
28  (ddc::is_chunk_v<DerivFieldMem>) or (is_vector_field_v<DerivFieldMem>)
29  or (is_multipatch_field_mem_v<DerivFieldMem>));
30 
31  static_assert(
32  (std::is_same_v<
33  typename FieldMem::discrete_domain_type,
34  typename DerivFieldMem::discrete_domain_type>)
35  || (is_multipatch_field_mem_v<FieldMem> && is_multipatch_field_mem_v<DerivFieldMem>));
36 
37  static_assert(
38  Kokkos::SpaceAccessibility<ExecSpace, typename FieldMem::memory_space>::accessible,
39  "MemorySpace has to be accessible for ExecutionSpace.");
40  static_assert(
41  Kokkos::SpaceAccessibility<ExecSpace, typename DerivFieldMem::memory_space>::accessible,
42  "MemorySpace has to be accessible for ExecutionSpace.");
43 
44 public:
46  using IdxRange = typename FieldMem::discrete_domain_type;
47 
48 
50  using ValField = typename FieldMem::span_type;
51 
53  using ValConstField = typename FieldMem::view_type;
54 
56  using DerivField = typename DerivFieldMem::span_type;
57 
59  using DerivConstField = typename DerivFieldMem::view_type;
60 
61 public:
78  void update(ValField y, double dt, std::function<void(DerivField, ValConstField)> dy_calculator)
79  const
80  {
81  update(ExecSpace(), y, dt, dy_calculator);
82  }
83 
102  void update(
103  ExecSpace const& exec_space,
104  ValField y,
105  double dt,
106  std::function<void(DerivField, ValConstField)> dy_calculator) const
107  {
108  static_assert(ddc::is_chunk_v<FieldMem>);
109  using Idx = typename IdxRange::discrete_element_type;
110  update(exec_space, y, dt, dy_calculator, [&](ValField y, DerivConstField dy, double dt) {
111  ddc::parallel_for_each(
112  exec_space,
113  get_idx_range(y),
114  KOKKOS_LAMBDA(Idx const idx) { y(idx) = y(idx) + dy(idx) * dt; });
115  });
116  }
117 
133  virtual void update(
134  ExecSpace const& exec_space,
135  ValField y,
136  double dt,
137  std::function<void(DerivField, ValConstField)> dy_calculator,
138  std::function<void(ValField, DerivConstField, double)> y_update) const = 0;
139 
140 protected:
147  void copy(ValField copy_to, ValConstField copy_from) const
148  {
149  if constexpr (ddc::is_chunk_v<ValField>) {
150  ddc::parallel_deepcopy(copy_to, copy_from);
151  } else {
152  ddcHelper::deepcopy(copy_to, copy_from);
153  }
154  }
155 
163  template <class DerivFieldType, class Idx, class... DDims>
164  KOKKOS_FUNCTION static void fill_k_total(DerivFieldType k_total, Idx i, Coord<DDims...> new_val)
165  {
166  static_assert(
167  (std::is_same_v<DerivField, DerivFieldType>)
168  || (is_multipatch_field_v<DerivField>));
169  ((ddcHelper::get<DDims>(k_total)(i) = ddc::get<DDims>(new_val)), ...);
170  }
171 
181  template <class FuncType, class... T>
182  void assemble_k_total(ExecSpace const& exec_space, DerivField k_total, FuncType func, T... k)
183  const
184  {
185  static_assert(std::conjunction_v<std::is_same<T, DerivField>...>);
186  std::size_t constexpr n_args = sizeof...(T);
187  using element_type = typename DerivField::element_type;
188  static_assert(
189  std::is_invocable_r_v<element_type, FuncType, std::array<element_type, n_args>>);
190  std::array<DerivField, n_args> k_arr({k...});
191  if constexpr (is_vector_field_v<DerivField>) {
192  assemble_vector_field_k_total(exec_space, k_total, func, k_arr);
193  } else if constexpr (is_multipatch_field_v<DerivField>) {
194  assemble_multipatch_field_k_total(exec_space, k_total, func, k_arr);
195  } else {
196  assemble_field_k_total(exec_space, k_total, func, k_arr);
197  }
198  }
199 
200 public:
210  template <class FieldType, class FuncType, std::size_t n_args>
212  ExecSpace const& exec_space,
213  FieldType k_total,
214  FuncType func,
215  std::array<FieldType, n_args> k_arr) const
216  {
217  static_assert(ddc::is_chunk_v<FieldType>);
218  using Idx = typename FieldType::discrete_domain_type::discrete_element_type;
219  ddc::parallel_for_each(
220  exec_space,
221  get_idx_range(k_total),
222  KOKKOS_LAMBDA(Idx const i) {
223  std::array<double, n_args> k_elems;
224  for (int j(0); j < n_args; ++j) {
225  k_elems[j] = k_arr[j](i);
226  }
227  k_total(i) = func(k_elems);
228  });
229  }
230 
240  template <class FieldType, class FuncType, std::size_t n_args>
242  ExecSpace const& exec_space,
243  FieldType k_total,
244  FuncType func,
245  std::array<FieldType, n_args> k_arr) const
246  {
247  static_assert(is_vector_field_v<FieldType>);
248  using element_type = typename FieldType::element_type;
249  using Idx = typename FieldType::discrete_domain_type::discrete_element_type;
250  ddc::parallel_for_each(
251  exec_space,
252  get_idx_range(k_total),
253  KOKKOS_LAMBDA(Idx const i) {
254  std::array<element_type, n_args> k_elems;
255  for (int j(0); j < n_args; ++j) {
256  k_elems[j] = k_arr[j](i);
257  }
258  fill_k_total(k_total, i, func(k_elems));
259  });
260  }
261 
262 private:
271  template <
272  class Patch,
273  template <typename P>
274  typename T,
275  class... Patches,
276  class FuncType,
277  std::size_t n_args>
278  void assemble_multipatch_field_k_total_on_patch(
279  ExecSpace const& exec_space,
281  FuncType func,
282  std::array<MultipatchField<T, Patches...>, n_args> k_arr) const
283  {
284  using FieldType = T<Patch>;
285  static_assert((ddc::is_chunk_v<FieldType>) or (is_vector_field_v<FieldType>));
286  std::array<FieldType, n_args> k_arr_on_patch;
287  FieldType k_total_on_patch = k_total.template get<Patch>();
288  for (std::size_t i(0); i < n_args; ++i) {
289  k_arr_on_patch[i] = k_arr[i].template get<Patch>();
290  }
291  if constexpr (is_vector_field_v<FieldType>) {
292  assemble_vector_field_k_total(exec_space, k_total_on_patch, func, k_arr_on_patch);
293  } else {
294  assemble_field_k_total(exec_space, k_total_on_patch, func, k_arr_on_patch);
295  }
296  }
297 
306  template <
307  template <typename P>
308  typename T,
309  class... Patches,
310  class FuncType,
311  std::size_t n_args>
312  void assemble_multipatch_field_k_total(
313  ExecSpace const& exec_space,
315  FuncType func,
316  std::array<MultipatchField<T, Patches...>, n_args> k_arr) const
317  {
318  ((assemble_multipatch_field_k_total_on_patch<Patches>(exec_space, k_total, func, k_arr)),
319  ...);
320  }
321 };
See DerivFieldMemImplementation.
Definition: derivative_field.hpp:10
See DerivFieldImplementation.
Definition: derivative_field.hpp:20
The superclass from which all timestepping methods inherit.
Definition: itimestepper.hpp:23
typename FieldMem::discrete_domain_type IdxRange
The type of the index range on which the values of the function are defined.
Definition: itimestepper.hpp:46
void assemble_vector_field_k_total(ExecSpace const &exec_space, FieldType k_total, FuncType func, std::array< FieldType, n_args > k_arr) const
Calculate func(k_arr[0], k_arr[1], ...) when FieldType is a VectorField.
Definition: itimestepper.hpp:241
void assemble_k_total(ExecSpace const &exec_space, DerivField k_total, FuncType func, T... k) const
A method to assemble multiple derivative fields into one.
Definition: itimestepper.hpp:182
typename DerivFieldMem::view_type DerivConstField
The constant type of the derivatives values of the function being evolved.
Definition: itimestepper.hpp:59
void update(ExecSpace const &exec_space, ValField y, double dt, std::function< void(DerivField, ValConstField)> dy_calculator) const
Carry out one step of the timestepping scheme.
Definition: itimestepper.hpp:102
virtual void update(ExecSpace const &exec_space, ValField y, double dt, std::function< void(DerivField, ValConstField)> dy_calculator, std::function< void(ValField, DerivConstField, double)> y_update) const =0
Carry out one step of the timestepping scheme.
typename FieldMem::span_type ValField
The type of the values of the function being evolved.
Definition: itimestepper.hpp:50
void update(ValField y, double dt, std::function< void(DerivField, ValConstField)> dy_calculator) const
Carry out one step of the timestepping scheme.
Definition: itimestepper.hpp:78
void copy(ValField copy_to, ValConstField copy_from) const
Make a copy of the values of the function being evolved.
Definition: itimestepper.hpp:147
void assemble_field_k_total(ExecSpace const &exec_space, FieldType k_total, FuncType func, std::array< FieldType, n_args > k_arr) const
Calculate func(k_arr[0], k_arr[1], ...) when FieldType is a Field (ddc::ChunkSpan).
Definition: itimestepper.hpp:211
typename FieldMem::view_type ValConstField
The constant type of the values of the function being evolved.
Definition: itimestepper.hpp:53
static KOKKOS_FUNCTION void fill_k_total(DerivFieldType k_total, Idx i, Coord< DDims... > new_val)
A method to fill an element of a vector field.
Definition: itimestepper.hpp:164
A class to store field objects on patches.
Definition: multipatch_field.hpp:30
Base tag for a patch.
Definition: patch.hpp:14
A class which describes the real space in the temporal direction.
Definition: geometry.hpp:45