nissy-nx

A Rubik's cube optimal solver
git clone https://git.tronto.net/nissy-nx
Download | Log | Files | Refs | README | LICENSE

pruning.c (9307B)


      1 #define PRUNING_C
      2 
      3 #include "pruning.h"
      4 
      5 #define ENTRIES_PER_GROUP              (2*sizeof(entry_group_t))
      6 #define ENTRIES_PER_GROUP_COMPACT      (4*sizeof(entry_group_t))
      7 
      8 static int         findchunk(PruneData *pd, int nchunks, uint64_t i);
      9 static void        genptable_bfs(PruneData *pd, int d, int nt, int nc);
     10 static void        genptable_fixnasty(PruneData *pd, int d, int nthreads);
     11 static void *      instance_bfs(void *arg);
     12 static void *      instance_fixnasty(void *arg);
     13 static void        ptable_update(PruneData *pd, uint64_t ind, int m);
     14 static bool        read_ptable_file(PruneData *pd);
     15 static bool        write_ptable_file(PruneData *pd);
     16 
     17 PruneData *active_pd[256];
     18 
     19 int
     20 findchunk(PruneData *pd, int nchunks, uint64_t i)
     21 {
     22 	uint64_t chunksize;
     23 
     24 	chunksize = pd->coord->max / (uint64_t)nchunks;
     25 	chunksize += ENTRIES_PER_GROUP - (chunksize % ENTRIES_PER_GROUP);
     26 
     27 	return MIN(nchunks-1, (int)(i / chunksize));
     28 }
     29 
     30 PruneData *
     31 genptable(PruneData *pd, int nthreads)
     32 {
     33 	int d, nchunks, i, maxv;
     34 	uint64_t oldn;
     35 
     36 	for (i = 0; active_pd[i] != NULL; i++) {
     37 		if (active_pd[i]->coord == pd->coord &&
     38 		    active_pd[i]->moveset == pd->moveset &&
     39 		    active_pd[i]->compact == pd->compact)
     40 			return active_pd[i];
     41 	}
     42 
     43 	init_moveset(pd->moveset);
     44 	gen_coord(pd->coord);
     45 
     46 	pd->ptable = malloc(ptablesize(pd) * sizeof(entry_group_t));
     47 
     48 	if (read_ptable_file(pd))
     49 		goto genptable_done;
     50 
     51 	if (nthreads < 4) {
     52 		fprintf(stderr,
     53 			"--- Warning ---\n"
     54 			"You are using only %d threads to generate the pruning"
     55 			"tables. This can take a while.\n"
     56 			"Unless you did this intentionally, you should re-run"
     57 			"this command with `-t 4' or more.\n"
     58 			"---------------\n\n", nthreads
     59 		);
     60 	}
     61 
     62 	nchunks = MIN(ptablesize(pd), 100000);
     63 	fprintf(stderr, "Generating pt_%s_%s with %d threads\n",
     64 			pd->coord->name, pd->moveset->name, nthreads); 
     65 
     66 	memset(pd->ptable, ~(uint8_t)0, ptablesize(pd)*sizeof(entry_group_t));
     67 	for (i = 0; i < 16; i++)
     68 		pd->count[i] = 0;
     69 
     70 	ptable_update(pd, 0, 0);
     71 	pd->n = 1;
     72 	oldn = 0;
     73 	genptable_fixnasty(pd, 0, nthreads);
     74 	fprintf(stderr, "Depth %d done, generated %"
     75 		PRIu64 "\t(%" PRIu64 "/%" PRIu64 ")\n",
     76 		0, pd->n - oldn, pd->n, pd->coord->max);
     77 	oldn = pd->n;
     78 	pd->count[0] = pd->n;
     79 
     80 	maxv = pd->compact ? MIN(15, pd->base + 4) : 15;
     81 	for (d = 0; d < maxv && pd->n < pd->coord->max; d++) {
     82 		genptable_bfs(pd, d, nthreads, nchunks);
     83 		genptable_fixnasty(pd, d+1, nthreads);
     84 		fprintf(stderr, "Depth %d done, generated %"
     85 			PRIu64 "\t(%" PRIu64 "/%" PRIu64 ")\n",
     86 			d+1, pd->n - oldn, pd->n, pd->coord->max);
     87 		pd->count[d+1] = pd->n - oldn;
     88 		oldn = pd->n;
     89 	}
     90 	if (pd->compact)
     91 		fprintf(stderr, "Compact table, values above "
     92 				"%d are inaccurate.\n", maxv-1);
     93 	fprintf(stderr, "Pruning table generated!\n");
     94 
     95 	if (!write_ptable_file(pd))
     96 		fprintf(stderr, "Error writing ptable file\n");
     97 
     98 genptable_done:
     99 	for (i = 0; active_pd[i] != NULL; i++);
    100 	return active_pd[i] = pd;
    101 }
    102 
    103 static void
    104 genptable_bfs(PruneData *pd, int d, int nthreads, int nchunks)
    105 {
    106 	int i;
    107 	pthread_t t[nthreads];
    108 	ThreadDataGenpt td[nthreads];
    109 	pthread_mutex_t *mtx[nchunks], *upmtx;
    110 
    111 	upmtx = malloc(sizeof(pthread_mutex_t));
    112 	pthread_mutex_init(upmtx, NULL);
    113 	for (i = 0; i < nchunks; i++) {
    114 		mtx[i] = malloc(sizeof(pthread_mutex_t));
    115 		pthread_mutex_init(mtx[i], NULL);
    116 	}
    117 
    118 	for (i = 0; i < nthreads; i++) {
    119 		td[i].thid     = i;
    120 		td[i].nthreads = nthreads;
    121 		td[i].pd       = pd;
    122 		td[i].d        = d;
    123 		td[i].nchunks  = nchunks;
    124 		td[i].mutex    = mtx;
    125 		td[i].upmutex  = upmtx;
    126 		pthread_create(&t[i], NULL, instance_bfs, &td[i]);
    127 	}
    128 
    129 	for (i = 0; i < nthreads; i++)
    130 		pthread_join(t[i], NULL);
    131 
    132 	free(upmtx);
    133 	for (i = 0; i < nchunks; i++)
    134 		free(mtx[i]);
    135 }
    136 
    137 static void
    138 genptable_fixnasty(PruneData *pd, int d, int nthreads)
    139 {
    140 	int i;
    141 	pthread_t t[nthreads];
    142 	ThreadDataGenpt td[nthreads];
    143 	pthread_mutex_t *upmtx;
    144 
    145 	if (pd->coord->type != SYMCOMP_COORD)
    146 		return;
    147 
    148 	upmtx = malloc(sizeof(pthread_mutex_t));
    149 	pthread_mutex_init(upmtx, NULL);
    150 	for (i = 0; i < nthreads; i++) {
    151 		td[i].thid     = i;
    152 		td[i].nthreads = nthreads;
    153 		td[i].pd       = pd;
    154 		td[i].d        = d;
    155 		td[i].upmutex  = upmtx;
    156 		pthread_create(&t[i], NULL, instance_fixnasty, &td[i]);
    157 	}
    158 
    159 	for (i = 0; i < nthreads; i++)
    160 		pthread_join(t[i], NULL);
    161 
    162 	free(upmtx);
    163 }
    164 
    165 static void *
    166 instance_bfs(void *arg)
    167 {
    168 	ThreadDataGenpt *td;
    169 	uint64_t i, ii, blocksize, rmin, rmax, updated;
    170 	int j, pval, ichunk, oldc, newc;
    171 	Move *ms;
    172 
    173 	td = (ThreadDataGenpt *)arg;
    174 	ms = td->pd->moveset->sorted_moves;
    175 	blocksize = td->pd->coord->max / (uint64_t)td->nthreads;
    176 	rmin = ((uint64_t)td->thid) * blocksize;
    177 	rmax = td->thid == td->nthreads - 1 ?
    178 	       td->pd->coord->max :
    179 	       ((uint64_t)td->thid + 1) * blocksize;
    180 
    181 	if (td->pd->compact) {
    182 		if (td->d <= td->pd->base) {
    183 			oldc = 1;
    184 			newc = 1;
    185 		} else {
    186 			oldc = td->d - td->pd->base;
    187 			newc = td->d - td->pd->base;
    188 		}
    189 	} else {
    190 		oldc = td->d;
    191 		newc = td->d + 1;
    192 	}
    193 
    194 	updated = 0;
    195 	for (i = rmin; i < rmax; i++) {
    196 		ichunk = findchunk(td->pd, td->nchunks, i);
    197 		pthread_mutex_lock(td->mutex[ichunk]);
    198 		pval = ptableval(td->pd, i);
    199 		pthread_mutex_unlock(td->mutex[ichunk]);
    200 		if (pval == oldc) {
    201 			for (j = 0; ms[j] != NULLMOVE; j++) {
    202 				ii = move_coord(td->pd->coord, ms[j], i, NULL);
    203 				ichunk = findchunk(td->pd, td->nchunks, ii);
    204 				pthread_mutex_lock(td->mutex[ichunk]);
    205 				pval = ptableval(td->pd, ii);
    206 				if (pval > newc) {
    207 					ptable_update(td->pd, ii, newc);
    208 					updated++;
    209 				}
    210 				pthread_mutex_unlock(td->mutex[ichunk]);
    211 			}
    212 			if (td->pd->compact && td->d <= td->pd->base) {
    213 				ichunk = findchunk(td->pd, td->nchunks, i);
    214 				pthread_mutex_lock(td->mutex[ichunk]);
    215 				ptable_update(td->pd, i, 0);
    216 				pthread_mutex_unlock(td->mutex[ichunk]);
    217 			}
    218 		}
    219 	}
    220 
    221 	pthread_mutex_lock(td->upmutex);
    222 	td->pd->n += updated;
    223 	pthread_mutex_unlock(td->upmutex);
    224 
    225 	return NULL;
    226 }
    227 
    228 static void *
    229 instance_fixnasty(void *arg)
    230 {
    231 	ThreadDataGenpt *td;
    232 	uint64_t i, ii, blocksize, rmin, rmax, updated, ss, M;
    233 	int j, oldc;
    234 	Trans t;
    235 
    236 	td = (ThreadDataGenpt *)arg;
    237 
    238 	/* We know type = SYMCOMP_COORD */
    239 	M = td->pd->coord->base[1]->max;
    240 	blocksize = (td->pd->coord->base[0]->max / td->nthreads) * M;
    241 	rmin = ((uint64_t)td->thid) * blocksize;
    242 	rmax = td->thid == td->nthreads - 1 ?
    243 	       td->pd->coord->max :
    244 	       ((uint64_t)td->thid + 1) * blocksize;
    245 
    246 	if (td->pd->compact) {
    247 		if (td->d <= td->pd->base)
    248 			oldc = 1;
    249 		else
    250 			oldc = td->d - td->pd->base;
    251 	} else {
    252 		oldc = td->d;
    253 	}
    254 
    255 	updated = 0;
    256 	for (i = rmin; i < rmax; i++) {
    257 		if (ptableval(td->pd, i) == oldc) {
    258 			ss = td->pd->coord->base[0]->selfsim[i/M];
    259 			for (j = 0; j < td->pd->coord->base[0]->tgrp->n; j++) {
    260 				t = td->pd->coord->base[0]->tgrp->t[j];
    261 				if (t == uf || !(ss & ((uint64_t)1<<t)))
    262 					continue;
    263 				ii = trans_coord(td->pd->coord, t, i);
    264 				if (ptableval(td->pd, ii) > oldc) {
    265 					ptable_update(td->pd, ii, oldc);
    266 					updated++;
    267 				}
    268 			}
    269 		}
    270 	}
    271 
    272 	pthread_mutex_lock(td->upmutex);
    273 	td->pd->n += updated;
    274 	pthread_mutex_unlock(td->upmutex);
    275 
    276 	return NULL;
    277 }
    278 
    279 void
    280 print_ptable(PruneData *pd)
    281 {
    282 	uint64_t i;
    283 
    284 	printf("Table %s_%s\n", pd->coord->name, pd->moveset->name);
    285 
    286 	if (pd->compact) {
    287 		printf("Compract table with base value: %d\n", pd->base);
    288 		printf("Values above %d are inaccurate.\n", pd->base + 3);
    289 	}
    290 
    291 	for (i = 0; i < 16; i++)
    292 		printf("%2" PRIu64 "\t%10" PRIu64 "\n", i, pd->count[i]);
    293 }
    294 
    295 uint64_t
    296 ptablesize(PruneData *pd)
    297 {
    298 	uint64_t e;
    299 
    300 	e = pd->compact ? ENTRIES_PER_GROUP_COMPACT : ENTRIES_PER_GROUP;
    301 
    302 	return (pd->coord->max + e - 1) / e;
    303 }
    304 
    305 static void
    306 ptable_update(PruneData *pd, uint64_t ind, int n)
    307 {
    308 	int sh;
    309 	entry_group_t f, mask;
    310 	uint64_t i, e, b;
    311 
    312 	e = pd->compact ? ENTRIES_PER_GROUP_COMPACT : ENTRIES_PER_GROUP;
    313 	b = pd->compact ? 2 : 4;
    314 	f = pd->compact ? 3 : 15;
    315 
    316 	sh = b * (ind % e);
    317 	mask = f << sh;
    318 	i = ind / e;
    319 
    320 	pd->ptable[i] &= ~mask;
    321 	pd->ptable[i] |= (((entry_group_t)n) & f) << sh;
    322 }
    323 
    324 int
    325 ptableval(PruneData *pd, uint64_t ind)
    326 {
    327 	int sh;
    328 	uint64_t e;
    329 	entry_group_t m;
    330 
    331 	if (pd->compact) {
    332 		e  = ENTRIES_PER_GROUP_COMPACT;
    333 		m  = 3;
    334 		sh = (ind % e) * 2;
    335 	} else {
    336 		e  = ENTRIES_PER_GROUP;
    337 		m  = 15;
    338 		sh = (ind % e) * 4;
    339 	}
    340 
    341 	return (pd->ptable[ind/e] & (m << sh)) >> sh;
    342 }
    343 
    344 static bool
    345 read_ptable_file(PruneData *pd)
    346 {
    347 	init_env();
    348 
    349 	FILE *f;
    350 	char fname[strlen(tabledir)+256];
    351 	int i;
    352 	uint64_t r;
    353 
    354 	strcpy(fname, tabledir);
    355 	strcat(fname, "/pt_");
    356 	strcat(fname, pd->coord->name);
    357 	strcat(fname, "_");
    358 	strcat(fname, pd->moveset->name);
    359 
    360 	if ((f = fopen(fname, "rb")) == NULL)
    361 		return false;
    362 
    363 	r = fread(&(pd->base), sizeof(int), 1, f);
    364 	for (i = 0; i < 16; i++)
    365 		r += fread(&(pd->count[i]), sizeof(uint64_t), 1, f);
    366 	r += fread(pd->ptable, sizeof(entry_group_t), ptablesize(pd), f);
    367 
    368 	fclose(f);
    369 
    370 	return r == 17 + ptablesize(pd);
    371 }
    372 
    373 static bool
    374 write_ptable_file(PruneData *pd)
    375 {
    376 	init_env();
    377 
    378 	FILE *f;
    379 	char fname[strlen(tabledir)+256];
    380 	int i;
    381 	uint64_t w;
    382 
    383 	strcpy(fname, tabledir);
    384 	strcat(fname, "/pt_");
    385 	strcat(fname, pd->coord->name);
    386 	strcat(fname, "_");
    387 	strcat(fname, pd->moveset->name);
    388 
    389 	if ((f = fopen(fname, "wb")) == NULL)
    390 		return false;
    391 
    392 	w = fwrite(&(pd->base), sizeof(int), 1, f);
    393 	for (i = 0; i < 16; i++)
    394 		w += fwrite(&(pd->count[i]), sizeof(uint64_t), 1, f);
    395 	w += fwrite(pd->ptable, sizeof(entry_group_t), ptablesize(pd), f);
    396 	fclose(f);
    397 
    398 	return w == 17 + ptablesize(pd);
    399 }
    400