gt and cmake fixes
This commit is contained in:
@@ -17,9 +17,11 @@ find_package(MPI REQUIRED)
|
|||||||
add_executable(mpi_daxpy_gt)
|
add_executable(mpi_daxpy_gt)
|
||||||
target_sources(mpi_daxpy_gt PRIVATE mpi_daxpy_gt.cc)
|
target_sources(mpi_daxpy_gt PRIVATE mpi_daxpy_gt.cc)
|
||||||
target_link_libraries(mpi_daxpy_gt gtensor::gtensor)
|
target_link_libraries(mpi_daxpy_gt gtensor::gtensor)
|
||||||
|
target_link_libraries(mpi_daxpy_gt gtensor::blas)
|
||||||
target_link_libraries(mpi_daxpy_gt MPI::MPI_CXX)
|
target_link_libraries(mpi_daxpy_gt MPI::MPI_CXX)
|
||||||
|
|
||||||
if ("${GTENSOR_DEVICE}" STREQUAL "cuda")
|
if ("${GTENSOR_DEVICE}" STREQUAL "cuda")
|
||||||
|
enable_language(CUDA)
|
||||||
set_source_files_properties(mpi_daxpy_gt.cc
|
set_source_files_properties(mpi_daxpy_gt.cc
|
||||||
TARGET_DIRECTORY mpi_daxpy_gt
|
TARGET_DIRECTORY mpi_daxpy_gt
|
||||||
PROPERTIES LANGUAGE CUDA)
|
PROPERTIES LANGUAGE CUDA)
|
||||||
|
|||||||
@@ -25,9 +25,8 @@
|
|||||||
|
|
||||||
void set_rank_device(int n_ranks, int rank) {
|
void set_rank_device(int n_ranks, int rank) {
|
||||||
int n_devices, device, ranks_per_device;
|
int n_devices, device, ranks_per_device;
|
||||||
size_t memory_per_rank;
|
|
||||||
|
|
||||||
n_devices = gt::backend::device_get_count();
|
n_devices = gt::backend::clib::device_get_count();
|
||||||
|
|
||||||
if (n_ranks > n_devices) {
|
if (n_ranks > n_devices) {
|
||||||
if (n_ranks % n_devices != 0) {
|
if (n_ranks % n_devices != 0) {
|
||||||
@@ -42,14 +41,14 @@ void set_rank_device(int n_ranks, int rank) {
|
|||||||
device = rank;
|
device = rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gt::backend::clib::device_set(device);
|
||||||
gt::backend::device_set(device);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
int n = 1024;
|
int n = 1024;
|
||||||
int world_size, world_rank;
|
int world_size, world_rank;
|
||||||
|
uint32_t vendor_id;
|
||||||
|
|
||||||
double a = 2.0;
|
double a = 2.0;
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
@@ -70,6 +69,7 @@ int main(int argc, char **argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
set_rank_device(world_size, world_rank);
|
set_rank_device(world_size, world_rank);
|
||||||
|
vendor_id = gt::backend::clib::device_get_vendor_id(gt::backend::clib::device_get());
|
||||||
|
|
||||||
gt::blas::handle_t* h = gt::blas::create();
|
gt::blas::handle_t* h = gt::blas::create();
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ int main(int argc, char **argv) {
|
|||||||
//printf("%f\n", y[i]);
|
//printf("%f\n", y[i]);
|
||||||
sum += y[i];
|
sum += y[i];
|
||||||
}
|
}
|
||||||
printf("%d/%d SUM = %f\n", world_rank, world_size, sum);
|
printf("%d/%d [%x] SUM = %f\n", world_rank, world_size, vendor_id, sum);
|
||||||
|
|
||||||
MPI_Finalize();
|
MPI_Finalize();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user