threader_eager.c (4054B)
1 #include <pthread.h> 2 #include "threader_eager.h" 3 4 typedef struct { 5 AlgList * sols; 6 pthread_mutex_t * sols_mutex; 7 } ThreadData; 8 9 typedef struct { 10 DfsArg * arg; 11 Solver * solver; 12 Threader * threader; 13 AlgList * starts; 14 AlgListNode ** node; 15 pthread_mutex_t * start_mutex; 16 } ThreadInitData; 17 18 static void append_sol(Alg *, void *); 19 static void * instance_thread(void *); 20 static void dispatch(DfsArg *, AlgList *, Solver *, Threader *); 21 static AlgList * possible_starts(DfsArg *, Solver *); 22 static int get_nsol(void *); 23 24 Threader threader_eager = { 25 .append_sol = append_sol, 26 .dispatch = dispatch, 27 .get_nsol = get_nsol, 28 }; 29 30 static void 31 append_sol(Alg *alg, void *threaddata) 32 { 33 ThreadData *td = (ThreadData *)threaddata; 34 35 pthread_mutex_lock(td->sols_mutex); 36 append_alg(td->sols, alg); 37 pthread_mutex_unlock(td->sols_mutex); 38 } 39 40 static AlgList * 41 possible_starts(DfsArg *arg, Solver *solver) 42 { 43 AlgList *ret = new_alglist(); 44 45 if (solver->is_solved(solver->param, arg->cubedata)) { 46 if (arg->opts->min_moves == 0 && arg->d == 0) 47 append_sol(new_alg(""), arg->threaddata); 48 return ret; 49 } 50 51 for (int i = 0; solver->moveset->sorted_moves[i] != NULLMOVE; i++) { 52 Move m = solver->moveset->sorted_moves[i]; 53 Alg *alg = new_alg(""); 54 append_move(alg, m, false); 55 append_alg(ret, alg); 56 free_alg(alg); 57 58 /* TODO: check if step not final */ 59 if (arg->opts->can_niss) { 60 alg = new_alg(""); 61 append_move(alg, m, true); 62 append_alg(ret, alg); 63 free_alg(alg); 64 } 65 } 66 67 return ret; 68 } 69 70 static void * 71 instance_thread(void *arg) 72 { 73 ThreadInitData *tid = (ThreadInitData *)arg; 74 75 while (true) { 76 pthread_mutex_lock(tid->start_mutex); 77 AlgListNode *node = *(tid->node); 78 if (node == NULL) { 79 pthread_mutex_unlock(tid->start_mutex); 80 break; 81 } 82 *(tid->node) = (*(tid->node))->next; 83 pthread_mutex_unlock(tid->start_mutex); 84 85 /* TODO: adjust for longer (arbitrarily long?) starting sequences */ 86 void *data = tid->solver->alloc_cubedata(tid->solver->param); 87 tid->solver->copy_cubedata( 88 tid->solver->param, tid->arg->cubedata, data); 89 bool inv = node->alg->inv[node->alg->len-1]; 90 if (inv) 91 tid->solver->invert_cube( 92 tid->solver->param, data); 93 tid->solver->apply_move( 94 tid->solver->param, data, node->alg->move[0]); 95 96 DfsArg newarg; 97 newarg.cubedata = data; 98 newarg.threaddata = tid->arg->threaddata; 99 newarg.opts = tid->arg->opts; 100 newarg.d = tid->arg->d; 101 newarg.niss = inv; 102 newarg.current_alg = new_alg(""); 103 copy_alg(node->alg, newarg.current_alg); 104 105 dfs(&newarg, tid->solver, tid->threader); 106 107 tid->solver->free_cubedata(tid->solver->param, data); 108 free_alg(newarg.current_alg); 109 } 110 111 return NULL; 112 } 113 114 static void 115 dispatch(DfsArg *arg, AlgList *sols, Solver *solver, Threader *threader) 116 { 117 int nthreads = arg->opts->nthreads; 118 ThreadInitData tid[nthreads]; 119 pthread_t t[nthreads]; 120 121 pthread_mutex_t *sols_mutex = malloc(sizeof(pthread_mutex_t)); 122 pthread_mutex_init(sols_mutex, NULL); 123 124 arg->threaddata = malloc(sizeof(ThreadData)); 125 ThreadData *td = (ThreadData *)arg->threaddata; 126 td->sols = sols; 127 td->sols_mutex = sols_mutex; 128 129 AlgList *starts = possible_starts(arg, solver); 130 AlgListNode *node = starts->first; 131 pthread_mutex_t *start_mutex = malloc(sizeof(pthread_mutex_t)); 132 pthread_mutex_init(start_mutex, NULL); 133 for (int i = 0; i < nthreads; i++) { 134 tid[i].arg = arg; 135 tid[i].solver = solver; 136 tid[i].threader = threader; 137 tid[i].starts = starts; 138 tid[i].node = &node; 139 tid[i].start_mutex = start_mutex; 140 141 pthread_create(&t[i], NULL, instance_thread, &tid[i]); 142 } 143 144 for (int i = 0; i < nthreads; i++) 145 pthread_join(t[i], NULL); 146 147 free(td); 148 free(sols_mutex); 149 free_alglist(starts); 150 free(start_mutex); 151 } 152 153 static int 154 get_nsol(void *threaddata) 155 { 156 ThreadData *td = (ThreadData *)threaddata; 157 158 pthread_mutex_lock(td->sols_mutex); 159 int n = td->sols->len; 160 pthread_mutex_unlock(td->sols_mutex); 161 162 return n; 163 }