parent
							
								
									2434b39b53
								
							
						
					
					
						commit
						d791b81cb6
					
				@ -0,0 +1,30 @@
 | 
				
			||||
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
 | 
				
			||||
 | 
				
			||||
# create project
 | 
				
			||||
project(mpi-daxpy-test)
 | 
				
			||||
 | 
				
			||||
set(GTENSOR_ENABLE_BLAS ON CACHE BOOL "Enable gtblas")
 | 
				
			||||
 | 
				
			||||
# add dependencies
 | 
				
			||||
include(cmake/CPM.cmake)
 | 
				
			||||
CPMFindPackage(NAME gtensor
 | 
				
			||||
               GITHUB_REPOSITORY wdmapp/gtensor
 | 
				
			||||
               GIT_TAG main
 | 
				
			||||
               OPTIONS "GTENSOR_ENABLE_BLAS ON")
 | 
				
			||||
 | 
				
			||||
find_package(MPI REQUIRED)
 | 
				
			||||
 | 
				
			||||
add_executable(mpi_daxpy_gt)
 | 
				
			||||
target_sources(mpi_daxpy_gt PRIVATE mpi_daxpy_gt.cc)
 | 
				
			||||
target_link_libraries(mpi_daxpy_gt gtensor::gtensor)
 | 
				
			||||
target_link_libraries(mpi_daxpy_gt MPI::MPI_CXX)
 | 
				
			||||
 | 
				
			||||
if ("${GTENSOR_DEVICE}" STREQUAL "cuda") 
 | 
				
			||||
  set_source_files_properties(mpi_daxpy_gt.cc
 | 
				
			||||
                              TARGET_DIRECTORY mpi_daxpy_gt
 | 
				
			||||
                              PROPERTIES LANGUAGE CUDA)
 | 
				
			||||
else()
 | 
				
			||||
  set_source_files_properties(mpi_daxpy_gt.cc
 | 
				
			||||
                              TARGET_DIRECTORY mpi_daxpy_gt
 | 
				
			||||
                              PROPERTIES LANGUAGE CXX)
 | 
				
			||||
endif()
 | 
				
			||||
@ -0,0 +1,21 @@
 | 
				
			||||
set(CPM_DOWNLOAD_VERSION 0.32.1)
 | 
				
			||||
 | 
				
			||||
if(CPM_SOURCE_CACHE)
 | 
				
			||||
  # Expand relative path. This is important if the provided path contains a tilde (~)
 | 
				
			||||
  get_filename_component(CPM_SOURCE_CACHE ${CPM_SOURCE_CACHE} ABSOLUTE)
 | 
				
			||||
  set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
 | 
				
			||||
elseif(DEFINED ENV{CPM_SOURCE_CACHE})
 | 
				
			||||
  set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
 | 
				
			||||
else()
 | 
				
			||||
  set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
 | 
				
			||||
endif()
 | 
				
			||||
 | 
				
			||||
if(NOT (EXISTS ${CPM_DOWNLOAD_LOCATION}))
 | 
				
			||||
  message(STATUS "Downloading CPM.cmake to ${CPM_DOWNLOAD_LOCATION}")
 | 
				
			||||
  file(DOWNLOAD
 | 
				
			||||
       https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake
 | 
				
			||||
       ${CPM_DOWNLOAD_LOCATION}
 | 
				
			||||
  )
 | 
				
			||||
endif()
 | 
				
			||||
 | 
				
			||||
include(${CPM_DOWNLOAD_LOCATION})
 | 
				
			||||
@ -0,0 +1,95 @@
 | 
				
			||||
/*
 | 
				
			||||
 * =====================================================================================
 | 
				
			||||
 *
 | 
				
			||||
 *       Filename:  mpi_daxpy_gt.c
 | 
				
			||||
 *
 | 
				
			||||
 *    Description:  Port to gtensor / gt-blas
 | 
				
			||||
 *
 | 
				
			||||
 *        Version:  1.0
 | 
				
			||||
 *        Created:  05/20/2019 10:33:30 AM
 | 
				
			||||
 *       Revision:  none
 | 
				
			||||
 *       Compiler:  gcc
 | 
				
			||||
 *
 | 
				
			||||
 *         Author:  YOUR NAME (), 
 | 
				
			||||
 *   Organization:  
 | 
				
			||||
 *
 | 
				
			||||
 * =====================================================================================
 | 
				
			||||
 */
 | 
				
			||||
 | 
				
			||||
#include <mpi.h>
 | 
				
			||||
#include <stdio.h>
 | 
				
			||||
#include <stdlib.h>
 | 
				
			||||
 | 
				
			||||
#include "gtensor/gtensor.h"
 | 
				
			||||
#include "gt-blas/blas.h"
 | 
				
			||||
 | 
				
			||||
void set_rank_device(int n_ranks, int rank) {
 | 
				
			||||
    int n_devices, device, ranks_per_device;
 | 
				
			||||
    size_t memory_per_rank;
 | 
				
			||||
 | 
				
			||||
    n_devices = gt::backend::device_get_count();
 | 
				
			||||
 | 
				
			||||
    if (n_ranks > n_devices) {
 | 
				
			||||
        if (n_ranks % n_devices != 0) {
 | 
				
			||||
            printf("ERROR: Number of ranks (%d) not a multiple of number of GPUs (%d)\n",
 | 
				
			||||
                   n_ranks, n_devices);
 | 
				
			||||
            exit(EXIT_FAILURE);
 | 
				
			||||
        }
 | 
				
			||||
        ranks_per_device = n_ranks / n_devices;
 | 
				
			||||
        device = rank / ranks_per_device;
 | 
				
			||||
    } else {
 | 
				
			||||
        ranks_per_device = 1;
 | 
				
			||||
        device = rank;
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
    gt::backend::device_set(device);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
int main(int argc, char **argv) {
 | 
				
			||||
    int n = 1024;
 | 
				
			||||
    int world_size, world_rank;
 | 
				
			||||
 | 
				
			||||
    double a = 2.0;
 | 
				
			||||
    double sum = 0.0;
 | 
				
			||||
 | 
				
			||||
    auto x = gt::empty<double>({n});
 | 
				
			||||
    auto y = gt::empty<double>({n});
 | 
				
			||||
    auto d_x = gt::empty_device<double>({n});
 | 
				
			||||
    auto d_y = gt::empty_device<double>({n});
 | 
				
			||||
 | 
				
			||||
    MPI_Init(NULL, NULL);
 | 
				
			||||
 | 
				
			||||
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);
 | 
				
			||||
    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
 | 
				
			||||
 | 
				
			||||
    for (int i=0; i<n; i++) {
 | 
				
			||||
        x[i] =  i+1;
 | 
				
			||||
        y[i] = -i-1;
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    set_rank_device(world_size, world_rank);
 | 
				
			||||
 | 
				
			||||
    gt::blas::handle_t* h = gt::blas::create();
 | 
				
			||||
 | 
				
			||||
    gt::copy(x, d_x);
 | 
				
			||||
    gt::copy(y, d_y);
 | 
				
			||||
 | 
				
			||||
    gt::blas::axpy(h, a, d_x, d_y);
 | 
				
			||||
 | 
				
			||||
    gt::synchronize();
 | 
				
			||||
 | 
				
			||||
    gt::copy(d_y, y);
 | 
				
			||||
    
 | 
				
			||||
    sum = 0.0;
 | 
				
			||||
    for (int i=0; i<n; i++) {
 | 
				
			||||
        //printf("%f\n", y[i]);
 | 
				
			||||
        sum += y[i];
 | 
				
			||||
    }
 | 
				
			||||
    printf("%d/%d SUM = %f\n", world_rank, world_size, sum);
 | 
				
			||||
 | 
				
			||||
    MPI_Finalize();
 | 
				
			||||
 | 
				
			||||
    return EXIT_SUCCESS;
 | 
				
			||||
}
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue