diff --git a/mpi_daxpy_nvtx.cc b/mpi_daxpy_nvtx.cc index 54662af..21d6977 100644 --- a/mpi_daxpy_nvtx.cc +++ b/mpi_daxpy_nvtx.cc @@ -124,8 +124,8 @@ int main(int argc, char **argv) { CHECK( "m_x", cudaMallocManaged((void**)&m_x, n*sizeof(*m_x)) ); CHECK( "m_y", cudaMallocManaged((void**)&m_y, n*sizeof(*m_y)) ); - CHECK( "m_allx", cudaMallocManaged((void**)&m_allx, n*sizeof(*m_allx)) ); - CHECK( "m_ally", cudaMallocManaged((void**)&m_ally, n*sizeof(*m_ally)) ); + CHECK( "m_allx", cudaMallocManaged((void**)&m_allx, n*sizeof(*m_allx)*world_size) ); + CHECK( "m_ally", cudaMallocManaged((void**)&m_ally, n*sizeof(*m_ally)*world_size) ); nvtxRangePushA("initializeArrays"); for (int i=0; i