Gyselalib++
 
Loading...
Searching...
No Matches
mpitransposealltoall.hpp
1// SPDX-License-Identifier: MIT
2#pragma once
3#include <numeric>
4
5#include <ddc/ddc.hpp>
6
7#include "ddc_alias_inline_functions.hpp"
8#include "ddc_aliases.hpp"
9#include "ddc_helper.hpp"
10#include "impitranspose.hpp"
11#include "mpilayout.hpp"
12#include "mpitools.hpp"
13#include "transpose.hpp"
14
23template <class Layout1, class Layout2>
24class MPITransposeAllToAll : public IMPITranspose<Layout1, Layout2>
25{
26public:
28 using idx_range_type1 = typename Layout1::discrete_domain_type;
30 using idx_range_type2 = typename Layout2::discrete_domain_type;
32 using distributed_idx_range_type1 = typename Layout1::distributed_sub_idx_range;
34 using distributed_idx_range_type2 = typename Layout2::distributed_sub_idx_range;
35
36private:
37 using layout_1_mpi_dims = ddcHelper::
38 apply_template_to_type_seq_t<MPIDim, typename Layout1::distributed_type_seq>;
39 using layout_2_mpi_dims = ddcHelper::
40 apply_template_to_type_seq_t<MPIDim, typename Layout2::distributed_type_seq>;
41 using layout_1_mpi_idx_range_type
42 = ddc::detail::convert_type_seq_to_discrete_domain_t<layout_1_mpi_dims>;
43 using layout_2_mpi_idx_range_type
44 = ddc::detail::convert_type_seq_to_discrete_domain_t<layout_2_mpi_dims>;
45
46private:
47 int m_comm_size;
48 Layout1 m_layout_1;
49 Layout2 m_layout_2;
50 idx_range_type1 m_local_idx_range_1;
51 idx_range_type2 m_local_idx_range_2;
52 layout_1_mpi_idx_range_type m_layout_1_mpi_idx_range;
53 layout_2_mpi_idx_range_type m_layout_2_mpi_idx_range;
54
55public:
62 template <class IdxRange>
63 MPITransposeAllToAll(IdxRange global_idx_range, MPI_Comm comm)
64 : IMPITranspose<Layout1, Layout2>(comm)
65 {
66 static_assert(
67 std::is_same_v<
68 IdxRange,
69 idx_range_type1> || std::is_same_v<IdxRange, idx_range_type2>,
70 "The initialisation global idx_range should be described by one of the layouts");
71 idx_range_type1 global_idx_range_layout_1(global_idx_range);
72 idx_range_type2 global_idx_range_layout_2(global_idx_range);
73 int rank;
74 MPI_Comm_size(comm, &m_comm_size);
75 MPI_Comm_rank(comm, &rank);
76 distributed_idx_range_type1 distrib_idx_range(global_idx_range_layout_1);
77 assert(m_comm_size <= distrib_idx_range.size());
78 // Ensure that the load balancing is good
79 assert(distrib_idx_range.size() % m_comm_size == 0);
80 m_local_idx_range_1
81 = m_layout_1.distribute_idx_range(global_idx_range_layout_1, m_comm_size, rank);
82 m_local_idx_range_2
83 = m_layout_2.distribute_idx_range(global_idx_range_layout_2, m_comm_size, rank);
84 m_layout_1_mpi_idx_range = get_distribution(
85 distributed_idx_range_type1(m_local_idx_range_1),
86 distributed_idx_range_type1(global_idx_range_layout_1));
87 m_layout_2_mpi_idx_range = get_distribution(
88 distributed_idx_range_type2(m_local_idx_range_2),
89 distributed_idx_range_type2(global_idx_range_layout_2));
90 }
91
99 template <class Layout>
101 {
102 static_assert(
103 std::is_same_v<Layout, Layout1> || std::is_same_v<Layout, Layout2>,
104 "Transpose class does not handle requested layout");
105 if constexpr (std::is_same_v<Layout, Layout1>) {
106 return m_local_idx_range_1;
107 } else if constexpr (std::is_same_v<Layout, Layout2>) {
108 return m_local_idx_range_2;
109 }
110 }
111
125 template <
126 class ElementType,
127 class InIdxRange,
128 class IdxRangeOut,
129 class MemSpace,
130 class ExecSpace>
132 ExecSpace const& execution_space,
133 Field<ElementType, IdxRangeOut, MemSpace> recv_field,
134 ConstField<ElementType, InIdxRange, MemSpace> send_field)
135 {
136 static_assert(!std::is_same_v<InIdxRange, IdxRangeOut>);
137 static_assert(
138 (std::is_same_v<InIdxRange, typename Layout1::discrete_domain_type>)
139 || (std::is_same_v<InIdxRange, typename Layout2::discrete_domain_type>));
140 static_assert(
141 (std::is_same_v<IdxRangeOut, typename Layout1::discrete_domain_type>)
142 || (std::is_same_v<IdxRangeOut, typename Layout2::discrete_domain_type>));
143 using OutLayout = std::conditional_t<
144 std::is_same_v<IdxRangeOut, typename Layout1::discrete_domain_type>,
145 Layout1,
146 Layout2>;
147 this->transpose_to<OutLayout>(execution_space, recv_field, send_field);
148 }
149
165 template <class OutLayout, class ElementType, class MemSpace, class ExecSpace, class InIdxRange>
167 ExecSpace const& execution_space,
168 Field<ElementType, typename OutLayout::discrete_domain_type, MemSpace> recv_field,
169 ConstField<ElementType, InIdxRange, MemSpace> send_field)
170 {
171 using InLayout = std::conditional_t<std::is_same_v<OutLayout, Layout1>, Layout2, Layout1>;
172 /*****************************************************************
173 * Validate input
174 *****************************************************************/
175 static_assert((std::is_same_v<InLayout, Layout1>) || (std::is_same_v<InLayout, Layout2>));
176 static_assert((std::is_same_v<OutLayout, Layout1>) || (std::is_same_v<OutLayout, Layout2>));
177 static_assert(Kokkos::SpaceAccessibility<ExecSpace, MemSpace>::accessible);
178
179 static_assert(std::is_same_v<InIdxRange, typename InLayout::discrete_domain_type>);
180 using IdxRangeOut = typename OutLayout::discrete_domain_type;
181
182 /*****************************************************************
183 * Define groups of tags to build necessary transform
184 *****************************************************************/
185 // Currently distributed dims will be gathered
186 using dims_to_gather = typename InLayout::distributed_type_seq;
187 // Future distributed dims must be scattered
188 using dims_to_scatter = typename OutLayout::distributed_type_seq;
189 // Get MPI tags (e.g. MPI<DimX> indicates where DimX will be gathered from)
190 using gather_mpi_dims = ddcHelper::
191 apply_template_to_type_seq_t<MPIDim, typename InLayout::distributed_type_seq>;
192 using scatter_mpi_dims = ddcHelper::
193 apply_template_to_type_seq_t<MPIDim, typename OutLayout::distributed_type_seq>;
194 // Get complete set of tags from index ranges
195 using input_ordered_dims = ddc::to_type_seq_t<InIdxRange>;
196 using output_ordered_dims = ddc::to_type_seq_t<IdxRangeOut>;
197 // Find tags that are neither scattered nor gathered
198 using batch_dims = ddc::type_seq_remove_t<
199 ddc::type_seq_remove_t<input_ordered_dims, dims_to_scatter>,
200 dims_to_gather>;
201 // Insert the MPI tags into the received data to introduce the dimension split
202 using input_mpi_idx_range_tags
203 = insert_mpi_tags_into_seq_t<scatter_mpi_dims, input_ordered_dims>;
204 using output_mpi_idx_range_tags
205 = insert_mpi_tags_into_seq_t<gather_mpi_dims, output_ordered_dims>;
206 // During the alltoall call the MPI tags must be first to correctly send/receive data
207 // but data does not reorder itself
208 using input_alltoall_dim_order
209 = ddc::type_seq_merge_t<scatter_mpi_dims, input_ordered_dims>;
210 using output_alltoall_dim_order
211 = ddc::type_seq_merge_t<gather_mpi_dims, input_ordered_dims>;
212
213 /*****************************************************************
214 * Define useful index range types
215 *****************************************************************/
216 // Get the index ranges of objects which are distributed across processes
217 using gather_idx_range_type = typename InLayout::distributed_sub_idx_range;
218 // Get the index ranges of objects which will be distributed across processes
219 using scatter_idx_range_type = typename OutLayout::distributed_sub_idx_range;
220 // Get the index range of objects that are not distributed across processes
221 using batch_idx_range_type = ddc::detail::convert_type_seq_to_discrete_domain_t<batch_dims>;
222 using gather_mpi_idx_range_type
223 = ddc::detail::convert_type_seq_to_discrete_domain_t<gather_mpi_dims>;
224 using scatter_mpi_idx_range_type
225 = ddc::detail::convert_type_seq_to_discrete_domain_t<scatter_mpi_dims>;
226 // Get the index range containing MPI tags which can be used to describe the function input
227 using input_mpi_idx_range_type
228 = ddc::detail::convert_type_seq_to_discrete_domain_t<input_mpi_idx_range_tags>;
229 // Get the index range containing MPI tags which can be used to describe the function output
230 using output_mpi_idx_range_type
231 = ddc::detail::convert_type_seq_to_discrete_domain_t<output_mpi_idx_range_tags>;
232 // Get the index range that should be used when scattering data
233 using input_alltoall_idx_range_type
234 = ddc::detail::convert_type_seq_to_discrete_domain_t<input_alltoall_dim_order>;
235 // Get the index range that should be used when gathering data
236 using output_alltoall_idx_range_type
237 = ddc::detail::convert_type_seq_to_discrete_domain_t<output_alltoall_dim_order>;
238
239 /*****************************************************************
240 * Build index ranges
241 *****************************************************************/
242 // Collect the artificial dimension describing the MPI rank where the scattered information
243 // will be sent to or where the gathered information will be collected from
244 gather_mpi_idx_range_type gather_mpi_idx_range;
245 scatter_mpi_idx_range_type scatter_mpi_idx_range;
246 if constexpr (std::is_same_v<InLayout, Layout1>) {
247 gather_mpi_idx_range = m_layout_1_mpi_idx_range;
248 scatter_mpi_idx_range = m_layout_2_mpi_idx_range;
249 } else {
250 gather_mpi_idx_range = m_layout_2_mpi_idx_range;
251 scatter_mpi_idx_range = m_layout_1_mpi_idx_range;
252 }
253
254 // Collect the useful subindex ranges described in the fields
255 batch_idx_range_type batch_idx_range(get_idx_range(send_field));
256 scatter_idx_range_type scatter_idx_range(get_idx_range(recv_field));
257 gather_idx_range_type gather_idx_range(get_idx_range(send_field));
258
259 // Build the index ranges describing the function inputs but including the MPI rank information
260 input_mpi_idx_range_type input_mpi_idx_range(
261 scatter_mpi_idx_range,
262 scatter_idx_range,
263 gather_idx_range,
264 batch_idx_range);
265 output_mpi_idx_range_type output_mpi_idx_range(
266 gather_mpi_idx_range,
267 scatter_idx_range,
268 gather_idx_range,
269 batch_idx_range);
270 assert(input_mpi_idx_range.size() == send_field.size());
271 assert(output_mpi_idx_range.size() == recv_field.size());
272
273 // Create the index ranges used during the alltoall call (the MPI rank index range is first in this layout)
274 input_alltoall_idx_range_type input_alltoall_idx_range(input_mpi_idx_range);
275 output_alltoall_idx_range_type output_alltoall_idx_range(output_mpi_idx_range);
276
277 /*****************************************************************
278 * Create views on the function inputs with the index ranges including the MPI rank information
279 *****************************************************************/
280 ConstField<ElementType, input_mpi_idx_range_type, MemSpace>
281 send_mpi_field(send_field.data_handle(), input_mpi_idx_range);
282 Field<ElementType, output_mpi_idx_range_type, MemSpace>
283 recv_mpi_field(recv_field.data_handle(), output_mpi_idx_range);
284
285 /*****************************************************************
286 * Transpose data (both on the rank and between ranks)
287 *****************************************************************/
288 // Create views or copies of the function inputs such that they are laid out on the
289 // index range used during the alltoall call
290 auto alltoall_send_buffer = ddcHelper::create_transpose_mirror_view_and_copy<
291 input_alltoall_idx_range_type>(execution_space, send_mpi_field);
292 auto alltoall_recv_buffer = ddcHelper::create_transpose_mirror<
293 output_alltoall_idx_range_type>(execution_space, recv_mpi_field);
294
295 // Call the MPI AlltoAll routine
296 call_all_to_all(
297 execution_space,
298 get_field(alltoall_recv_buffer),
299 get_const_field(alltoall_send_buffer));
300
301 // If alltoall_recv_buffer owns its data (not a view) then copy the results back to
302 // recv_mpi_field which is a view on recv_field, the function output
303 if constexpr (!ddc::is_borrowed_chunk_v<decltype(alltoall_recv_buffer)>) {
304 transpose_layout(
305 execution_space,
306 recv_mpi_field,
307 get_const_field(alltoall_recv_buffer));
308 }
309 }
310
311private:
313 template <
314 class ElementType,
315 class MPIRecvIdxRange,
316 class MPISendIdxRange,
317 class MemSpace,
318 class ExecSpace>
319 void call_all_to_all(
320 ExecSpace const& execution_space,
321 Field<ElementType, MPIRecvIdxRange, MemSpace> recv_field,
322 ConstField<ElementType, MPISendIdxRange, MemSpace> send_field)
323 {
324 // No Cuda-aware MPI yet
325 auto send_buffer = ddc::create_mirror_view_and_copy(send_field);
326 auto recv_buffer = ddc::create_mirror_view(recv_field);
327 MPI_Alltoall(
328 send_buffer.data_handle(),
329 send_buffer.size() / m_comm_size,
330 MPI_type_descriptor_t<ElementType>,
331 recv_buffer.data_handle(),
332 recv_buffer.size() / m_comm_size,
333 MPI_type_descriptor_t<ElementType>,
335 if constexpr (!ddc::is_borrowed_chunk_v<decltype(recv_buffer)>) {
336 ddc::parallel_deepcopy(execution_space, recv_field, recv_buffer);
337 }
338 }
339
340 template <class... DistributedDims>
341 IdxRange<MPIDim<DistributedDims>...> get_distribution(
342 IdxRange<DistributedDims...> local_idx_range,
343 IdxRange<DistributedDims...> global_idx_range)
344 {
345 Idx<MPIDim<DistributedDims>...> start(Idx<MPIDim<DistributedDims>> {0}...);
346 IdxStep<MPIDim<DistributedDims>...> size(IdxStep<MPIDim<DistributedDims>> {
347 ddc::select<DistributedDims>(global_idx_range).size()
348 / ddc::select<DistributedDims>(local_idx_range).size()}...);
349 return IdxRange<MPIDim<DistributedDims>...>(start, size);
350 }
351};
A superclass describing an operator for converting from/to different MPI layouts.
Definition impitranspose.hpp:15
A class describing an operator for converting from/to different MPI layouts using AlltoAll.
Definition mpitransposealltoall.hpp:25
void operator()(ExecSpace const &execution_space, Field< ElementType, IdxRangeOut, MemSpace > recv_field, ConstField< ElementType, InIdxRange, MemSpace > send_field)
An operator which transposes from one layout to another.
Definition mpitransposealltoall.hpp:131
typename Layout2::distributed_sub_idx_range distributed_idx_range_type2
The type of the index range of the second MPI layout.
Definition mpitransposealltoall.hpp:34
MPITransposeAllToAll(IdxRange global_idx_range, MPI_Comm comm)
A constructor for the transpose operator.
Definition mpitransposealltoall.hpp:63
typename Layout1::distributed_sub_idx_range distributed_idx_range_type1
The type of the index range of the first MPI layout.
Definition mpitransposealltoall.hpp:32
auto get_local_idx_range()
Getter for the local index range.
Definition mpitransposealltoall.hpp:100
typename Layout1::discrete_domain_type idx_range_type1
The type of the index range of the first MPI layout.
Definition mpitransposealltoall.hpp:28
void transpose_to(ExecSpace const &execution_space, Field< ElementType, typename OutLayout::discrete_domain_type, MemSpace > recv_field, ConstField< ElementType, InIdxRange, MemSpace > send_field)
An operator which transposes from one layout to another.
Definition mpitransposealltoall.hpp:166
typename Layout2::discrete_domain_type idx_range_type2
The type of the index range of the second MPI layout.
Definition mpitransposealltoall.hpp:30
An internal tag used to dsecribe an artificial dimension describing the MPI rank where the scattered ...
Definition mpitools.hpp:14