Gyselalib++
 
Loading...
Searching...
No Matches
rk4.hpp
1// SPDX-License-Identifier: MIT
2#pragma once
3
4#include "ddc_alias_inline_functions.hpp"
5#include "ddc_aliases.hpp"
6#include "ddc_helper.hpp"
7#include "itimestepper.hpp"
8#include "vector_field_common.hpp"
9
10
32template <
33 class FieldMem,
34 class DerivFieldMem = FieldMem,
35 class ExecSpace = Kokkos::DefaultExecutionSpace>
36class RK4 : public ITimeStepper<FieldMem, DerivFieldMem, ExecSpace>
37{
39
40public:
41 using typename base_type::IdxRange;
42
43 using typename base_type::ValConstField;
44 using typename base_type::ValField;
45
46 using typename base_type::DerivConstField;
47 using typename base_type::DerivField;
48
49private:
50 IdxRange const m_idx_range;
51
52public:
54
55public:
60 explicit RK4(IdxRange idx_range) : m_idx_range(idx_range) {}
61
77 void update(
78 ExecSpace const& exec_space,
79 ValField y,
80 double dt,
81 std::function<void(DerivField, ValConstField)> dy_calculator,
82 std::function<void(ValField, DerivConstField, double)> y_update) const final
83 {
84 static_assert(
85 Kokkos::SpaceAccessibility<ExecSpace, typename FieldMem::memory_space>::accessible,
86 "MemorySpace has to be accessible for ExecutionSpace.");
87 static_assert(
88 Kokkos::SpaceAccessibility<ExecSpace, typename DerivFieldMem::memory_space>::
89 accessible,
90 "MemorySpace has to be accessible for ExecutionSpace.");
91 FieldMem y_prime_alloc(m_idx_range);
92 DerivFieldMem k1_alloc(m_idx_range);
93 DerivFieldMem k2_alloc(m_idx_range);
94 DerivFieldMem k3_alloc(m_idx_range);
95 DerivFieldMem k4_alloc(m_idx_range);
96 DerivFieldMem k_total_alloc(m_idx_range);
97
98 ValField y_prime = get_field(y_prime_alloc);
99 DerivField k1 = get_field(k1_alloc);
100 DerivField k2 = get_field(k2_alloc);
101 DerivField k3 = get_field(k3_alloc);
102 DerivField k4 = get_field(k4_alloc);
103 DerivField k_total = get_field(k_total_alloc);
104
105
106 // Save initial conditions
107 base_type::copy(y_prime, get_const_field(y));
108
109 // --------- Calculate k1 ------------
110 // k1 = f(y)
111 dy_calculator(k1, get_const_field(y));
112
113 // --------- Calculate k2 ------------
114 // Calculate y_new := y_n + h/2*k_1
115 y_update(y_prime, get_const_field(k1), 0.5 * dt);
116
117 // Calculate k2 = f(y_new)
118 dy_calculator(k2, get_const_field(y_prime));
119
120 // --------- Calculate k3 ------------
121 // Collect initial conditions
122 base_type::copy(y_prime, get_const_field(y));
123
124 // Calculate y_new := y_n + h/2*k_2
125 y_update(y_prime, get_const_field(k2), 0.5 * dt);
126
127 // Calculate k3 = f(y_new)
128 dy_calculator(k3, get_const_field(y_prime));
129
130 // --------- Calculate k3 ------------
131 // Collect initial conditions
132 base_type::copy(y_prime, get_const_field(y));
133
134 // Calculate y_new := y_n + h*k_3
135 y_update(y_prime, get_const_field(k3), dt);
136
137 // Calculate k4 = f(y_new)
138 dy_calculator(k4, get_const_field(y_prime));
139
140 // --------- Update y ------------
141 // Calculation of step
142 // k_total = k1 + 4 * k2 + k3
143 using element_type = typename DerivField::element_type;
145 exec_space,
146 k_total,
147 KOKKOS_LAMBDA(std::array<element_type, 4> k) {
148 return k[0] + 2 * k[1] + 2 * k[2] + k[3];
149 },
150 k1,
151 k2,
152 k3,
153 k4);
154
155 // Calculate y_{n+1} := y_n + (k1 + 2 * k2 + 2 * k3 + k4) * h/6
156 y_update(y, get_const_field(k_total), dt / 6.);
157 }
158};
See DerivFieldMemImplementation.
Definition derivative_field.hpp:10
See DerivFieldImplementation.
Definition derivative_field.hpp:20
The superclass from which all timestepping methods inherit.
Definition itimestepper.hpp:23
typename FieldMem::discrete_domain_type IdxRange
The type of the index range on which the values of the function are defined.
Definition itimestepper.hpp:46
typename DerivFieldMem::span_type DerivField
The type of the derivatives of the function being evolved.
Definition itimestepper.hpp:56
void assemble_k_total(ExecSpace const &exec_space, DerivField k_total, FuncType func, T... k) const
A method to assemble multiple derivative fields into one.
Definition itimestepper.hpp:182
typename DerivFieldMem::view_type DerivConstField
The constant type of the derivatives values of the function being evolved.
Definition itimestepper.hpp:59
typename FieldMem::span_type ValField
The type of the values of the function being evolved.
Definition itimestepper.hpp:50
void update(ValField y, double dt, std::function< void(DerivField, ValConstField)> dy_calculator) const
Carry out one step of the timestepping scheme.
Definition itimestepper.hpp:78
void copy(ValField copy_to, ValConstField copy_from) const
Make a copy of the values of the function being evolved.
Definition itimestepper.hpp:147
typename FieldMem::view_type ValConstField
The constant type of the values of the function being evolved.
Definition itimestepper.hpp:53
A class which provides an implementation of a fourth-order Runge-Kutta method.
Definition rk4.hpp:37
RK4(IdxRange idx_range)
Create a RK4 object.
Definition rk4.hpp:60
typename DerivFieldMem::view_type DerivConstField
The constant type of the derivatives values of the function being evolved.
Definition itimestepper.hpp:59
void update(ExecSpace const &exec_space, ValField y, double dt, std::function< void(DerivField, ValConstField)> dy_calculator, std::function< void(ValField, DerivConstField, double)> y_update) const final
Carry out one step of the Runge-Kutta scheme.
Definition rk4.hpp:77
typename FieldMem::span_type ValField
The type of the values of the function being evolved.
Definition itimestepper.hpp:50
typename FieldMem::view_type ValConstField
The constant type of the values of the function being evolved.
Definition itimestepper.hpp:53