diff --git a/mpi_stencil2d_gt.cc b/mpi_stencil2d_gt.cc index a3130da..bb8403a 100644 --- a/mpi_stencil2d_gt.cc +++ b/mpi_stencil2d_gt.cc @@ -79,7 +79,8 @@ static const gt::gtensor stencil5 = {1.0 / 12.0, -2.0 / 3.0, 0.0, * * Size of the result will be size of z with minus 4 in second dimension. */ -inline auto stencil2d_1d_5(const gt::gtensor_device& z, +template +inline auto stencil2d_1d_5(const gt::ext::gtensor2& z, const gt::gtensor& stencil) { return stencil(0) * z.view(_s(0, -4), _all) + @@ -113,8 +114,9 @@ void set_rank_device(int n_ranks, int rank) } // exchange in first dimension, staging into contiguous buffers on device +template void boundary_exchange_x(MPI_Comm comm, int world_size, int rank, - gt::gtensor_device& d_z, int n_bnd, + gt::ext::gtensor2& d_z, int n_bnd, bool stage_host = false) { auto buf_shape = gt::shape(n_bnd, d_z.shape(1)); @@ -235,6 +237,8 @@ void boundary_exchange_x(MPI_Comm comm, int world_size, int rank, int main(int argc, char** argv) { + using S = gt::space::managed; + // Note: domain will be n_global x n_global plus ghost points in one dimension int n_global = 8 * 1024; bool stage_host = false; @@ -286,11 +290,11 @@ int main(int argc, char** argv) } auto h_z = gt::empty({n_local_with_ghost, n_global}); - auto d_z = gt::empty_device({n_local_with_ghost, n_global}); + gt::ext::gtensor2 d_z(h_z.shape()); auto h_dzdx_numeric = gt::empty({n_local, n_global}); auto h_dzdx_actual = gt::empty({n_local, n_global}); - auto d_dzdx_numeric = gt::empty_device({n_local, n_global}); + gt::ext::gtensor2 d_dzdx_numeric(h_dzdx_numeric.shape()); double lx = 8; double dx = lx / n_global; @@ -347,8 +351,8 @@ int main(int argc, char** argv) for (int i = 0; i < n_warmup + n_iter; i++) { clock_gettime(CLOCK_MONOTONIC, &start); - boundary_exchange_x(MPI_COMM_WORLD, world_size, world_rank, d_z, n_bnd, - stage_host); + boundary_exchange_x(MPI_COMM_WORLD, world_size, world_rank, d_z, n_bnd, + stage_host); clock_gettime(CLOCK_MONOTONIC, &end); iter_time = ((end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) * 1.0e-9); @@ -358,7 +362,7 @@ int main(int argc, char** argv) } // do some calculation, to try to more closely simulate what happens in GENE - d_dzdx_numeric = stencil2d_1d_5(d_z, stencil5) * scale; + d_dzdx_numeric = stencil2d_1d_5(d_z, stencil5) * scale; gt::synchronize(); } printf("%d/%d exchange time %0.8f ms\n", world_rank, world_size,