gt: parameterize space

main
Bryce Allen 3 years ago
parent 4f82b8359b
commit 5bcf1382ba

@ -79,7 +79,8 @@ static const gt::gtensor<double, 1> stencil5 = {1.0 / 12.0, -2.0 / 3.0, 0.0,
*
* Size of the result will be size of z with minus 4 in second dimension.
*/
inline auto stencil2d_1d_5(const gt::gtensor_device<double, 2>& z,
template <typename S>
inline auto stencil2d_1d_5(const gt::ext::gtensor2<double, 2, S>& z,
const gt::gtensor<double, 1>& stencil)
{
return stencil(0) * z.view(_s(0, -4), _all) +
@ -113,8 +114,9 @@ void set_rank_device(int n_ranks, int rank)
}
// exchange in first dimension, staging into contiguous buffers on device
template <typename S>
void boundary_exchange_x(MPI_Comm comm, int world_size, int rank,
gt::gtensor_device<double, 2>& d_z, int n_bnd,
gt::ext::gtensor2<double, 2, S>& d_z, int n_bnd,
bool stage_host = false)
{
auto buf_shape = gt::shape(n_bnd, d_z.shape(1));
@ -235,6 +237,8 @@ void boundary_exchange_x(MPI_Comm comm, int world_size, int rank,
int main(int argc, char** argv)
{
using S = gt::space::managed;
// Note: domain will be n_global x n_global plus ghost points in one dimension
int n_global = 8 * 1024;
bool stage_host = false;
@ -286,11 +290,11 @@ int main(int argc, char** argv)
}
auto h_z = gt::empty<double>({n_local_with_ghost, n_global});
auto d_z = gt::empty_device<double>({n_local_with_ghost, n_global});
gt::ext::gtensor2<double, 2, S> d_z(h_z.shape());
auto h_dzdx_numeric = gt::empty<double>({n_local, n_global});
auto h_dzdx_actual = gt::empty<double>({n_local, n_global});
auto d_dzdx_numeric = gt::empty_device<double>({n_local, n_global});
gt::ext::gtensor2<double, 2, S> d_dzdx_numeric(h_dzdx_numeric.shape());
double lx = 8;
double dx = lx / n_global;
@ -347,8 +351,8 @@ int main(int argc, char** argv)
for (int i = 0; i < n_warmup + n_iter; i++) {
clock_gettime(CLOCK_MONOTONIC, &start);
boundary_exchange_x(MPI_COMM_WORLD, world_size, world_rank, d_z, n_bnd,
stage_host);
boundary_exchange_x<S>(MPI_COMM_WORLD, world_size, world_rank, d_z, n_bnd,
stage_host);
clock_gettime(CLOCK_MONOTONIC, &end);
iter_time =
((end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) * 1.0e-9);
@ -358,7 +362,7 @@ int main(int argc, char** argv)
}
// do some calculation, to try to more closely simulate what happens in GENE
d_dzdx_numeric = stencil2d_1d_5(d_z, stencil5) * scale;
d_dzdx_numeric = stencil2d_1d_5<S>(d_z, stencil5) * scale;
gt::synchronize();
}
printf("%d/%d exchange time %0.8f ms\n", world_rank, world_size,

Loading…
Cancel
Save