は、まぁ、エレガントな方法が一杯あったようですが、SIMD化の勉強も兼ねて(言い訳)力ずくで解いてしまいました。
#include<stdlib.h>
#include<stdio.h>
#include<emmintrin.h>
#define NUM_NODE 314159
#define BLOCK 1024
#ifdef _OPENMP
#include<omp.h>
#define NTHREADS 32
#endif
#define _mm_extract_epi32(x, imm) _mm_cvtsi128_si32(_mm_srli_si128((x), 4 * (imm)))
int main(){
FILE *fp;
int *dest;
int i,ib;
int j;
unsigned long long nCross=0;
int in,iBase;
int **iTemp;
int ip;
int *ll;
__m128i sum_sse;
__m128i one_sse = _mm_set1_epi32(1.0);
dest = (int *)_mm_malloc(sizeof(int)*NUM_NODE,32);
#ifdef _OPENMP
iTemp= (int **)_mm_malloc(sizeof(int*)*NTHREADS,32);
for(i=0;i<NTHREADS;i++){
iTemp[i]= (int *)_mm_malloc(sizeof(int)*BLOCK,32);
}
#else
iTemp= (int **)_mm_malloc(sizeof(int*),32);
iTemp[0]= (int *)_mm_malloc(sizeof(int)*BLOCK,32);
#endif
fp = fopen("crossing.txt","r");
for(i=0;i<NUM_NODE;i++){
fscanf(fp,"%d",&dest[i]);
}
#pragma omp parallel for private(i,j,in,ip,iBase,ll,sum_sse) reduction(+: nCross) schedule(dynamic) num_threads(NTHREADS)
for(i=0;i<NUM_NODE-BLOCK;i+=BLOCK){
// fprintf(stdout,"%d\n",i);
#ifdef _OPENMP
ip = omp_get_thread_num();
#else
ip = 0;
#endif
ll = iTemp[ip];
// for(in=0;in<BLOCK;in++){ll[in] = dest[i+in];}
__m128i *llsse ;
__m128i *dest_sse ;
llsse = (__m128i*)ll;
dest_sse = (__m128i*)&dest[i];
for(in=0;in<BLOCK/4;in++){
_mm_store_si128(&llsse[in],_mm_load_si128(&dest_sse[in]));
}
for(j=0;j<BLOCK;j++){
iBase = ll[j];
for(in=j+1;in<BLOCK;in++){
if( ll[in] < iBase ){nCross++;}
}
}
sum_sse = _mm_setzero_si128();
for(j=i+BLOCK;j<NUM_NODE-4;j+=4){
__m128i ldst_sse = _mm_load_si128((__m128i *)&dest[j]);
__m128i iChk_sse0 = _mm_shuffle_epi32(ldst_sse,_MM_SHUFFLE(0,0,0,0));
__m128i iChk_sse1 = _mm_shuffle_epi32(ldst_sse,_MM_SHUFFLE(1,1,1,1));
__m128i iChk_sse2 = _mm_shuffle_epi32(ldst_sse,_MM_SHUFFLE(2,2,2,2));
__m128i iChk_sse3 = _mm_shuffle_epi32(ldst_sse,_MM_SHUFFLE(3,3,3,3));
for(in=0;in<BLOCK/4;in++){
__m128i mask0 = _mm_cmplt_epi32(iChk_sse0,llsse[in]);
__m128i mask1 = _mm_cmplt_epi32(iChk_sse1,llsse[in]);
__m128i mask2 = _mm_cmplt_epi32(iChk_sse2,llsse[in]);
__m128i mask3 = _mm_cmplt_epi32(iChk_sse3,llsse[in]);
__m128i res0 = _mm_and_si128(one_sse,mask0);
__m128i res1 = _mm_and_si128(one_sse,mask1);
__m128i res2 = _mm_and_si128(one_sse,mask2);
__m128i res3 = _mm_and_si128(one_sse,mask3);
sum_sse += res0 + res1 + res2 + res3;
}
}
for(j=(NUM_NODE/4)*4;j<NUM_NODE;j+=1){
__m128i iChk_sse = _mm_set1_epi32(dest[j]);
for(in=0;in<BLOCK/4;in++){
__m128i mask = _mm_cmplt_epi32(iChk_sse,llsse[in]);
sum_sse += _mm_and_si128(one_sse,mask);
}
}
int s_sum;
s_sum = _mm_extract_epi32(sum_sse,0);
s_sum += _mm_extract_epi32(sum_sse,1);
s_sum += _mm_extract_epi32(sum_sse,2);
s_sum += _mm_extract_epi32(sum_sse,3);
nCross += s_sum;
}
ib=(NUM_NODE/BLOCK)*BLOCK;
for(i=ib;i<NUM_NODE;i++){
iBase=dest[i];
for(j=i+1;j<NUM_NODE;j++){
if( dest[j] < iBase ){
nCross++;
}
}
}
fprintf(stdout,"Numer of cross point is %llu\n",nCross);
}
0 件のコメント:
コメントを投稿