pruning.c (9307B)
1 #define PRUNING_C 2 3 #include "pruning.h" 4 5 #define ENTRIES_PER_GROUP (2*sizeof(entry_group_t)) 6 #define ENTRIES_PER_GROUP_COMPACT (4*sizeof(entry_group_t)) 7 8 static int findchunk(PruneData *pd, int nchunks, uint64_t i); 9 static void genptable_bfs(PruneData *pd, int d, int nt, int nc); 10 static void genptable_fixnasty(PruneData *pd, int d, int nthreads); 11 static void * instance_bfs(void *arg); 12 static void * instance_fixnasty(void *arg); 13 static void ptable_update(PruneData *pd, uint64_t ind, int m); 14 static bool read_ptable_file(PruneData *pd); 15 static bool write_ptable_file(PruneData *pd); 16 17 PruneData *active_pd[256]; 18 19 int 20 findchunk(PruneData *pd, int nchunks, uint64_t i) 21 { 22 uint64_t chunksize; 23 24 chunksize = pd->coord->max / (uint64_t)nchunks; 25 chunksize += ENTRIES_PER_GROUP - (chunksize % ENTRIES_PER_GROUP); 26 27 return MIN(nchunks-1, (int)(i / chunksize)); 28 } 29 30 PruneData * 31 genptable(PruneData *pd, int nthreads) 32 { 33 int d, nchunks, i, maxv; 34 uint64_t oldn; 35 36 for (i = 0; active_pd[i] != NULL; i++) { 37 if (active_pd[i]->coord == pd->coord && 38 active_pd[i]->moveset == pd->moveset && 39 active_pd[i]->compact == pd->compact) 40 return active_pd[i]; 41 } 42 43 init_moveset(pd->moveset); 44 gen_coord(pd->coord); 45 46 pd->ptable = malloc(ptablesize(pd) * sizeof(entry_group_t)); 47 48 if (read_ptable_file(pd)) 49 goto genptable_done; 50 51 if (nthreads < 4) { 52 fprintf(stderr, 53 "--- Warning ---\n" 54 "You are using only %d threads to generate the pruning" 55 "tables. This can take a while.\n" 56 "Unless you did this intentionally, you should re-run" 57 "this command with `-t 4' or more.\n" 58 "---------------\n\n", nthreads 59 ); 60 } 61 62 nchunks = MIN(ptablesize(pd), 100000); 63 fprintf(stderr, "Generating pt_%s_%s with %d threads\n", 64 pd->coord->name, pd->moveset->name, nthreads); 65 66 memset(pd->ptable, ~(uint8_t)0, ptablesize(pd)*sizeof(entry_group_t)); 67 for (i = 0; i < 16; i++) 68 pd->count[i] = 0; 69 70 ptable_update(pd, 0, 0); 71 pd->n = 1; 72 oldn = 0; 73 genptable_fixnasty(pd, 0, nthreads); 74 fprintf(stderr, "Depth %d done, generated %" 75 PRIu64 "\t(%" PRIu64 "/%" PRIu64 ")\n", 76 0, pd->n - oldn, pd->n, pd->coord->max); 77 oldn = pd->n; 78 pd->count[0] = pd->n; 79 80 maxv = pd->compact ? MIN(15, pd->base + 4) : 15; 81 for (d = 0; d < maxv && pd->n < pd->coord->max; d++) { 82 genptable_bfs(pd, d, nthreads, nchunks); 83 genptable_fixnasty(pd, d+1, nthreads); 84 fprintf(stderr, "Depth %d done, generated %" 85 PRIu64 "\t(%" PRIu64 "/%" PRIu64 ")\n", 86 d+1, pd->n - oldn, pd->n, pd->coord->max); 87 pd->count[d+1] = pd->n - oldn; 88 oldn = pd->n; 89 } 90 if (pd->compact) 91 fprintf(stderr, "Compact table, values above " 92 "%d are inaccurate.\n", maxv-1); 93 fprintf(stderr, "Pruning table generated!\n"); 94 95 if (!write_ptable_file(pd)) 96 fprintf(stderr, "Error writing ptable file\n"); 97 98 genptable_done: 99 for (i = 0; active_pd[i] != NULL; i++); 100 return active_pd[i] = pd; 101 } 102 103 static void 104 genptable_bfs(PruneData *pd, int d, int nthreads, int nchunks) 105 { 106 int i; 107 pthread_t t[nthreads]; 108 ThreadDataGenpt td[nthreads]; 109 pthread_mutex_t *mtx[nchunks], *upmtx; 110 111 upmtx = malloc(sizeof(pthread_mutex_t)); 112 pthread_mutex_init(upmtx, NULL); 113 for (i = 0; i < nchunks; i++) { 114 mtx[i] = malloc(sizeof(pthread_mutex_t)); 115 pthread_mutex_init(mtx[i], NULL); 116 } 117 118 for (i = 0; i < nthreads; i++) { 119 td[i].thid = i; 120 td[i].nthreads = nthreads; 121 td[i].pd = pd; 122 td[i].d = d; 123 td[i].nchunks = nchunks; 124 td[i].mutex = mtx; 125 td[i].upmutex = upmtx; 126 pthread_create(&t[i], NULL, instance_bfs, &td[i]); 127 } 128 129 for (i = 0; i < nthreads; i++) 130 pthread_join(t[i], NULL); 131 132 free(upmtx); 133 for (i = 0; i < nchunks; i++) 134 free(mtx[i]); 135 } 136 137 static void 138 genptable_fixnasty(PruneData *pd, int d, int nthreads) 139 { 140 int i; 141 pthread_t t[nthreads]; 142 ThreadDataGenpt td[nthreads]; 143 pthread_mutex_t *upmtx; 144 145 if (pd->coord->type != SYMCOMP_COORD) 146 return; 147 148 upmtx = malloc(sizeof(pthread_mutex_t)); 149 pthread_mutex_init(upmtx, NULL); 150 for (i = 0; i < nthreads; i++) { 151 td[i].thid = i; 152 td[i].nthreads = nthreads; 153 td[i].pd = pd; 154 td[i].d = d; 155 td[i].upmutex = upmtx; 156 pthread_create(&t[i], NULL, instance_fixnasty, &td[i]); 157 } 158 159 for (i = 0; i < nthreads; i++) 160 pthread_join(t[i], NULL); 161 162 free(upmtx); 163 } 164 165 static void * 166 instance_bfs(void *arg) 167 { 168 ThreadDataGenpt *td; 169 uint64_t i, ii, blocksize, rmin, rmax, updated; 170 int j, pval, ichunk, oldc, newc; 171 Move *ms; 172 173 td = (ThreadDataGenpt *)arg; 174 ms = td->pd->moveset->sorted_moves; 175 blocksize = td->pd->coord->max / (uint64_t)td->nthreads; 176 rmin = ((uint64_t)td->thid) * blocksize; 177 rmax = td->thid == td->nthreads - 1 ? 178 td->pd->coord->max : 179 ((uint64_t)td->thid + 1) * blocksize; 180 181 if (td->pd->compact) { 182 if (td->d <= td->pd->base) { 183 oldc = 1; 184 newc = 1; 185 } else { 186 oldc = td->d - td->pd->base; 187 newc = td->d - td->pd->base; 188 } 189 } else { 190 oldc = td->d; 191 newc = td->d + 1; 192 } 193 194 updated = 0; 195 for (i = rmin; i < rmax; i++) { 196 ichunk = findchunk(td->pd, td->nchunks, i); 197 pthread_mutex_lock(td->mutex[ichunk]); 198 pval = ptableval(td->pd, i); 199 pthread_mutex_unlock(td->mutex[ichunk]); 200 if (pval == oldc) { 201 for (j = 0; ms[j] != NULLMOVE; j++) { 202 ii = move_coord(td->pd->coord, ms[j], i, NULL); 203 ichunk = findchunk(td->pd, td->nchunks, ii); 204 pthread_mutex_lock(td->mutex[ichunk]); 205 pval = ptableval(td->pd, ii); 206 if (pval > newc) { 207 ptable_update(td->pd, ii, newc); 208 updated++; 209 } 210 pthread_mutex_unlock(td->mutex[ichunk]); 211 } 212 if (td->pd->compact && td->d <= td->pd->base) { 213 ichunk = findchunk(td->pd, td->nchunks, i); 214 pthread_mutex_lock(td->mutex[ichunk]); 215 ptable_update(td->pd, i, 0); 216 pthread_mutex_unlock(td->mutex[ichunk]); 217 } 218 } 219 } 220 221 pthread_mutex_lock(td->upmutex); 222 td->pd->n += updated; 223 pthread_mutex_unlock(td->upmutex); 224 225 return NULL; 226 } 227 228 static void * 229 instance_fixnasty(void *arg) 230 { 231 ThreadDataGenpt *td; 232 uint64_t i, ii, blocksize, rmin, rmax, updated, ss, M; 233 int j, oldc; 234 Trans t; 235 236 td = (ThreadDataGenpt *)arg; 237 238 /* We know type = SYMCOMP_COORD */ 239 M = td->pd->coord->base[1]->max; 240 blocksize = (td->pd->coord->base[0]->max / td->nthreads) * M; 241 rmin = ((uint64_t)td->thid) * blocksize; 242 rmax = td->thid == td->nthreads - 1 ? 243 td->pd->coord->max : 244 ((uint64_t)td->thid + 1) * blocksize; 245 246 if (td->pd->compact) { 247 if (td->d <= td->pd->base) 248 oldc = 1; 249 else 250 oldc = td->d - td->pd->base; 251 } else { 252 oldc = td->d; 253 } 254 255 updated = 0; 256 for (i = rmin; i < rmax; i++) { 257 if (ptableval(td->pd, i) == oldc) { 258 ss = td->pd->coord->base[0]->selfsim[i/M]; 259 for (j = 0; j < td->pd->coord->base[0]->tgrp->n; j++) { 260 t = td->pd->coord->base[0]->tgrp->t[j]; 261 if (t == uf || !(ss & ((uint64_t)1<<t))) 262 continue; 263 ii = trans_coord(td->pd->coord, t, i); 264 if (ptableval(td->pd, ii) > oldc) { 265 ptable_update(td->pd, ii, oldc); 266 updated++; 267 } 268 } 269 } 270 } 271 272 pthread_mutex_lock(td->upmutex); 273 td->pd->n += updated; 274 pthread_mutex_unlock(td->upmutex); 275 276 return NULL; 277 } 278 279 void 280 print_ptable(PruneData *pd) 281 { 282 uint64_t i; 283 284 printf("Table %s_%s\n", pd->coord->name, pd->moveset->name); 285 286 if (pd->compact) { 287 printf("Compract table with base value: %d\n", pd->base); 288 printf("Values above %d are inaccurate.\n", pd->base + 3); 289 } 290 291 for (i = 0; i < 16; i++) 292 printf("%2" PRIu64 "\t%10" PRIu64 "\n", i, pd->count[i]); 293 } 294 295 uint64_t 296 ptablesize(PruneData *pd) 297 { 298 uint64_t e; 299 300 e = pd->compact ? ENTRIES_PER_GROUP_COMPACT : ENTRIES_PER_GROUP; 301 302 return (pd->coord->max + e - 1) / e; 303 } 304 305 static void 306 ptable_update(PruneData *pd, uint64_t ind, int n) 307 { 308 int sh; 309 entry_group_t f, mask; 310 uint64_t i, e, b; 311 312 e = pd->compact ? ENTRIES_PER_GROUP_COMPACT : ENTRIES_PER_GROUP; 313 b = pd->compact ? 2 : 4; 314 f = pd->compact ? 3 : 15; 315 316 sh = b * (ind % e); 317 mask = f << sh; 318 i = ind / e; 319 320 pd->ptable[i] &= ~mask; 321 pd->ptable[i] |= (((entry_group_t)n) & f) << sh; 322 } 323 324 int 325 ptableval(PruneData *pd, uint64_t ind) 326 { 327 int sh; 328 uint64_t e; 329 entry_group_t m; 330 331 if (pd->compact) { 332 e = ENTRIES_PER_GROUP_COMPACT; 333 m = 3; 334 sh = (ind % e) * 2; 335 } else { 336 e = ENTRIES_PER_GROUP; 337 m = 15; 338 sh = (ind % e) * 4; 339 } 340 341 return (pd->ptable[ind/e] & (m << sh)) >> sh; 342 } 343 344 static bool 345 read_ptable_file(PruneData *pd) 346 { 347 init_env(); 348 349 FILE *f; 350 char fname[strlen(tabledir)+256]; 351 int i; 352 uint64_t r; 353 354 strcpy(fname, tabledir); 355 strcat(fname, "/pt_"); 356 strcat(fname, pd->coord->name); 357 strcat(fname, "_"); 358 strcat(fname, pd->moveset->name); 359 360 if ((f = fopen(fname, "rb")) == NULL) 361 return false; 362 363 r = fread(&(pd->base), sizeof(int), 1, f); 364 for (i = 0; i < 16; i++) 365 r += fread(&(pd->count[i]), sizeof(uint64_t), 1, f); 366 r += fread(pd->ptable, sizeof(entry_group_t), ptablesize(pd), f); 367 368 fclose(f); 369 370 return r == 17 + ptablesize(pd); 371 } 372 373 static bool 374 write_ptable_file(PruneData *pd) 375 { 376 init_env(); 377 378 FILE *f; 379 char fname[strlen(tabledir)+256]; 380 int i; 381 uint64_t w; 382 383 strcpy(fname, tabledir); 384 strcat(fname, "/pt_"); 385 strcat(fname, pd->coord->name); 386 strcat(fname, "_"); 387 strcat(fname, pd->moveset->name); 388 389 if ((f = fopen(fname, "wb")) == NULL) 390 return false; 391 392 w = fwrite(&(pd->base), sizeof(int), 1, f); 393 for (i = 0; i < 16; i++) 394 w += fwrite(&(pd->count[i]), sizeof(uint64_t), 1, f); 395 w += fwrite(pd->ptable, sizeof(entry_group_t), ptablesize(pd), f); 396 fclose(f); 397 398 return w == 17 + ptablesize(pd); 399 } 400