gt: parameterize space
This commit is contained in:
@@ -79,7 +79,8 @@ static const gt::gtensor<double, 1> stencil5 = {1.0 / 12.0, -2.0 / 3.0, 0.0,
|
||||
*
|
||||
* Size of the result will be size of z with minus 4 in second dimension.
|
||||
*/
|
||||
inline auto stencil2d_1d_5(const gt::gtensor_device<double, 2>& z,
|
||||
template <typename S>
|
||||
inline auto stencil2d_1d_5(const gt::ext::gtensor2<double, 2, S>& z,
|
||||
const gt::gtensor<double, 1>& stencil)
|
||||
{
|
||||
return stencil(0) * z.view(_s(0, -4), _all) +
|
||||
@@ -113,8 +114,9 @@ void set_rank_device(int n_ranks, int rank)
|
||||
}
|
||||
|
||||
// exchange in first dimension, staging into contiguous buffers on device
|
||||
template <typename S>
|
||||
void boundary_exchange_x(MPI_Comm comm, int world_size, int rank,
|
||||
gt::gtensor_device<double, 2>& d_z, int n_bnd,
|
||||
gt::ext::gtensor2<double, 2, S>& d_z, int n_bnd,
|
||||
bool stage_host = false)
|
||||
{
|
||||
auto buf_shape = gt::shape(n_bnd, d_z.shape(1));
|
||||
@@ -235,6 +237,8 @@ void boundary_exchange_x(MPI_Comm comm, int world_size, int rank,
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
using S = gt::space::managed;
|
||||
|
||||
// Note: domain will be n_global x n_global plus ghost points in one dimension
|
||||
int n_global = 8 * 1024;
|
||||
bool stage_host = false;
|
||||
@@ -286,11 +290,11 @@ int main(int argc, char** argv)
|
||||
}
|
||||
|
||||
auto h_z = gt::empty<double>({n_local_with_ghost, n_global});
|
||||
auto d_z = gt::empty_device<double>({n_local_with_ghost, n_global});
|
||||
gt::ext::gtensor2<double, 2, S> d_z(h_z.shape());
|
||||
|
||||
auto h_dzdx_numeric = gt::empty<double>({n_local, n_global});
|
||||
auto h_dzdx_actual = gt::empty<double>({n_local, n_global});
|
||||
auto d_dzdx_numeric = gt::empty_device<double>({n_local, n_global});
|
||||
gt::ext::gtensor2<double, 2, S> d_dzdx_numeric(h_dzdx_numeric.shape());
|
||||
|
||||
double lx = 8;
|
||||
double dx = lx / n_global;
|
||||
@@ -347,7 +351,7 @@ int main(int argc, char** argv)
|
||||
|
||||
for (int i = 0; i < n_warmup + n_iter; i++) {
|
||||
clock_gettime(CLOCK_MONOTONIC, &start);
|
||||
boundary_exchange_x(MPI_COMM_WORLD, world_size, world_rank, d_z, n_bnd,
|
||||
boundary_exchange_x<S>(MPI_COMM_WORLD, world_size, world_rank, d_z, n_bnd,
|
||||
stage_host);
|
||||
clock_gettime(CLOCK_MONOTONIC, &end);
|
||||
iter_time =
|
||||
@@ -358,7 +362,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
|
||||
// do some calculation, to try to more closely simulate what happens in GENE
|
||||
d_dzdx_numeric = stencil2d_1d_5(d_z, stencil5) * scale;
|
||||
d_dzdx_numeric = stencil2d_1d_5<S>(d_z, stencil5) * scale;
|
||||
gt::synchronize();
|
||||
}
|
||||
printf("%d/%d exchange time %0.8f ms\n", world_rank, world_size,
|
||||
|
||||
Reference in New Issue
Block a user