nissy-classic

Stable branch of nissy
git clone https://git.tronto.net/nissy-classic
Download | Log | Files | Refs | README | LICENSE

pruning.c (11922B)


      1 #include "pruning.h"
      2 
      3 #define ENTRIES_PER_GROUP              (2*sizeof(entry_group_t))
      4 #define ENTRIES_PER_GROUP_COMPACT      (4*sizeof(entry_group_t))
      5 
      6 static int         findchunk(PruneData *pd, int nchunks, uint64_t i);
      7 static void        genptable_bfs(PruneData *pd, int d, int nt, int nc);
      8 static void        genptable_compress(PruneData *pd);
      9 static void        genptable_fixnasty(PruneData *pd, int d, int nthreads);
     10 static void        genptable_setbase(PruneData *pd);
     11 static void *      instance_bfs(void *arg);
     12 static void *      instance_fixnasty(void *arg);
     13 static void        ptable_update(PruneData *pd, Cube cube, int m);
     14 static void        ptable_update_index(PruneData *pd, uint64_t ind, int m);
     15 static int         ptableval_index(PruneData *pd, uint64_t ind);
     16 static bool        read_ptable_file(PruneData *pd);
     17 static bool        write_ptable_file(PruneData *pd);
     18 
     19 PruneData
     20 pd_eofb_HTM = {
     21 	.filename = "pt_eofb_HTM",
     22 	.coord    = &coord_eofb,
     23 	.moveset  = &moveset_HTM,
     24 };
     25 
     26 PruneData
     27 pd_coud_HTM = {
     28 	.filename = "pt_coud_HTM",
     29 	.coord    = &coord_coud,
     30 	.moveset  = &moveset_HTM,
     31 };
     32 
     33 PruneData
     34 pd_cornershtr_HTM = {
     35 	.filename = "pt_cornershtr_HTM",
     36 	.coord    = &coord_cornershtr,
     37 	.moveset  = &moveset_HTM,
     38 };
     39 
     40 PruneData
     41 pd_corners_HTM = {
     42 	.filename = "pt_corners_HTM",
     43 	.coord    = &coord_corners,
     44 	.moveset  = &moveset_HTM,
     45 };
     46 
     47 PruneData
     48 pd_drud_sym16_HTM = {
     49 	.filename = "pt_drud_sym16_HTM",
     50 	.coord    = &coord_drud_sym16,
     51 	.moveset  = &moveset_HTM,
     52 };
     53 
     54 PruneData
     55 pd_drud_eofb = {
     56 	.filename = "pt_drud_eofb",
     57 	.coord    = &coord_drud_eofb,
     58 	.moveset  = &moveset_eofb,
     59 };
     60 
     61 PruneData
     62 pd_drudfin_noE_sym16_drud = {
     63 	.filename = "pt_drudfin_noE_sym16_drud",
     64 	.coord    = &coord_drudfin_noE_sym16,
     65 	.moveset  = &moveset_drud,
     66 };
     67 
     68 PruneData
     69 pd_htr_drud = {
     70 	.filename = "pt_htr_drud",
     71 	.coord    = &coord_htr_drud,
     72 	.moveset  = &moveset_drud,
     73 };
     74 
     75 PruneData
     76 pd_cp_drud = {
     77 	.filename = "pt_cp_drud",
     78 	.coord    = &coord_cp,
     79 	.moveset  = &moveset_drud,
     80 };
     81 
     82 PruneData
     83 pd_htrfin_htr = {
     84 	.filename = "pt_htrfin_htr",
     85 	.coord    = &coord_htrfin,
     86 	.moveset  = &moveset_htr,
     87 };
     88 
     89 PruneData
     90 pd_nxopt31_HTM = {
     91 	.filename = "pt_nxopt31_HTM",
     92 	.coord    = &coord_nxopt31,
     93 	.moveset  = &moveset_HTM,
     94 
     95 	.compact  = true,
     96 	.fallback = &pd_drud_sym16_HTM,
     97 	.fbmod    = BINOM8ON4,
     98 };
     99 
    100 PruneData * all_pd[] = {
    101 	&pd_eofb_HTM,
    102 	&pd_coud_HTM,
    103 	&pd_cornershtr_HTM,
    104 	&pd_corners_HTM,
    105 	&pd_drud_sym16_HTM,
    106 	&pd_drud_eofb,
    107 	&pd_drudfin_noE_sym16_drud,
    108 	&pd_htr_drud,
    109 	&pd_cp_drud,
    110 	&pd_htrfin_htr,
    111 	&pd_nxopt31_HTM,
    112 	NULL
    113 };
    114 
    115 /* Functions *****************************************************************/
    116 
    117 int
    118 findchunk(PruneData *pd, int nchunks, uint64_t i)
    119 {
    120 	uint64_t chunksize;
    121 
    122 	chunksize = pd->coord->max / (uint64_t)nchunks;
    123 	chunksize += ENTRIES_PER_GROUP - (chunksize % ENTRIES_PER_GROUP);
    124 
    125 	return MIN(nchunks-1, (int)(i / chunksize));
    126 }
    127 
    128 void
    129 free_pd(PruneData *pd)
    130 {
    131 	if (pd->generated)
    132 		free(pd->ptable);
    133 
    134 	pd->generated = false;
    135 }
    136 
    137 void
    138 genptable(PruneData *pd, int nthreads)
    139 {
    140 	bool compact;
    141 	int d, nchunks;
    142 	uint64_t oldn, sz;
    143 
    144 	if (pd->generated)
    145 		return;
    146 
    147 	/* TODO: check if memory is enough, otherwise maybe exit gracefully? */
    148 	sz = ptablesize(pd) * (pd->compact ? 2 : 1);
    149 	pd->ptable = malloc(sz * sizeof(entry_group_t));
    150 
    151 	if (read_ptable_file(pd)) {
    152 		pd->generated = true;
    153 		return;
    154 	}
    155 
    156 	if (nthreads < 4) {
    157 		fprintf(stderr,
    158 			"--- Warning ---\n"
    159 			"You are using only %d threads to generate the pruning"
    160 			"tables. This can take a while."
    161 			"Unless you did this intentionally, you should re-run"
    162 			"this command with `-t 4' or more.\n"
    163 			"---------------\n\n", nthreads
    164 		);
    165 	}
    166 
    167 
    168 	/* For the first steps we proceed the same way for compact and not */
    169 	compact = pd->compact;
    170 	pd->compact = false;
    171 	pd->generated = true;
    172 
    173 	nchunks = MIN(ptablesize(pd), 100000);
    174 	fprintf(stderr, "Cannot load %s, generating it with %d threads\n",
    175 			pd->filename, nthreads); 
    176 
    177 
    178 	memset(pd->ptable, ~(uint8_t)0, ptablesize(pd)*sizeof(entry_group_t));
    179 
    180 	ptable_update(pd, (Cube){0}, 0);
    181 	pd->n = 1;
    182 	oldn = 0;
    183 	genptable_fixnasty(pd, 0, nthreads);
    184 	fprintf(stderr, "Depth %d done, generated %"
    185 		PRIu64 "\t(%" PRIu64 "/%" PRIu64 ")\n",
    186 		0, pd->n - oldn, pd->n, pd->coord->max);
    187 	oldn = pd->n;
    188 	pd->count[0] = pd->n;
    189 	for (d = 0; d < 15 && pd->n < pd->coord->max; d++) {
    190 		genptable_bfs(pd, d, nthreads, nchunks);
    191 		genptable_fixnasty(pd, d+1, nthreads);
    192 		fprintf(stderr, "Depth %d done, generated %"
    193 			PRIu64 "\t(%" PRIu64 "/%" PRIu64 ")\n",
    194 			d+1, pd->n - oldn, pd->n, pd->coord->max);
    195 		pd->count[d+1] = pd->n - oldn;
    196 		oldn = pd->n;
    197 	}
    198 	fprintf(stderr, "Pruning table generated!\n");
    199 	
    200 	genptable_setbase(pd);
    201 	if (compact)
    202 		genptable_compress(pd);
    203 	
    204 	if (!write_ptable_file(pd))
    205 		fprintf(stderr, "Error writing ptable file\n");
    206 }
    207 
    208 static void
    209 genptable_bfs(PruneData *pd, int d, int nthreads, int nchunks)
    210 {
    211 	int i;
    212 	pthread_t t[nthreads];
    213 	ThreadDataGenpt td[nthreads];
    214 	pthread_mutex_t *mtx[nchunks], *upmtx;
    215 
    216 	upmtx = malloc(sizeof(pthread_mutex_t));
    217 	pthread_mutex_init(upmtx, NULL);
    218 	for (i = 0; i < nchunks; i++) {
    219 		mtx[i] = malloc(sizeof(pthread_mutex_t));
    220 		pthread_mutex_init(mtx[i], NULL);
    221 	}
    222 
    223 	for (i = 0; i < nthreads; i++) {
    224 		td[i].thid     = i;
    225 		td[i].nthreads = nthreads;
    226 		td[i].pd       = pd;
    227 		td[i].d        = d;
    228 		td[i].nchunks  = nchunks;
    229 		td[i].mutex    = mtx;
    230 		td[i].upmutex  = upmtx;
    231 		pthread_create(&t[i], NULL, instance_bfs, &td[i]);
    232 	}
    233 
    234 	for (i = 0; i < nthreads; i++)
    235 		pthread_join(t[i], NULL);
    236 
    237 	free(upmtx);
    238 	for (i = 0; i < nchunks; i++)
    239 		free(mtx[i]);
    240 }
    241 
    242 static void
    243 genptable_compress(PruneData *pd)
    244 {
    245 	int val;
    246 	uint64_t i, j;
    247 	entry_group_t mask, v;
    248 
    249 	fprintf(stderr, "Compressing table to 2 bits per entry\n");
    250 
    251 	for (i = 0; i < pd->coord->max; i += ENTRIES_PER_GROUP_COMPACT) {
    252 		mask = (entry_group_t)0;
    253 		for (j = 0; j < ENTRIES_PER_GROUP_COMPACT; j++) {
    254 			if (i+j >= pd->coord->max)
    255 				break;
    256 			val = ptableval_index(pd, i+j) - pd->base;
    257 			v = (entry_group_t)MIN(3, MAX(0, val));
    258 			mask |= v << (2*j);
    259 		}
    260 		pd->ptable[i/ENTRIES_PER_GROUP_COMPACT] = mask;
    261 	}
    262 
    263 	pd->compact = true;
    264 	pd->ptable = realloc(pd->ptable, sizeof(entry_group_t)*ptablesize(pd));
    265 }
    266 
    267 static void
    268 genptable_fixnasty(PruneData *pd, int d, int nthreads)
    269 {
    270 	int i;
    271 	pthread_t t[nthreads];
    272 	ThreadDataGenpt td[nthreads];
    273 	pthread_mutex_t *upmtx;
    274 
    275 	if (pd->coord->tfind == NULL)
    276 		return;
    277 
    278 	upmtx = malloc(sizeof(pthread_mutex_t));
    279 	pthread_mutex_init(upmtx, NULL);
    280 	for (i = 0; i < nthreads; i++) {
    281 		td[i].thid     = i;
    282 		td[i].nthreads = nthreads;
    283 		td[i].pd       = pd;
    284 		td[i].d        = d;
    285 		td[i].upmutex  = upmtx;
    286 		pthread_create(&t[i], NULL, instance_fixnasty, &td[i]);
    287 	}
    288 
    289 	for (i = 0; i < nthreads; i++)
    290 		pthread_join(t[i], NULL);
    291 
    292 	free(upmtx);
    293 }
    294 
    295 static void
    296 genptable_setbase(PruneData *pd)
    297 {
    298 	int i;
    299 	uint64_t sum, newsum;
    300 
    301 	pd->base = 0;
    302 	sum = pd->count[0] + pd->count[1] + pd->count[2];
    303 	for (i = 3; i < 16; i++) {
    304 		newsum = sum + pd->count[i] - pd->count[i-3];
    305 		if (newsum > sum)
    306 			pd->base = i-3;
    307 		sum = newsum;
    308 	}
    309 }
    310 
    311 static void *
    312 instance_bfs(void *arg)
    313 {
    314 	ThreadDataGenpt *td;
    315 	uint64_t i, ii, blocksize, rmin, rmax, updated;
    316 	int j, pval, ichunk;
    317 	Move *ms;
    318 
    319 	td = (ThreadDataGenpt *)arg;
    320 	ms = td->pd->moveset->sorted_moves;
    321 	blocksize = td->pd->coord->max / (uint64_t)td->nthreads;
    322 	rmin = ((uint64_t)td->thid) * blocksize;
    323 	rmax = td->thid == td->nthreads - 1 ?
    324 	       td->pd->coord->max :
    325 	       ((uint64_t)td->thid + 1) * blocksize;
    326 
    327 	updated = 0;
    328 	for (i = rmin; i < rmax; i++) {
    329 		ichunk = findchunk(td->pd, td->nchunks, i);
    330 		pthread_mutex_lock(td->mutex[ichunk]);
    331 		pval = ptableval_index(td->pd, i);
    332 		pthread_mutex_unlock(td->mutex[ichunk]);
    333 		if (pval == td->d) {
    334 			for (j = 0; ms[j] != NULLMOVE; j++) {
    335 				ii = td->pd->coord->move(ms[j], i);
    336 				ichunk = findchunk(td->pd, td->nchunks, ii);
    337 				pthread_mutex_lock(td->mutex[ichunk]);
    338 				pval = ptableval_index(td->pd, ii);
    339 				if (pval > td->d+1) {
    340 					ptable_update_index(td->pd,
    341 					    ii, td->d+1);
    342 					updated++;
    343 				}
    344 				pthread_mutex_unlock(td->mutex[ichunk]);
    345 			}
    346 		}
    347 	}
    348 
    349 	pthread_mutex_lock(td->upmutex);
    350 	td->pd->n += updated;
    351 	pthread_mutex_unlock(td->upmutex);
    352 
    353 	return NULL;
    354 }
    355 
    356 static void *
    357 instance_fixnasty(void *arg)
    358 {
    359 	ThreadDataGenpt *td;
    360 	uint64_t i, ii, nb, blocksize, rmin, rmax, updated;
    361 	int j, n;
    362 	Trans t, aux[NTRANS];
    363 
    364 	td = (ThreadDataGenpt *)arg;
    365 	nb = td->pd->coord->max / td->pd->coord->base->max;
    366 	blocksize = (td->pd->coord->base->max / td->nthreads) * nb;
    367 	rmin = ((uint64_t)td->thid) * blocksize;
    368 	rmax = td->thid == td->nthreads - 1 ?
    369 	       td->pd->coord->max :
    370 	       ((uint64_t)td->thid + 1) * blocksize;
    371 
    372 	updated = 0;
    373 	for (i = rmin; i < rmax; i++) {
    374 		if (ptableval_index(td->pd, i) == td->d) {
    375 			if ((n = td->pd->coord->tfind(i, aux)) == 1)
    376 				continue;
    377 
    378 			for (j = 0; j < n; j++) {
    379 				if ((t = aux[j]) == uf)
    380 					continue;
    381 				ii = td->pd->coord->transform(t, i);
    382 				if (ptableval_index(td->pd, ii) > td->d) {
    383 					ptable_update_index(td->pd, ii, td->d);
    384 					updated++;
    385 				}
    386 			}
    387 		}
    388 	}
    389 
    390 	pthread_mutex_lock(td->upmutex);
    391 	td->pd->n += updated;
    392 	pthread_mutex_unlock(td->upmutex);
    393 
    394 	return NULL;
    395 }
    396 
    397 void
    398 print_ptable(PruneData *pd)
    399 {
    400 	uint64_t i;
    401 
    402 	if (!pd->generated)
    403 		genptable(pd, 1); /* TODO: set default nthreads somewhere */
    404 		
    405 	printf("Table %s\n", pd->filename);
    406 	printf("Base value: %d\n", pd->base);
    407 	for (i = 0; i < 16; i++)
    408 		if (pd->count[i] != 0)
    409 			printf("%2" PRIu64 "\t%10" PRIu64 "\n",
    410 			    i, pd->count[i]);
    411 	printf("Total: %" PRIu64 "\n", pd->coord->max);
    412 }
    413 
    414 uint64_t
    415 ptablesize(PruneData *pd)
    416 {
    417 	uint64_t e;
    418 
    419 	e = pd->compact ? ENTRIES_PER_GROUP_COMPACT : ENTRIES_PER_GROUP;
    420 
    421 	return (pd->coord->max + e - 1) / e;
    422 }
    423 
    424 static void
    425 ptable_update(PruneData *pd, Cube cube, int n)
    426 {
    427 	ptable_update_index(pd, pd->coord->index(cube), n);
    428 }
    429 
    430 static void
    431 ptable_update_index(PruneData *pd, uint64_t ind, int n)
    432 {
    433 	int sh;
    434 	entry_group_t mask;
    435 	uint64_t i;
    436 
    437 	sh = 4 * (ind % ENTRIES_PER_GROUP);
    438 	mask = ((entry_group_t)15) << sh;
    439 	i = ind/ENTRIES_PER_GROUP;
    440 
    441 	pd->ptable[i] &= ~mask;
    442 	pd->ptable[i] |= (((entry_group_t)n)&15) << sh;
    443 }
    444 
    445 int
    446 ptableval(PruneData *pd, Cube cube)
    447 {
    448 	return ptableval_index(pd, pd->coord->index(cube));
    449 }
    450 
    451 static int
    452 ptableval_index(PruneData *pd, uint64_t ind)
    453 {
    454 	int sh, ret;
    455 	entry_group_t mask;
    456 	uint64_t i, e;
    457 	entry_group_t m;
    458 
    459 	if (!pd->generated) {
    460 		fprintf(stderr, "Warning: request pruning table value"
    461 			" for uninitialized table %s.\n It's fine, but it"
    462 			" should not happen. Please report bug.\n",
    463 			pd->filename);
    464 		genptable(pd, 1); /* TODO: set default or remove this case */
    465 	}
    466 
    467 	if (pd->compact) {
    468 		e  = ENTRIES_PER_GROUP_COMPACT;
    469 		m  = 3;
    470 		sh = (ind % e) * 2;
    471 	} else {
    472 		e  = ENTRIES_PER_GROUP;
    473 		m  = 15;
    474 		sh = (ind % e) * 4;
    475 	}
    476 
    477 	mask = m << sh;
    478 	i = ind/e;
    479 
    480 	ret = (pd->ptable[i] & mask) >> sh;
    481 
    482 	if (pd->compact) {
    483 		if (ret)
    484 			ret += pd->base;
    485 		else
    486 			ret = ptableval_index(pd->fallback, ind / pd->fbmod);
    487 	}
    488 
    489 	return ret;
    490 }
    491 
    492 static bool
    493 read_ptable_file(PruneData *pd)
    494 {
    495 	init_env();
    496 
    497 	FILE *f;
    498 	char fname[strlen(tabledir)+100];
    499 	int i;
    500 	uint64_t r;
    501 
    502 	strcpy(fname, tabledir);
    503 	strcat(fname, "/");
    504 	strcat(fname, pd->filename);
    505 
    506 	if ((f = fopen(fname, "rb")) == NULL)
    507 		return false;
    508 
    509 	r = fread(&(pd->base), sizeof(int), 1, f);
    510 	for (i = 0; i < 16; i++)
    511 		r += fread(&(pd->count[i]), sizeof(uint64_t), 1, f);
    512 	r += fread(pd->ptable, sizeof(entry_group_t), ptablesize(pd), f);
    513 
    514 	fclose(f);
    515 
    516 	return r == 17 + ptablesize(pd);
    517 }
    518 
    519 static bool
    520 write_ptable_file(PruneData *pd)
    521 {
    522 	init_env();
    523 
    524 	FILE *f;
    525 	char fname[strlen(tabledir)+100];
    526 	int i;
    527 	uint64_t w;
    528 
    529 	strcpy(fname, tabledir);
    530 	strcat(fname, "/");
    531 	strcat(fname, pd->filename);
    532 
    533 	if ((f = fopen(fname, "wb")) == NULL)
    534 		return false;
    535 
    536 	w = fwrite(&(pd->base), sizeof(int), 1, f);
    537 	for (i = 0; i < 16; i++)
    538 		w += fwrite(&(pd->count[i]), sizeof(uint64_t), 1, f);
    539 	w += fwrite(pd->ptable, sizeof(entry_group_t), ptablesize(pd), f);
    540 	fclose(f);
    541 
    542 	return w == 17 + ptablesize(pd);
    543 }
    544