Skip to content

Commit

Permalink
Merge pull request #15 from lattice/feature/quda-interface-optimize
Browse files Browse the repository at this point in the history
Feature/quda interface optimize
  • Loading branch information
detar authored Oct 4, 2017
2 parents cad56a5 + a7654a0 commit f407b32
Show file tree
Hide file tree
Showing 17 changed files with 383 additions and 157 deletions.
7 changes: 6 additions & 1 deletion Make_template_combos
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ endif

FLINKS_HISQ_MILC_GPU = fermion_links_hisq_milc.o \
fermion_links_fn_load_gpu.o \
fermion_links_hisq_load_milc.o fermion_links_hisq_load_gpu.o \
fermion_links_hisq_load_milc.o \
${FLINKS} ks_action_paths_hisq.o su3_mat_op.o stout_smear.o

# Standard QOP combinations
Expand Down Expand Up @@ -206,7 +206,12 @@ endif
# Standard MILC

# Choices here are dslash_fn.o dslash_fn2.o dslash_fn_dblstore.o
ifeq ($(strip ${WANTQUDA}),true)
# When using QUDA, the back links are not used and just add unnecessary overhead
DSLASH_FN_MILC = dslash_fn.o
else
DSLASH_FN_MILC = dslash_fn_dblstore.o
endif

# No other choice
DSLASH_EO = dslash_eo.o
Expand Down
10 changes: 9 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,15 @@ CGEOM +=# -DFIX_IONODE_GEOM
# For now, works only with dslash_fn_dblstore.o
# FEWSUMS Fewer CG reduction calls

KSCGSTORE = -DDBLSTORE_FN -DFEWSUMS -DD_FN_GATHER13
# If we are using QUDA, the backward links are unused, so we should
# avoid unecessary overhead and use the standard dslash. Note that
# dslash_fn also has hooks in place to offload any dslash_fn_field
# calls to QUDA
ifeq ($(strip ${WANTQUDA}),true)
KSCGSTORE = -DFEWSUMS
else
KSCGSTORE = -DDBLSTORE_FN -DFEWSUMS -DD_FN_GATHER13
endif

#------------------------------
# Staggered fermion force routines
Expand Down
42 changes: 15 additions & 27 deletions generic/gauge_force_imp_gpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,47 @@

/**#define GFTIME**/ /* For timing gauge force calculation */
#include "generic_includes.h" /* definitions files and prototypes */

#include <quda.h>
#include <quda_milc_interface.h>
#include "../include/openmp_defs.h"

#include "../include/generic_quda.h"

// gpu code
void imp_gauge_force_gpu(Real eps, field_offset mom_off)
{

#ifdef GFTIME
int nflop = 153004; /* For Symanzik1 action */
double dtime = -dclock();
#endif

Real **loop_coeff = get_loop_coeff();
//int max_length = get_max_length();
//int nreps = get_nreps();

const int num_loop_types = get_nloop();
double *quda_loop_coeff = (double*)malloc(num_loop_types * sizeof(double));
int i;
#ifdef GFTIME
int nflop = 153004; /* For Symanzik1 action */
double dtime = -dclock();
#endif

site *st;
const Real eb3 = eps*beta/3.0;

initialize_quda();

su3_matrix *links = qudaAllocatePinned(sites_on_node*4*sizeof(su3_matrix));
anti_hermitmat* momentum = qudaAllocatePinned(sites_on_node*4*sizeof(anti_hermitmat));

int dir,j;
site *st;
su3_matrix *links = create_G_from_site_quda();
anti_hermitmat* momentum = create_M_quda();

for(i=0; i<num_loop_types; ++i) quda_loop_coeff[i] = loop_coeff[i][0];

FORALLSITES_OMP(i,st,private(dir)){
for(dir=XUP; dir<=TUP; ++dir){
links[4*i + dir] = st->link[dir];
} // dir
} END_LOOP_OMP

qudaGaugeForce(PRECISION,num_loop_types,quda_loop_coeff,eb3,links,momentum);

FORALLSITES_OMP(i,st,private(dir,j)){
for(dir=XUP; dir<=TUP; ++dir){
for(j=0; j<10; ++j){
FORALLSITES_OMP(i,st,){
for(int dir=XUP; dir<=TUP; ++dir){
for(int j=0; j<10; ++j){
((Real*)&(st->mom[dir]))[j] += ((Real*)(momentum + 4*i+dir))[j];
}
}
} END_LOOP_OMP

free(quda_loop_coeff);
qudaFreePinned(links);
qudaFreePinned(momentum);
destroy_G_quda(links);
destroy_M_quda(momentum);

#ifdef GFTIME
dtime+=dclock();
Expand Down
54 changes: 52 additions & 2 deletions generic/reunitarize2.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
#include "generic_includes.h"
#include "../include/openmp_defs.h"

#ifdef USE_GF_GPU
#include "../include/generic_quda.h"
#endif

#define TOLERANCE (0.0001)
#define MAXERRCOUNT 100
/**#define UNIDEBUG**/
Expand Down Expand Up @@ -201,7 +205,35 @@ int reunit_su3(su3_matrix *c)

} /* reunit_su3 */

void reunitarize() {
#ifdef USE_GF_GPU

void reunitarize_gpu() {

initialize_quda();

#ifdef GFTIME
double dtime, dclock();
dtime = -dclock();
#endif

su3_matrix *links = create_G_from_site_quda();

qudaUnitarizeSU3(PRECISION, links, TOLERANCE);

copy_to_site_from_G_quda(links); // insert back into site

destroy_G_quda(links);

#ifdef GFTIME
dtime += dclock();
node0_printf("REUNITARIZE: time = %e\n", dtime);
#endif

} /* reunitarize2 */

#endif

void reunitarize_cpu() {
register su3_matrix *mat;
register int i,dir;
register site *s;
Expand All @@ -210,7 +242,7 @@ void reunitarize() {

max_deviation = 0.;
av_deviation = 0.;

FORALLSITES_OMP(i,s,private(dir,mat,errors) reduction(+:errcount) ){
#ifdef SCHROED_FUN
for(dir=XUP; dir<=TUP; dir++ ) if(dir==TUP || s->t>0 ){
Expand Down Expand Up @@ -248,3 +280,21 @@ void reunitarize() {

} /* reunitarize2 */

void reunitarize() {

#ifdef USE_GF_GPU

/* Use QUDA if gauge-force is enabled for GPU, but fallback to CPU
if Schroedinger functional boundary conditions are enabled */
#ifdef SCHROED_FUN
node0_printf("%s not supported on GPU, using CPU fallback\n", __func__);
reunitarize_cpu();
#else
reunitarize_gpu();
#endif

#else
reunitarize_cpu();
#endif

}
3 changes: 0 additions & 3 deletions generic_ks/Make_template
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ G_KS_ALL = \
fermion_links_fn_load_milc.o \
fermion_links_fn_utilities_gpu.o \
fermion_links_hisq_milc.o \
fermion_links_hisq_load_gpu.o \
fermion_links_hisq_load_milc.o \
fermion_links_hisq_qop.o \
fermion_links_hyp.o \
Expand Down Expand Up @@ -324,8 +323,6 @@ fermion_links_fn_load_milc.o: ../generic_ks/fermion_links_fn_load_milc.c
${CC} -c ${CFLAGS} $<
fermion_links_hisq_milc.o: ../generic_ks/fermion_links_hisq_milc.c
${CC} -c ${CFLAGS} $<
fermion_links_hisq_load_gpu.o: ../generic_ks/fermion_links_hisq_load_gpu.c
${CC} -c ${CFLAGS} $<
fermion_links_hisq_load_milc.o: ../generic_ks/fermion_links_hisq_load_milc.c
${CC} -c ${CFLAGS} $<
fermion_links_hisq_qdp.o: ../generic_ks/fermion_links_hisq_qdp.c
Expand Down
7 changes: 3 additions & 4 deletions generic_ks/d_congrad5_fn_gpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,11 @@ int ks_congrad_parity_gpu(su3_vector *t_src, su3_vector *t_dest,
int num_iters;

// for newer versions of QUDA we need to invalidate the gauge field if the links are new
static imp_ferm_links_t *fn_last = NULL;
if ( fn != fn_last || fresh_fn_links(fn) ){
if ( fn != get_fn_last() || fresh_fn_links(fn) ){
cancel_quda_notification(fn);
fn_last = fn;
set_fn_last(fn);
num_iters = -1;
node0_printf("%s: fn, notify: Signal QUDA to refresh links", myname);
node0_printf("%s: fn, notify: Signal QUDA to refresh links\n", myname);
}

qudaInvert(PRECISION,
Expand Down
51 changes: 47 additions & 4 deletions generic_ks/dslash_fn.c
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,57 @@ void dslash_fn_site_special( field_offset src, field_offset dest,

}

#ifdef USE_CG_GPU
#include "../include/generic_quda.h"

// if using QUDA then we offload the dslash to the GPU
void dslash_fn_field( su3_vector *src, su3_vector *dest, int parity,
fn_links_t *fn) {

su3_matrix* fatlink = get_fatlinks(fn);
su3_matrix* longlink = get_lnglinks(fn);

// for newer versions of QUDA we need to invalidate the gauge field if the links are new
int num_iters;
if (fn != get_fn_last() || fresh_fn_links(fn)){
cancel_quda_notification(fn);
set_fn_last(fn);
num_iters = -1;
node0_printf("%s: fn, notify: Signal QUDA to refresh links\n", __func__);
}

QudaInvertArgs_t inv_args;
if (parity != EVENANDODD) {
switch(parity) {
case EVEN: inv_args.evenodd = QUDA_EVEN_PARITY; break;
case ODD: inv_args.evenodd = QUDA_ODD_PARITY; break;
default: printf("%s: Unrecognised parity\n",__func__); terminate(2);
}

qudaDslash(PRECISION, PRECISION, inv_args, fatlink, longlink, u0, src, dest, &num_iters);
} else { // do both parities as separate calls
inv_args.evenodd = QUDA_EVEN_PARITY;
qudaDslash(PRECISION, PRECISION, inv_args, fatlink, longlink, u0, src, dest, &num_iters);
inv_args.evenodd = QUDA_ODD_PARITY;
qudaDslash(PRECISION, PRECISION, inv_args, fatlink, longlink, u0, src, dest, &num_iters);
}

}

#else

void dslash_fn_field( su3_vector *src, su3_vector *dest, int parity,
fn_links_t *fn) {

msg_tag *tag[16];

dslash_fn_field_special(src, dest, parity, tag, 1, fn);
cleanup_one_gather_set(tag);

dslash_fn_field_special(src, dest, parity, tag, 1, fn);
cleanup_one_gather_set(tag);

}

#endif

/* Special dslash for use by congrad. Uses restart_gather_field() when
possible. Next to last argument is an array of message tags, to be set
if this is the first use, otherwise reused. If start=1,use
Expand Down Expand Up @@ -532,7 +575,7 @@ dslash_fn_dir(su3_vector *src, su3_vector *dest, int parity,
{
register int i ;
site *s;
msg_tag *tag[2];
msg_tag *tag[2] = {NULL, NULL};
su3_matrix *fat = get_fatlinks(fn);
su3_matrix *lng = get_lnglinks(fn);
su3_vector tmp;
Expand Down
11 changes: 4 additions & 7 deletions generic_ks/fermion_force_asqtad_gpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ fermion_force_oprod_site(Real eps, Real weight1, Real weight2,
msg_tag* mtag[2];

{ // copy the quark-field information to su3_vector fields
v[0] = (su3_vector*)malloc(sites_on_node*sizeof(su3_vector));
v[1] = (su3_vector*)malloc(sites_on_node*sizeof(su3_vector));

if(v[0] == NULL) printf("fermion_force_oprod_site: v[0] not allocated\n");
if(v[1] == NULL) printf("fermion_force_oprod_site: v[1] not allocated\n");
v[0] = (su3_vector*)qudaAllocatePinned(sites_on_node*sizeof(su3_vector));
v[1] = (su3_vector*)qudaAllocatePinned(sites_on_node*sizeof(su3_vector));

FORALLSITES(i,s){
v[0][i] = *(su3_vector*)F_PT(s,x1_off);
Expand Down Expand Up @@ -67,8 +64,8 @@ fermion_force_oprod_site(Real eps, Real weight1, Real weight2,
free(combined_coeff);

// Cleanup
free(v[0]);
free(v[1]);
qudaFreePinned(v[0]);
qudaFreePinned(v[1]);
}

void
Expand Down
Loading

0 comments on commit f407b32

Please sign in to comment.