pruning.c (11922B)
1 #include "pruning.h" 2 3 #define ENTRIES_PER_GROUP (2*sizeof(entry_group_t)) 4 #define ENTRIES_PER_GROUP_COMPACT (4*sizeof(entry_group_t)) 5 6 static int findchunk(PruneData *pd, int nchunks, uint64_t i); 7 static void genptable_bfs(PruneData *pd, int d, int nt, int nc); 8 static void genptable_compress(PruneData *pd); 9 static void genptable_fixnasty(PruneData *pd, int d, int nthreads); 10 static void genptable_setbase(PruneData *pd); 11 static void * instance_bfs(void *arg); 12 static void * instance_fixnasty(void *arg); 13 static void ptable_update(PruneData *pd, Cube cube, int m); 14 static void ptable_update_index(PruneData *pd, uint64_t ind, int m); 15 static int ptableval_index(PruneData *pd, uint64_t ind); 16 static bool read_ptable_file(PruneData *pd); 17 static bool write_ptable_file(PruneData *pd); 18 19 PruneData 20 pd_eofb_HTM = { 21 .filename = "pt_eofb_HTM", 22 .coord = &coord_eofb, 23 .moveset = &moveset_HTM, 24 }; 25 26 PruneData 27 pd_coud_HTM = { 28 .filename = "pt_coud_HTM", 29 .coord = &coord_coud, 30 .moveset = &moveset_HTM, 31 }; 32 33 PruneData 34 pd_cornershtr_HTM = { 35 .filename = "pt_cornershtr_HTM", 36 .coord = &coord_cornershtr, 37 .moveset = &moveset_HTM, 38 }; 39 40 PruneData 41 pd_corners_HTM = { 42 .filename = "pt_corners_HTM", 43 .coord = &coord_corners, 44 .moveset = &moveset_HTM, 45 }; 46 47 PruneData 48 pd_drud_sym16_HTM = { 49 .filename = "pt_drud_sym16_HTM", 50 .coord = &coord_drud_sym16, 51 .moveset = &moveset_HTM, 52 }; 53 54 PruneData 55 pd_drud_eofb = { 56 .filename = "pt_drud_eofb", 57 .coord = &coord_drud_eofb, 58 .moveset = &moveset_eofb, 59 }; 60 61 PruneData 62 pd_drudfin_noE_sym16_drud = { 63 .filename = "pt_drudfin_noE_sym16_drud", 64 .coord = &coord_drudfin_noE_sym16, 65 .moveset = &moveset_drud, 66 }; 67 68 PruneData 69 pd_htr_drud = { 70 .filename = "pt_htr_drud", 71 .coord = &coord_htr_drud, 72 .moveset = &moveset_drud, 73 }; 74 75 PruneData 76 pd_cp_drud = { 77 .filename = "pt_cp_drud", 78 .coord = &coord_cp, 79 .moveset = &moveset_drud, 80 }; 81 82 PruneData 83 pd_htrfin_htr = { 84 .filename = "pt_htrfin_htr", 85 .coord = &coord_htrfin, 86 .moveset = &moveset_htr, 87 }; 88 89 PruneData 90 pd_nxopt31_HTM = { 91 .filename = "pt_nxopt31_HTM", 92 .coord = &coord_nxopt31, 93 .moveset = &moveset_HTM, 94 95 .compact = true, 96 .fallback = &pd_drud_sym16_HTM, 97 .fbmod = BINOM8ON4, 98 }; 99 100 PruneData * all_pd[] = { 101 &pd_eofb_HTM, 102 &pd_coud_HTM, 103 &pd_cornershtr_HTM, 104 &pd_corners_HTM, 105 &pd_drud_sym16_HTM, 106 &pd_drud_eofb, 107 &pd_drudfin_noE_sym16_drud, 108 &pd_htr_drud, 109 &pd_cp_drud, 110 &pd_htrfin_htr, 111 &pd_nxopt31_HTM, 112 NULL 113 }; 114 115 /* Functions *****************************************************************/ 116 117 int 118 findchunk(PruneData *pd, int nchunks, uint64_t i) 119 { 120 uint64_t chunksize; 121 122 chunksize = pd->coord->max / (uint64_t)nchunks; 123 chunksize += ENTRIES_PER_GROUP - (chunksize % ENTRIES_PER_GROUP); 124 125 return MIN(nchunks-1, (int)(i / chunksize)); 126 } 127 128 void 129 free_pd(PruneData *pd) 130 { 131 if (pd->generated) 132 free(pd->ptable); 133 134 pd->generated = false; 135 } 136 137 void 138 genptable(PruneData *pd, int nthreads) 139 { 140 bool compact; 141 int d, nchunks; 142 uint64_t oldn, sz; 143 144 if (pd->generated) 145 return; 146 147 /* TODO: check if memory is enough, otherwise maybe exit gracefully? */ 148 sz = ptablesize(pd) * (pd->compact ? 2 : 1); 149 pd->ptable = malloc(sz * sizeof(entry_group_t)); 150 151 if (read_ptable_file(pd)) { 152 pd->generated = true; 153 return; 154 } 155 156 if (nthreads < 4) { 157 fprintf(stderr, 158 "--- Warning ---\n" 159 "You are using only %d threads to generate the pruning" 160 "tables. This can take a while." 161 "Unless you did this intentionally, you should re-run" 162 "this command with `-t 4' or more.\n" 163 "---------------\n\n", nthreads 164 ); 165 } 166 167 168 /* For the first steps we proceed the same way for compact and not */ 169 compact = pd->compact; 170 pd->compact = false; 171 pd->generated = true; 172 173 nchunks = MIN(ptablesize(pd), 100000); 174 fprintf(stderr, "Cannot load %s, generating it with %d threads\n", 175 pd->filename, nthreads); 176 177 178 memset(pd->ptable, ~(uint8_t)0, ptablesize(pd)*sizeof(entry_group_t)); 179 180 ptable_update(pd, (Cube){0}, 0); 181 pd->n = 1; 182 oldn = 0; 183 genptable_fixnasty(pd, 0, nthreads); 184 fprintf(stderr, "Depth %d done, generated %" 185 PRIu64 "\t(%" PRIu64 "/%" PRIu64 ")\n", 186 0, pd->n - oldn, pd->n, pd->coord->max); 187 oldn = pd->n; 188 pd->count[0] = pd->n; 189 for (d = 0; d < 15 && pd->n < pd->coord->max; d++) { 190 genptable_bfs(pd, d, nthreads, nchunks); 191 genptable_fixnasty(pd, d+1, nthreads); 192 fprintf(stderr, "Depth %d done, generated %" 193 PRIu64 "\t(%" PRIu64 "/%" PRIu64 ")\n", 194 d+1, pd->n - oldn, pd->n, pd->coord->max); 195 pd->count[d+1] = pd->n - oldn; 196 oldn = pd->n; 197 } 198 fprintf(stderr, "Pruning table generated!\n"); 199 200 genptable_setbase(pd); 201 if (compact) 202 genptable_compress(pd); 203 204 if (!write_ptable_file(pd)) 205 fprintf(stderr, "Error writing ptable file\n"); 206 } 207 208 static void 209 genptable_bfs(PruneData *pd, int d, int nthreads, int nchunks) 210 { 211 int i; 212 pthread_t t[nthreads]; 213 ThreadDataGenpt td[nthreads]; 214 pthread_mutex_t *mtx[nchunks], *upmtx; 215 216 upmtx = malloc(sizeof(pthread_mutex_t)); 217 pthread_mutex_init(upmtx, NULL); 218 for (i = 0; i < nchunks; i++) { 219 mtx[i] = malloc(sizeof(pthread_mutex_t)); 220 pthread_mutex_init(mtx[i], NULL); 221 } 222 223 for (i = 0; i < nthreads; i++) { 224 td[i].thid = i; 225 td[i].nthreads = nthreads; 226 td[i].pd = pd; 227 td[i].d = d; 228 td[i].nchunks = nchunks; 229 td[i].mutex = mtx; 230 td[i].upmutex = upmtx; 231 pthread_create(&t[i], NULL, instance_bfs, &td[i]); 232 } 233 234 for (i = 0; i < nthreads; i++) 235 pthread_join(t[i], NULL); 236 237 free(upmtx); 238 for (i = 0; i < nchunks; i++) 239 free(mtx[i]); 240 } 241 242 static void 243 genptable_compress(PruneData *pd) 244 { 245 int val; 246 uint64_t i, j; 247 entry_group_t mask, v; 248 249 fprintf(stderr, "Compressing table to 2 bits per entry\n"); 250 251 for (i = 0; i < pd->coord->max; i += ENTRIES_PER_GROUP_COMPACT) { 252 mask = (entry_group_t)0; 253 for (j = 0; j < ENTRIES_PER_GROUP_COMPACT; j++) { 254 if (i+j >= pd->coord->max) 255 break; 256 val = ptableval_index(pd, i+j) - pd->base; 257 v = (entry_group_t)MIN(3, MAX(0, val)); 258 mask |= v << (2*j); 259 } 260 pd->ptable[i/ENTRIES_PER_GROUP_COMPACT] = mask; 261 } 262 263 pd->compact = true; 264 pd->ptable = realloc(pd->ptable, sizeof(entry_group_t)*ptablesize(pd)); 265 } 266 267 static void 268 genptable_fixnasty(PruneData *pd, int d, int nthreads) 269 { 270 int i; 271 pthread_t t[nthreads]; 272 ThreadDataGenpt td[nthreads]; 273 pthread_mutex_t *upmtx; 274 275 if (pd->coord->tfind == NULL) 276 return; 277 278 upmtx = malloc(sizeof(pthread_mutex_t)); 279 pthread_mutex_init(upmtx, NULL); 280 for (i = 0; i < nthreads; i++) { 281 td[i].thid = i; 282 td[i].nthreads = nthreads; 283 td[i].pd = pd; 284 td[i].d = d; 285 td[i].upmutex = upmtx; 286 pthread_create(&t[i], NULL, instance_fixnasty, &td[i]); 287 } 288 289 for (i = 0; i < nthreads; i++) 290 pthread_join(t[i], NULL); 291 292 free(upmtx); 293 } 294 295 static void 296 genptable_setbase(PruneData *pd) 297 { 298 int i; 299 uint64_t sum, newsum; 300 301 pd->base = 0; 302 sum = pd->count[0] + pd->count[1] + pd->count[2]; 303 for (i = 3; i < 16; i++) { 304 newsum = sum + pd->count[i] - pd->count[i-3]; 305 if (newsum > sum) 306 pd->base = i-3; 307 sum = newsum; 308 } 309 } 310 311 static void * 312 instance_bfs(void *arg) 313 { 314 ThreadDataGenpt *td; 315 uint64_t i, ii, blocksize, rmin, rmax, updated; 316 int j, pval, ichunk; 317 Move *ms; 318 319 td = (ThreadDataGenpt *)arg; 320 ms = td->pd->moveset->sorted_moves; 321 blocksize = td->pd->coord->max / (uint64_t)td->nthreads; 322 rmin = ((uint64_t)td->thid) * blocksize; 323 rmax = td->thid == td->nthreads - 1 ? 324 td->pd->coord->max : 325 ((uint64_t)td->thid + 1) * blocksize; 326 327 updated = 0; 328 for (i = rmin; i < rmax; i++) { 329 ichunk = findchunk(td->pd, td->nchunks, i); 330 pthread_mutex_lock(td->mutex[ichunk]); 331 pval = ptableval_index(td->pd, i); 332 pthread_mutex_unlock(td->mutex[ichunk]); 333 if (pval == td->d) { 334 for (j = 0; ms[j] != NULLMOVE; j++) { 335 ii = td->pd->coord->move(ms[j], i); 336 ichunk = findchunk(td->pd, td->nchunks, ii); 337 pthread_mutex_lock(td->mutex[ichunk]); 338 pval = ptableval_index(td->pd, ii); 339 if (pval > td->d+1) { 340 ptable_update_index(td->pd, 341 ii, td->d+1); 342 updated++; 343 } 344 pthread_mutex_unlock(td->mutex[ichunk]); 345 } 346 } 347 } 348 349 pthread_mutex_lock(td->upmutex); 350 td->pd->n += updated; 351 pthread_mutex_unlock(td->upmutex); 352 353 return NULL; 354 } 355 356 static void * 357 instance_fixnasty(void *arg) 358 { 359 ThreadDataGenpt *td; 360 uint64_t i, ii, nb, blocksize, rmin, rmax, updated; 361 int j, n; 362 Trans t, aux[NTRANS]; 363 364 td = (ThreadDataGenpt *)arg; 365 nb = td->pd->coord->max / td->pd->coord->base->max; 366 blocksize = (td->pd->coord->base->max / td->nthreads) * nb; 367 rmin = ((uint64_t)td->thid) * blocksize; 368 rmax = td->thid == td->nthreads - 1 ? 369 td->pd->coord->max : 370 ((uint64_t)td->thid + 1) * blocksize; 371 372 updated = 0; 373 for (i = rmin; i < rmax; i++) { 374 if (ptableval_index(td->pd, i) == td->d) { 375 if ((n = td->pd->coord->tfind(i, aux)) == 1) 376 continue; 377 378 for (j = 0; j < n; j++) { 379 if ((t = aux[j]) == uf) 380 continue; 381 ii = td->pd->coord->transform(t, i); 382 if (ptableval_index(td->pd, ii) > td->d) { 383 ptable_update_index(td->pd, ii, td->d); 384 updated++; 385 } 386 } 387 } 388 } 389 390 pthread_mutex_lock(td->upmutex); 391 td->pd->n += updated; 392 pthread_mutex_unlock(td->upmutex); 393 394 return NULL; 395 } 396 397 void 398 print_ptable(PruneData *pd) 399 { 400 uint64_t i; 401 402 if (!pd->generated) 403 genptable(pd, 1); /* TODO: set default nthreads somewhere */ 404 405 printf("Table %s\n", pd->filename); 406 printf("Base value: %d\n", pd->base); 407 for (i = 0; i < 16; i++) 408 if (pd->count[i] != 0) 409 printf("%2" PRIu64 "\t%10" PRIu64 "\n", 410 i, pd->count[i]); 411 printf("Total: %" PRIu64 "\n", pd->coord->max); 412 } 413 414 uint64_t 415 ptablesize(PruneData *pd) 416 { 417 uint64_t e; 418 419 e = pd->compact ? ENTRIES_PER_GROUP_COMPACT : ENTRIES_PER_GROUP; 420 421 return (pd->coord->max + e - 1) / e; 422 } 423 424 static void 425 ptable_update(PruneData *pd, Cube cube, int n) 426 { 427 ptable_update_index(pd, pd->coord->index(cube), n); 428 } 429 430 static void 431 ptable_update_index(PruneData *pd, uint64_t ind, int n) 432 { 433 int sh; 434 entry_group_t mask; 435 uint64_t i; 436 437 sh = 4 * (ind % ENTRIES_PER_GROUP); 438 mask = ((entry_group_t)15) << sh; 439 i = ind/ENTRIES_PER_GROUP; 440 441 pd->ptable[i] &= ~mask; 442 pd->ptable[i] |= (((entry_group_t)n)&15) << sh; 443 } 444 445 int 446 ptableval(PruneData *pd, Cube cube) 447 { 448 return ptableval_index(pd, pd->coord->index(cube)); 449 } 450 451 static int 452 ptableval_index(PruneData *pd, uint64_t ind) 453 { 454 int sh, ret; 455 entry_group_t mask; 456 uint64_t i, e; 457 entry_group_t m; 458 459 if (!pd->generated) { 460 fprintf(stderr, "Warning: request pruning table value" 461 " for uninitialized table %s.\n It's fine, but it" 462 " should not happen. Please report bug.\n", 463 pd->filename); 464 genptable(pd, 1); /* TODO: set default or remove this case */ 465 } 466 467 if (pd->compact) { 468 e = ENTRIES_PER_GROUP_COMPACT; 469 m = 3; 470 sh = (ind % e) * 2; 471 } else { 472 e = ENTRIES_PER_GROUP; 473 m = 15; 474 sh = (ind % e) * 4; 475 } 476 477 mask = m << sh; 478 i = ind/e; 479 480 ret = (pd->ptable[i] & mask) >> sh; 481 482 if (pd->compact) { 483 if (ret) 484 ret += pd->base; 485 else 486 ret = ptableval_index(pd->fallback, ind / pd->fbmod); 487 } 488 489 return ret; 490 } 491 492 static bool 493 read_ptable_file(PruneData *pd) 494 { 495 init_env(); 496 497 FILE *f; 498 char fname[strlen(tabledir)+100]; 499 int i; 500 uint64_t r; 501 502 strcpy(fname, tabledir); 503 strcat(fname, "/"); 504 strcat(fname, pd->filename); 505 506 if ((f = fopen(fname, "rb")) == NULL) 507 return false; 508 509 r = fread(&(pd->base), sizeof(int), 1, f); 510 for (i = 0; i < 16; i++) 511 r += fread(&(pd->count[i]), sizeof(uint64_t), 1, f); 512 r += fread(pd->ptable, sizeof(entry_group_t), ptablesize(pd), f); 513 514 fclose(f); 515 516 return r == 17 + ptablesize(pd); 517 } 518 519 static bool 520 write_ptable_file(PruneData *pd) 521 { 522 init_env(); 523 524 FILE *f; 525 char fname[strlen(tabledir)+100]; 526 int i; 527 uint64_t w; 528 529 strcpy(fname, tabledir); 530 strcat(fname, "/"); 531 strcat(fname, pd->filename); 532 533 if ((f = fopen(fname, "wb")) == NULL) 534 return false; 535 536 w = fwrite(&(pd->base), sizeof(int), 1, f); 537 for (i = 0; i < 16; i++) 538 w += fwrite(&(pd->count[i]), sizeof(uint64_t), 1, f); 539 w += fwrite(pd->ptable, sizeof(entry_group_t), ptablesize(pd), f); 540 fclose(f); 541 542 return w == 17 + ptablesize(pd); 543 } 544