diff --git a/mpi_daxpy_gt.cc b/mpi_daxpy_gt.cc index d9d53a6..d44aed1 100644 --- a/mpi_daxpy_gt.cc +++ b/mpi_daxpy_gt.cc @@ -47,29 +47,31 @@ void set_rank_device(int n_ranks, int rank) { int main(int argc, char **argv) { int n = 1024; - int world_size, world_rank; + int world_size, world_rank, device_id; uint32_t vendor_id; double a = 2.0; double sum = 0.0; - auto x = gt::empty({n}); - auto y = gt::empty({n}); - auto d_x = gt::empty_device({n}); - auto d_y = gt::empty_device({n}); - MPI_Init(NULL, NULL); MPI_Comm_size(MPI_COMM_WORLD, &world_size); MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + set_rank_device(world_size, world_rank); + + auto x = gt::empty({n}); + auto y = gt::empty({n}); + auto d_x = gt::empty_device({n}); + auto d_y = gt::empty_device({n}); + for (int i=0; i