fix mpi init/set device order
This commit is contained in:
@@ -47,29 +47,31 @@ void set_rank_device(int n_ranks, int rank) {
|
|||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
int n = 1024;
|
int n = 1024;
|
||||||
int world_size, world_rank;
|
int world_size, world_rank, device_id;
|
||||||
uint32_t vendor_id;
|
uint32_t vendor_id;
|
||||||
|
|
||||||
double a = 2.0;
|
double a = 2.0;
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
|
|
||||||
auto x = gt::empty<double>({n});
|
|
||||||
auto y = gt::empty<double>({n});
|
|
||||||
auto d_x = gt::empty_device<double>({n});
|
|
||||||
auto d_y = gt::empty_device<double>({n});
|
|
||||||
|
|
||||||
MPI_Init(NULL, NULL);
|
MPI_Init(NULL, NULL);
|
||||||
|
|
||||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||||
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
|
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
|
||||||
|
|
||||||
|
set_rank_device(world_size, world_rank);
|
||||||
|
|
||||||
|
auto x = gt::empty<double>({n});
|
||||||
|
auto y = gt::empty<double>({n});
|
||||||
|
auto d_x = gt::empty_device<double>({n});
|
||||||
|
auto d_y = gt::empty_device<double>({n});
|
||||||
|
|
||||||
for (int i=0; i<n; i++) {
|
for (int i=0; i<n; i++) {
|
||||||
x[i] = i+1;
|
x[i] = i+1;
|
||||||
y[i] = -i-1;
|
y[i] = -i-1;
|
||||||
}
|
}
|
||||||
|
|
||||||
set_rank_device(world_size, world_rank);
|
device_id = gt::backend::clib::device_get();
|
||||||
vendor_id = gt::backend::clib::device_get_vendor_id(gt::backend::clib::device_get());
|
vendor_id = gt::backend::clib::device_get_vendor_id(device_id);
|
||||||
|
|
||||||
gt::blas::handle_t* h = gt::blas::create();
|
gt::blas::handle_t* h = gt::blas::create();
|
||||||
|
|
||||||
@@ -87,7 +89,7 @@ int main(int argc, char **argv) {
|
|||||||
//printf("%f\n", y[i]);
|
//printf("%f\n", y[i]);
|
||||||
sum += y[i];
|
sum += y[i];
|
||||||
}
|
}
|
||||||
printf("%d/%d [%x] SUM = %f\n", world_rank, world_size, vendor_id, sum);
|
printf("%d/%d [%d:0x%08x] SUM = %f\n", world_rank, world_size, device_id, vendor_id, sum);
|
||||||
|
|
||||||
MPI_Finalize();
|
MPI_Finalize();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user