add mpi wtime counters, fix make clean

main
Bryce Allen 5 years ago
parent 3dd6045f2e
commit 3ebd09725e

@ -21,7 +21,7 @@ mpienv: mpienv.f90
.PHONY: clean .PHONY: clean
clean: clean:
rm -rf daxpy mpi_daxpy rm -rf daxpy mpi_daxpy daxpy_nvtx mpi_daxpy_nvtx
.PHONY: force .PHONY: force
force: clean all force: clean all

@ -66,12 +66,19 @@ void set_rank_device(int n_ranks, int rank) {
int main(int argc, char **argv) { int main(int argc, char **argv) {
int n = 1024; int n = 32*1024*1024;
int world_size, world_rank; int world_size, world_rank;
double a = 2.0; double a = 2.0;
double sum = 0.0; double sum = 0.0;
double start_time = 0.0;
double end_time = 0.0;
double k_start_time = 0.0;
double k_end_time = 0.0;
double g_start_time = 0.0;
double g_end_time = 0.0;
//double *x, *y, *d_x, *d_y; //double *x, *y, *d_x, *d_y;
double *m_x, *m_y; double *m_x, *m_y;
@ -113,6 +120,7 @@ int main(int argc, char **argv) {
//CHECK("setDevice", cudaSetDevice(0)); //CHECK("setDevice", cudaSetDevice(0));
cudaProfilerStart(); cudaProfilerStart();
start_time = MPI_Wtime();
CHECK( "cublas", cublasCreate(&handle) ); CHECK( "cublas", cublasCreate(&handle) );
@ -155,12 +163,14 @@ int main(int argc, char **argv) {
MEMINFO("m_x", m_x, sizeof(m_x)); MEMINFO("m_x", m_x, sizeof(m_x));
MEMINFO("m_y", m_y, sizeof(m_y)); MEMINFO("m_y", m_y, sizeof(m_y));
k_start_time = MPI_Wtime();
nvtxRangePushA("cublasDaxpy"); nvtxRangePushA("cublasDaxpy");
CHECK("daxpy", CHECK("daxpy",
cublasDaxpy(handle, n, &a, m_x, 1, m_y, 1) ); cublasDaxpy(handle, n, &a, m_x, 1, m_y, 1) );
CHECK("daxpy sync", cudaDeviceSynchronize()); CHECK("daxpy sync", cudaDeviceSynchronize());
nvtxRangePop(); nvtxRangePop();
k_end_time = MPI_Wtime();
/* /*
CHECK("y = d_y", CHECK("y = d_y",
@ -182,6 +192,7 @@ int main(int argc, char **argv) {
nvtxRangePop(); nvtxRangePop();
printf("%d/%d SUM = %f\n", world_rank, world_size, sum); printf("%d/%d SUM = %f\n", world_rank, world_size, sum);
g_start_time = MPI_Wtime();
nvtxRangePushA("allGather"); nvtxRangePushA("allGather");
nvtxRangePushA("x"); nvtxRangePushA("x");
MPI_Allgather(m_x, n, MPI_DOUBLE, m_allx, n, MPI_DOUBLE, MPI_COMM_WORLD); MPI_Allgather(m_x, n, MPI_DOUBLE, m_allx, n, MPI_DOUBLE, MPI_COMM_WORLD);
@ -190,6 +201,7 @@ int main(int argc, char **argv) {
MPI_Allgather(m_y, n, MPI_DOUBLE, m_ally, n, MPI_DOUBLE, MPI_COMM_WORLD); MPI_Allgather(m_y, n, MPI_DOUBLE, m_ally, n, MPI_DOUBLE, MPI_COMM_WORLD);
nvtxRangePop(); nvtxRangePop();
nvtxRangePop(); nvtxRangePop();
g_end_time = MPI_Wtime();
sum = 0.0; sum = 0.0;
nvtxRangePushA("allSum"); nvtxRangePushA("allSum");
@ -211,10 +223,15 @@ int main(int argc, char **argv) {
nvtxRangePop(); nvtxRangePop();
end_time = MPI_Wtime();
cudaProfilerStop(); cudaProfilerStop();
cublasDestroy(handle); cublasDestroy(handle);
MPI_Finalize(); MPI_Finalize();
printf("total time: %0.3f\n", end_time-start_time);
printf("kernel time: %0.3f\n", k_end_time-k_start_time);
printf("gather time: %0.3f\n", g_end_time-g_start_time);
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }

Loading…
Cancel
Save