darknet  v3
go.c
Go to the documentation of this file.
1 #include "darknet.h"
2 
3 #include <assert.h>
4 #include <math.h>
5 #include <unistd.h>
6 
7 int inverted = 1;
8 int noi = 1;
9 static const int nind = 10;
10 int legal_go(float *b, float *ko, int p, int r, int c);
11 int check_ko(float *x, float *ko);
12 
13 typedef struct {
14  char **data;
15  int n;
16 } moves;
17 
18 char *fgetgo(FILE *fp)
19 {
20  if(feof(fp)) return 0;
21  size_t size = 96;
22  char *line = malloc(size*sizeof(char));
23  if(size != fread(line, sizeof(char), size, fp)){
24  free(line);
25  return 0;
26  }
27 
28  return line;
29 }
30 
31 moves load_go_moves(char *filename)
32 {
33  moves m;
34  m.n = 128;
35  m.data = calloc(128, sizeof(char*));
36  FILE *fp = fopen(filename, "rb");
37  int count = 0;
38  char *line = 0;
39  while ((line = fgetgo(fp))) {
40  if (count >= m.n) {
41  m.n *= 2;
42  m.data = realloc(m.data, m.n*sizeof(char*));
43  }
44  m.data[count] = line;
45  ++count;
46  }
47  printf("%d\n", count);
48  m.n = count;
49  m.data = realloc(m.data, count*sizeof(char*));
50  return m;
51 }
52 
53 void string_to_board(char *s, float *board)
54 {
55  int i, j;
56  memset(board, 0, 2*19*19*sizeof(float));
57  int count = 0;
58  for(i = 0; i < 91; ++i){
59  char c = s[i];
60  for(j = 0; j < 4; ++j){
61  int me = (c >> (2*j)) & 1;
62  int you = (c >> (2*j + 1)) & 1;
63  if (me) board[count] = 1;
64  else if (you) board[count + 19*19] = 1;
65  ++count;
66  if(count >= 19*19) break;
67  }
68  }
69 }
70 
71 void board_to_string(char *s, float *board)
72 {
73  int i, j;
74  memset(s, 0, (19*19/4+1)*sizeof(char));
75  int count = 0;
76  for(i = 0; i < 91; ++i){
77  for(j = 0; j < 4; ++j){
78  int me = (board[count] == 1);
79  int you = (board[count + 19*19] == 1);
80  if (me) s[i] = s[i] | (1<<(2*j));
81  if (you) s[i] = s[i] | (1<<(2*j + 1));
82  ++count;
83  if(count >= 19*19) break;
84  }
85  }
86 }
87 
88 static int occupied(float *b, int i)
89 {
90  if (b[i]) return 1;
91  if (b[i+19*19]) return -1;
92  return 0;
93 }
94 
96 {
97  data d = {0};
98  d.X = make_matrix(n, 19*19*3);
99  d.y = make_matrix(n, 19*19+2);
100  int i, j;
101  for(i = 0; i < n; ++i){
102  float *board = d.X.vals[i];
103  float *label = d.y.vals[i];
104  char *b = m.data[rand()%m.n];
105  int player = b[0] - '0';
106  int result = b[1] - '0';
107  int row = b[2];
108  int col = b[3];
109  string_to_board(b+4, board);
110  if(player > 0) for(j = 0; j < 19*19; ++j) board[19*19*2 + j] = 1;
111  label[19*19+1] = (player==result);
112  if(row >= 19 || col >= 19){
113  label[19*19] = 1;
114  } else {
115  label[col + 19*row] = 1;
116  if(occupied(board, col + 19*row)) printf("hey\n");
117  }
118 
119  int flip = rand()%2;
120  int rotate = rand()%4;
121  image in = float_to_image(19, 19, 3, board);
122  image out = float_to_image(19, 19, 1, label);
123  if(flip){
124  flip_image(in);
125  flip_image(out);
126  }
127  rotate_image_cw(in, rotate);
128  rotate_image_cw(out, rotate);
129  }
130  return d;
131 }
132 
133 
134 void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ngpus, int clear)
135 {
136  int i;
137  float avg_loss = -1;
138  char *base = basecfg(cfgfile);
139  printf("%s\n", base);
140  printf("%d\n", ngpus);
141  network **nets = calloc(ngpus, sizeof(network*));
142 
143  srand(time(0));
144  int seed = rand();
145  for(i = 0; i < ngpus; ++i){
146  srand(seed);
147 #ifdef GPU
148  cuda_set_device(gpus[i]);
149 #endif
150  nets[i] = load_network(cfgfile, weightfile, clear);
151  nets[i]->learning_rate *= ngpus;
152  }
153  network *net = nets[0];
154  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
155 
156  char *backup_directory = "/home/pjreddie/backup/";
157 
158  char buff[256];
159  moves m = load_go_moves(filename);
160  //moves m = load_go_moves("games.txt");
161 
162  int N = m.n;
163  printf("Moves: %d\n", N);
164  int epoch = (*net->seen)/N;
165  while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
166  double time=what_time_is_it_now();
167 
168  data train = random_go_moves(m, net->batch*net->subdivisions*ngpus);
169  printf("Loaded: %lf seconds\n", what_time_is_it_now() - time);
170  time=what_time_is_it_now();
171 
172  float loss = 0;
173 #ifdef GPU
174  if(ngpus == 1){
175  loss = train_network(net, train);
176  } else {
177  loss = train_networks(nets, ngpus, train, 10);
178  }
179 #else
180  loss = train_network(net, train);
181 #endif
182  free_data(train);
183 
184  if(avg_loss == -1) avg_loss = loss;
185  avg_loss = avg_loss*.95 + loss*.05;
186  printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
187  if(*net->seen/N > epoch){
188  epoch = *net->seen/N;
189  char buff[256];
190  sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
191  save_weights(net, buff);
192 
193  }
194  if(get_current_batch(net)%1000 == 0){
195  char buff[256];
196  sprintf(buff, "%s/%s.backup",backup_directory,base);
197  save_weights(net, buff);
198  }
199  if(get_current_batch(net)%10000 == 0){
200  char buff[256];
201  sprintf(buff, "%s/%s_%ld.backup",backup_directory,base,get_current_batch(net));
202  save_weights(net, buff);
203  }
204  }
205  sprintf(buff, "%s/%s.weights", backup_directory, base);
206  save_weights(net, buff);
207 
208  free_network(net);
209  free(base);
210 }
211 
212 static void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
213 {
214  if (row < 0 || row > 18 || col < 0 || col > 18) return;
215  int index = row*19 + col;
216  if (occupied(board,index) != side) return;
217  if (visited[index]) return;
218  visited[index] = 1;
219  lib[index] += 1;
220  propagate_liberty(board, lib, visited, row+1, col, side);
221  propagate_liberty(board, lib, visited, row-1, col, side);
222  propagate_liberty(board, lib, visited, row, col+1, side);
223  propagate_liberty(board, lib, visited, row, col-1, side);
224 }
225 
226 
227 static int *calculate_liberties(float *board)
228 {
229  int *lib = calloc(19*19, sizeof(int));
230  int visited[19*19];
231  int i, j;
232  for(j = 0; j < 19; ++j){
233  for(i = 0; i < 19; ++i){
234  memset(visited, 0, 19*19*sizeof(int));
235  int index = j*19 + i;
236  if(!occupied(board,index)){
237  if ((i > 0) && occupied(board,index - 1)) propagate_liberty(board, lib, visited, j, i-1, occupied(board,index-1));
238  if ((i < 18) && occupied(board,index + 1)) propagate_liberty(board, lib, visited, j, i+1, occupied(board,index+1));
239  if ((j > 0) && occupied(board,index - 19)) propagate_liberty(board, lib, visited, j-1, i, occupied(board,index-19));
240  if ((j < 18) && occupied(board,index + 19)) propagate_liberty(board, lib, visited, j+1, i, occupied(board,index+19));
241  }
242  }
243  }
244  return lib;
245 }
246 
247 void print_board(FILE *stream, float *board, int player, int *indexes)
248 {
249  int i,j,n;
250  fprintf(stream, " ");
251  for(i = 0; i < 19; ++i){
252  fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
253  }
254  fprintf(stream, "\n");
255  for(j = 0; j < 19; ++j){
256  fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
257  for(i = 0; i < 19; ++i){
258  int index = j*19 + i;
259  if(indexes){
260  int found = 0;
261  for(n = 0; n < nind; ++n){
262  if(index == indexes[n]){
263  found = 1;
264  /*
265  if(n == 0) fprintf(stream, "\uff11");
266  else if(n == 1) fprintf(stream, "\uff12");
267  else if(n == 2) fprintf(stream, "\uff13");
268  else if(n == 3) fprintf(stream, "\uff14");
269  else if(n == 4) fprintf(stream, "\uff15");
270  */
271  fprintf(stream, " %d", n+1);
272  }
273  }
274  if(found) continue;
275  }
276  //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
277  //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
278  if (occupied(board, index) == player) fprintf(stream, " X");
279  else if (occupied(board, index) ==-player) fprintf(stream, " O");
280  else fprintf(stream, " .");
281  }
282  fprintf(stream, "\n");
283  }
284 }
285 
286 void flip_board(float *board)
287 {
288  int i;
289  for(i = 0; i < 19*19; ++i){
290  float swap = board[i];
291  board[i] = board[i+19*19];
292  board[i+19*19] = swap;
293  board[i+19*19*2] = 1-board[i+19*19*2];
294  }
295 }
296 
297 float predict_move2(network *net, float *board, float *move, int multi)
298 {
299  float *output = network_predict(net, board);
300  copy_cpu(19*19+1, output, 1, move, 1);
301  float result = output[19*19 + 1];
302  int i;
303  if(multi){
304  image bim = float_to_image(19, 19, 3, board);
305  for(i = 1; i < 8; ++i){
306  rotate_image_cw(bim, i);
307  if(i >= 4) flip_image(bim);
308 
309  float *output = network_predict(net, board);
310  image oim = float_to_image(19, 19, 1, output);
311  result += output[19*19 + 1];
312 
313  if(i >= 4) flip_image(oim);
314  rotate_image_cw(oim, -i);
315 
316  axpy_cpu(19*19+1, 1, output, 1, move, 1);
317 
318  if(i >= 4) flip_image(bim);
319  rotate_image_cw(bim, -i);
320  }
321  result = result/8;
322  scal_cpu(19*19+1, 1./8., move, 1);
323  }
324  for(i = 0; i < 19*19; ++i){
325  if(board[i] || board[i+19*19]) move[i] = 0;
326  }
327  return result;
328 }
329 
330 static void remove_connected(float *b, int *lib, int p, int r, int c)
331 {
332  if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
333  if (occupied(b, r*19 + c) != p) return;
334  if (lib[r*19 + c] != 1) return;
335  b[r*19 + c] = 0;
336  b[19*19 + r*19 + c] = 0;
337  remove_connected(b, lib, p, r+1, c);
338  remove_connected(b, lib, p, r-1, c);
339  remove_connected(b, lib, p, r, c+1);
340  remove_connected(b, lib, p, r, c-1);
341 }
342 
343 
344 void move_go(float *b, int p, int r, int c)
345 {
346  int *l = calculate_liberties(b);
347  if(p > 0) b[r*19 + c] = 1;
348  else b[19*19 + r*19 + c] = 1;
349  remove_connected(b, l, -p, r+1, c);
350  remove_connected(b, l, -p, r-1, c);
351  remove_connected(b, l, -p, r, c+1);
352  remove_connected(b, l, -p, r, c-1);
353  free(l);
354 }
355 
356 int compare_board(float *a, float *b)
357 {
358  if(memcmp(a, b, 19*19*3*sizeof(float)) == 0) return 1;
359  return 0;
360 }
361 
362 typedef struct mcts_tree{
363  float *board;
364  struct mcts_tree **children;
365  float *prior;
367  float *value;
368  float *mean;
369  float *prob;
371  float result;
372  int done;
373  int pass;
374 } mcts_tree;
375 
376 void free_mcts(mcts_tree *root)
377 {
378  if(!root) return;
379  int i;
380  free(root->board);
381  for(i = 0; i < 19*19+1; ++i){
382  if(root->children[i]) free_mcts(root->children[i]);
383  }
384  free(root->children);
385  free(root->prior);
386  free(root->visit_count);
387  free(root->value);
388  free(root->mean);
389  free(root->prob);
390  free(root);
391 }
392 
393 float *network_predict_rotations(network *net, float *next)
394 {
395  int n = net->batch;
396  float *in = calloc(19*19*3*n, sizeof(float));
397  image im = float_to_image(19, 19, 3, next);
398  int i,j;
399  int *inds = random_index_order(0, 8);
400  for(j = 0; j < n; ++j){
401  i = inds[j];
402  rotate_image_cw(im, i);
403  if(i >= 4) flip_image(im);
404  memcpy(in + 19*19*3*j, im.data, 19*19*3*sizeof(float));
405  if(i >= 4) flip_image(im);
406  rotate_image_cw(im, -i);
407  }
408  float *pred = network_predict(net, in);
409  for(j = 0; j < n; ++j){
410  i = inds[j];
411  image im = float_to_image(19, 19, 1, pred + j*(19*19 + 2));
412  if(i >= 4) flip_image(im);
413  rotate_image_cw(im, -i);
414  if(j > 0){
415  axpy_cpu(19*19+2, 1, im.data, 1, pred, 1);
416  }
417  }
418  free(in);
419  free(inds);
420  scal_cpu(19*19+2, 1./n, pred, 1);
421  return pred;
422 }
423 
424 mcts_tree *expand(float *next, float *ko, network *net)
425 {
426  mcts_tree *root = calloc(1, sizeof(mcts_tree));
427  root->board = next;
428  root->children = calloc(19*19+1, sizeof(mcts_tree*));
429  root->prior = calloc(19*19 + 1, sizeof(float));
430  root->prob = calloc(19*19 + 1, sizeof(float));
431  root->mean = calloc(19*19 + 1, sizeof(float));
432  root->value = calloc(19*19 + 1, sizeof(float));
433  root->visit_count = calloc(19*19 + 1, sizeof(int));
434  root->total_count = 1;
435  int i;
436  float *pred = network_predict_rotations(net, next);
437  copy_cpu(19*19+1, pred, 1, root->prior, 1);
438  float val = 2*pred[19*19 + 1] - 1;
439  root->result = val;
440  for(i = 0; i < 19*19+1; ++i) {
441  root->visit_count[i] = 0;
442  root->value[i] = 0;
443  root->mean[i] = val;
444  if(i < 19*19 && occupied(next, i)){
445  root->value[i] = -1;
446  root->mean[i] = -1;
447  root->prior[i] = 0;
448  }
449  }
450  //print_board(stderr, next, flip?-1:1, 0);
451  return root;
452 }
453 
454 float *copy_board(float *board)
455 {
456  float *next = calloc(19*19*3, sizeof(float));
457  copy_cpu(19*19*3, board, 1, next, 1);
458  return next;
459 }
460 
461 float select_mcts(mcts_tree *root, network *net, float *prev, float cpuct)
462 {
463  if(root->done) return -root->result;
464  int i;
465  float max = -1000;
466  int max_i = 0;
467  for(i = 0; i < 19*19+1; ++i){
468  root->prob[i] = root->mean[i] + cpuct*root->prior[i] * sqrt(root->total_count) / (1. + root->visit_count[i]);
469  if(root->prob[i] > max){
470  max = root->prob[i];
471  max_i = i;
472  }
473  }
474  float val;
475  i = max_i;
476  root->visit_count[i]++;
477  root->total_count++;
478  if (root->children[i]) {
479  val = select_mcts(root->children[i], net, root->board, cpuct);
480  } else {
481  if(max_i < 19*19 && !legal_go(root->board, prev, 1, max_i/19, max_i%19)) {
482  root->mean[i] = -1;
483  root->value[i] = -1;
484  root->prior[i] = 0;
485  --root->total_count;
486  return select_mcts(root, net, prev, cpuct);
487  //printf("Detected ko\n");
488  //getchar();
489  } else {
490  float *next = copy_board(root->board);
491  if (max_i < 19*19) {
492  move_go(next, 1, max_i / 19, max_i % 19);
493  }
494  flip_board(next);
495  root->children[i] = expand(next, root->board, net);
496  val = -root->children[i]->result;
497  if(max_i == 19*19){
498  root->children[i]->pass = 1;
499  if (root->pass){
500  root->children[i]->done = 1;
501  }
502  }
503  }
504  }
505  root->value[i] += val;
506  root->mean[i] = root->value[i]/root->visit_count[i];
507  return -val;
508 }
509 
510 mcts_tree *run_mcts(mcts_tree *tree, network *net, float *board, float *ko, int player, int n, float cpuct, float secs)
511 {
512  int i;
513  double t = what_time_is_it_now();
514  if(player < 0) flip_board(board);
515  if(!tree) tree = expand(copy_board(board), ko, net);
516  assert(compare_board(tree->board, board));
517  for(i = 0; i < n; ++i){
518  if (secs > 0 && (what_time_is_it_now() - t) > secs) break;
519  int max_i = max_int_index(tree->visit_count, 19*19+1);
520  if (tree->visit_count[max_i] >= n) break;
521  select_mcts(tree, net, ko, cpuct);
522  }
523  if(player < 0) flip_board(board);
524  //fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
525  return tree;
526 }
527 
529 {
530  if(index < 0 || index > 19*19 || !tree || !tree->children[index]) {
531  free_mcts(tree);
532  tree = 0;
533  } else {
534  mcts_tree *swap = tree;
535  tree = tree->children[index];
536  swap->children[index] = 0;
537  free_mcts(swap);
538  }
539  return tree;
540 }
541 
542 typedef struct {
543  float value;
544  float mcts;
545  int row;
546  int col;
547 } move;
548 
549 move pick_move(mcts_tree *tree, float temp, int player)
550 {
551  int i;
552  float probs[19*19+1] = {0};
553  move m = {0};
554  double sum = 0;
555  /*
556  for(i = 0; i < 19*19+1; ++i){
557  probs[i] = tree->visit_count[i];
558  }
559  */
560  //softmax(probs, 19*19+1, temp, 1, probs);
561  for(i = 0; i < 19*19+1; ++i){
562  sum += pow(tree->visit_count[i], 1./temp);
563  }
564  for(i = 0; i < 19*19+1; ++i){
565  probs[i] = pow(tree->visit_count[i], 1./temp) / sum;
566  }
567 
568  int index = sample_array(probs, 19*19+1);
569  m.row = index / 19;
570  m.col = index % 19;
571  m.value = (tree->result+1.)/2.;
572  m.mcts = (tree->mean[index]+1.)/2.;
573 
574  int indexes[nind];
575  top_k(probs, 19*19+1, nind, indexes);
576  print_board(stderr, tree->board, player, indexes);
577 
578  fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", index/19, index%19, tree->result, tree->prior[index], probs[index], tree->mean[index], (tree->children[index])?tree->children[index]->result:0, tree->visit_count[index]);
579  int ind = max_index(probs, 19*19+1);
580  fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, tree->result, tree->prior[ind], probs[ind], tree->mean[ind], (tree->children[ind])?tree->children[ind]->result:0, tree->visit_count[ind]);
581  ind = max_index(tree->prior, 19*19+1);
582  fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, tree->result, tree->prior[ind], probs[ind], tree->mean[ind], (tree->children[ind])?tree->children[ind]->result:0, tree->visit_count[ind]);
583  return m;
584 }
585 
586 /*
587  float predict_move(network *net, float *board, float *move, int multi, float *ko, float temp)
588  {
589 
590  int i;
591 
592  int max_v = 0;
593  int max_i = 0;
594  for(i = 0; i < 19*19+1; ++i){
595  if(root->visit_count[i] > max_v){
596  max_v = root->visit_count[i];
597  max_i = i;
598  }
599  }
600  fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
601  int ind = max_index(root->mean, 19*19+1);
602  fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", max_i/19, max_i%19, root->result, root->prior[max_i], root->prob[max_i], root->mean[max_i], (root->children[max_i])?root->children[max_i]->result:0, root->visit_count[max_i]);
603  fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, root->result, root->prior[ind], root->prob[ind], root->mean[ind], (root->children[ind])?root->children[ind]->result:0, root->visit_count[ind]);
604  ind = max_index(root->prior, 19*19+1);
605  fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, root->result, root->prior[ind], root->prob[ind], root->mean[ind], (root->children[ind])?root->children[ind]->result:0, root->visit_count[ind]);
606  if(root->result < -.9 && root->mean[max_i] < -.9) return -1000.f;
607 
608  float val = root->result;
609  free_mcts(root);
610  return val;
611  }
612  */
613 
614 static int makes_safe_go(float *b, int *lib, int p, int r, int c){
615  if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
616  if (occupied(b,r*19 + c) == -p){
617  if (lib[r*19 + c] > 1) return 0;
618  else return 1;
619  }
620  if (!occupied(b,r*19 + c)) return 1;
621  if (lib[r*19 + c] > 1) return 1;
622  return 0;
623 }
624 
625 int suicide_go(float *b, int p, int r, int c)
626 {
627  int *l = calculate_liberties(b);
628  int safe = 0;
629  safe = safe || makes_safe_go(b, l, p, r+1, c);
630  safe = safe || makes_safe_go(b, l, p, r-1, c);
631  safe = safe || makes_safe_go(b, l, p, r, c+1);
632  safe = safe || makes_safe_go(b, l, p, r, c-1);
633  free(l);
634  return !safe;
635 }
636 
637 int check_ko(float *x, float *ko)
638 {
639  if(!ko) return 0;
640  float curr[19*19*3];
641  copy_cpu(19*19*3, x, 1, curr, 1);
642  if(curr[19*19*2] != ko[19*19*2]) flip_board(curr);
643  if(compare_board(curr, ko)) return 1;
644  return 0;
645 }
646 
647 int legal_go(float *b, float *ko, int p, int r, int c)
648 {
649  if (occupied(b, r*19+c)) return 0;
650  float curr[19*19*3];
651  copy_cpu(19*19*3, b, 1, curr, 1);
652  move_go(curr, p, r, c);
653  if(check_ko(curr, ko)) return 0;
654  if(suicide_go(b, p, r, c)) return 0;
655  return 1;
656 }
657 
658 /*
659  move generate_move(mcts_tree *root, network *net, int player, float *board, int multi, float temp, float *ko, int print)
660  {
661  move m = {0};
662 //root = run_mcts(tree, network *net, float *board, float *ko, int n, float cpuct)
663 int i, j;
664 int empty = 1;
665 for(i = 0; i < 19*19; ++i){
666 if (occupied(board, i)) {
667 empty = 0;
668 break;
669 }
670 }
671 if(empty) {
672 m.value = .5;
673 m.mcts = .5;
674 m.row = 3;
675 m.col = 15;
676 return m;
677 }
678 
679 float move[362];
680 if (player < 0) flip_board(board);
681 float result = predict_move(net, board, move, multi, ko, temp);
682 if (player < 0) flip_board(board);
683 if(result == -1000.f) return -2;
684 
685 for(i = 0; i < 19; ++i){
686 for(j = 0; j < 19; ++j){
687 if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
688 }
689 }
690 
691 int indexes[nind];
692 top_k(move, 19*19+1, nind, indexes);
693 
694 
695 int max = max_index(move, 19*19+1);
696 int row = max / 19;
697 int col = max % 19;
698 int index = sample_array(move, 19*19+1);
699 
700 if(print){
701 top_k(move, 19*19+1, nind, indexes);
702 for(i = 0; i < nind; ++i){
703 if (!move[indexes[i]]) indexes[i] = -1;
704 }
705 print_board(stderr, board, 1, indexes);
706 fprintf(stderr, "%s To Move\n", player > 0 ? "X" : "O");
707 fprintf(stderr, "%.2f%% Win Chance\n", (result+1)/2*100);
708 for(i = 0; i < nind; ++i){
709 int index = indexes[i];
710 int row = index / 19;
711 int col = index % 19;
712 if(row == 19){
713 fprintf(stderr, "%d: Pass, %.2f%%\n", i+1, move[index]*100);
714 } else {
715 fprintf(stderr, "%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
716 }
717 }
718 }
719 if (row == 19) return -1;
720 
721 if (suicide_go(board, player, row, col)){
722 return -1;
723 }
724 
725 if (suicide_go(board, player, index/19, index%19)){
726 index = max;
727 }
728 if (index == 19*19) return -1;
729 return index;
730 }
731 */
732 
733 void valid_go(char *cfgfile, char *weightfile, int multi, char *filename)
734 {
735  srand(time(0));
736  char *base = basecfg(cfgfile);
737  printf("%s\n", base);
738  network *net = load_network(cfgfile, weightfile, 0);
739  set_batch_network(net, 1);
740  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
741 
742  float *board = calloc(19*19*3, sizeof(float));
743  float *move = calloc(19*19+2, sizeof(float));
744  // moves m = load_go_moves("/home/pjreddie/backup/go.test");
745  moves m = load_go_moves(filename);
746 
747  int N = m.n;
748  int i,j;
749  int correct = 0;
750  for (i = 0; i <N; ++i) {
751  char *b = m.data[i];
752  int player = b[0] - '0';
753  //int result = b[1] - '0';
754  int row = b[2];
755  int col = b[3];
756  int truth = col + 19*row;
757  string_to_board(b+4, board);
758  if(player > 0) for(j = 0; j < 19*19; ++j) board[19*19*2 + j] = 1;
759  predict_move2(net, board, move, multi);
760  int index = max_index(move, 19*19+1);
761  if(index == truth) ++correct;
762  printf("%d Accuracy %f\n", i, (float) correct/(i+1));
763  }
764 }
765 
766 int print_game(float *board, FILE *fp)
767 {
768  int i, j;
769  int count = 3;
770  fprintf(fp, "komi 6.5\n");
771  fprintf(fp, "boardsize 19\n");
772  fprintf(fp, "clear_board\n");
773  for(j = 0; j < 19; ++j){
774  for(i = 0; i < 19; ++i){
775  if(occupied(board,j*19 + i) == 1) fprintf(fp, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
776  if(occupied(board,j*19 + i) == -1) fprintf(fp, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
777  if(occupied(board,j*19 + i)) ++count;
778  }
779  }
780  return count;
781 }
782 
783 
785 {
786  fd_set readfds;
787  FD_ZERO(&readfds);
788 
789  struct timeval timeout;
790  timeout.tv_sec = 0;
791  timeout.tv_usec = 0;
792  FD_SET(STDIN_FILENO, &readfds);
793 
794  if (select(1, &readfds, NULL, NULL, &timeout)){
795  return 1;
796  }
797  return 0;
798 }
799 
800 mcts_tree *ponder(mcts_tree *tree, network *net, float *b, float *ko, int player, float cpuct)
801 {
802  double t = what_time_is_it_now();
803  int count = 0;
804  if (tree) count = tree->total_count;
805  while(!stdin_ready()){
806  if (what_time_is_it_now() - t > 120) break;
807  tree = run_mcts(tree, net, b, ko, player, 100000, cpuct, .1);
808  }
809  fprintf(stderr, "Pondered %d moves...\n", tree->total_count - count);
810  return tree;
811 }
812 
813 void engine_go(char *filename, char *weightfile, int mcts_iters, float secs, float temp, float cpuct, int anon, int resign)
814 {
815  mcts_tree *root = 0;
816  network *net = load_network(filename, weightfile, 0);
817  set_batch_network(net, 1);
818  srand(time(0));
819  float *board = calloc(19*19*3, sizeof(float));
820  flip_board(board);
821  float *one = calloc(19*19*3, sizeof(float));
822  float *two = calloc(19*19*3, sizeof(float));
823  int ponder_player = 0;
824  int passed = 0;
825  int move_num = 0;
826  int main_time = 0;
827  int byo_yomi_time = 0;
828  int byo_yomi_stones = 0;
829  int black_time_left = 0;
830  int black_stones_left = 0;
831  int white_time_left = 0;
832  int white_stones_left = 0;
833  float orig_time = secs;
834  int old_ponder = 0;
835  while(1){
836  if(ponder_player){
837  root = ponder(root, net, board, two, ponder_player, cpuct);
838  }
839  old_ponder = ponder_player;
840  ponder_player = 0;
841  char buff[256];
842  int id = 0;
843  int has_id = (scanf("%d", &id) == 1);
844  scanf("%s", buff);
845  if (feof(stdin)) break;
846  fprintf(stderr, "%s\n", buff);
847  char ids[256];
848  sprintf(ids, "%d", id);
849  //fprintf(stderr, "%s\n", buff);
850  if (!has_id) ids[0] = 0;
851  if (!strcmp(buff, "protocol_version")){
852  printf("=%s 2\n\n", ids);
853  } else if (!strcmp(buff, "name")){
854  if(anon){
855  printf("=%s The Fool!\n\n", ids);
856  }else{
857  printf("=%s DarkGo\n\n", ids);
858  }
859  } else if (!strcmp(buff, "time_settings")){
860  ponder_player = old_ponder;
861  scanf("%d %d %d", &main_time, &byo_yomi_time, &byo_yomi_stones);
862  printf("=%s \n\n", ids);
863  } else if (!strcmp(buff, "time_left")){
864  ponder_player = old_ponder;
865  char color[256];
866  int time = 0, stones = 0;
867  scanf("%s %d %d", color, &time, &stones);
868  if (color[0] == 'b' || color[0] == 'B'){
869  black_time_left = time;
870  black_stones_left = stones;
871  } else {
872  white_time_left = time;
873  white_stones_left = stones;
874  }
875  printf("=%s \n\n", ids);
876  } else if (!strcmp(buff, "version")){
877  if(anon){
878  printf("=%s :-DDDD\n\n", ids);
879  }else {
880  printf("=%s 1.0. Want more DarkGo? You can find me on OGS, unlimited games, no waiting! https://online-go.com/user/view/434218\n\n", ids);
881  }
882  } else if (!strcmp(buff, "known_command")){
883  char comm[256];
884  scanf("%s", comm);
885  int known = (!strcmp(comm, "protocol_version") ||
886  !strcmp(comm, "name") ||
887  !strcmp(comm, "version") ||
888  !strcmp(comm, "known_command") ||
889  !strcmp(comm, "list_commands") ||
890  !strcmp(comm, "quit") ||
891  !strcmp(comm, "boardsize") ||
892  !strcmp(comm, "clear_board") ||
893  !strcmp(comm, "komi") ||
894  !strcmp(comm, "final_status_list") ||
895  !strcmp(comm, "play") ||
896  !strcmp(comm, "genmove_white") ||
897  !strcmp(comm, "genmove_black") ||
898  !strcmp(comm, "fixed_handicap") ||
899  !strcmp(comm, "genmove"));
900  if(known) printf("=%s true\n\n", ids);
901  else printf("=%s false\n\n", ids);
902  } else if (!strcmp(buff, "list_commands")){
903  printf("=%s protocol_version\nshowboard\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove_black\ngenmove_white\ngenmove\nfinal_status_list\nfixed_handicap\n\n", ids);
904  } else if (!strcmp(buff, "quit")){
905  break;
906  } else if (!strcmp(buff, "boardsize")){
907  int boardsize = 0;
908  scanf("%d", &boardsize);
909  //fprintf(stderr, "%d\n", boardsize);
910  if(boardsize != 19){
911  printf("?%s unacceptable size\n\n", ids);
912  } else {
913  root = move_mcts(root, -1);
914  memset(board, 0, 3*19*19*sizeof(float));
915  flip_board(board);
916  move_num = 0;
917  printf("=%s \n\n", ids);
918  }
919  } else if (!strcmp(buff, "fixed_handicap")){
920  int handicap = 0;
921  scanf("%d", &handicap);
922  int indexes[] = {72, 288, 300, 60, 180, 174, 186, 66, 294};
923  int i;
924  for(i = 0; i < handicap; ++i){
925  board[indexes[i]] = 1;
926  ++move_num;
927  }
928  root = move_mcts(root, -1);
929  } else if (!strcmp(buff, "clear_board")){
930  passed = 0;
931  memset(board, 0, 3*19*19*sizeof(float));
932  flip_board(board);
933  move_num = 0;
934  root = move_mcts(root, -1);
935  printf("=%s \n\n", ids);
936  } else if (!strcmp(buff, "komi")){
937  float komi = 0;
938  scanf("%f", &komi);
939  printf("=%s \n\n", ids);
940  } else if (!strcmp(buff, "showboard")){
941  printf("=%s \n", ids);
942  print_board(stdout, board, 1, 0);
943  printf("\n");
944  } else if (!strcmp(buff, "play") || !strcmp(buff, "black") || !strcmp(buff, "white")){
945  ++move_num;
946  char color[256];
947  if(!strcmp(buff, "play"))
948  {
949  scanf("%s ", color);
950  } else {
951  scanf(" ");
952  color[0] = buff[0];
953  }
954  char c;
955  int r;
956  int count = scanf("%c%d", &c, &r);
957  int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
958  if((c == 'p' || c == 'P') && count < 2) {
959  passed = 1;
960  printf("=%s \n\n", ids);
961  char *line = fgetl(stdin);
962  free(line);
963  fflush(stdout);
964  fflush(stderr);
965  root = move_mcts(root, 19*19);
966  continue;
967  } else {
968  passed = 0;
969  }
970  if(c >= 'A' && c <= 'Z') c = c - 'A';
971  if(c >= 'a' && c <= 'z') c = c - 'a';
972  if(c >= 8) --c;
973  r = 19 - r;
974  fprintf(stderr, "move: %d %d\n", r, c);
975 
976  float *swap = two;
977  two = one;
978  one = swap;
979  move_go(board, player, r, c);
980  copy_cpu(19*19*3, board, 1, one, 1);
981  if(root) fprintf(stderr, "Prior: %f\n", root->prior[r*19 + c]);
982  if(root) fprintf(stderr, "Mean: %f\n", root->mean[r*19 + c]);
983  if(root) fprintf(stderr, "Result: %f\n", root->result);
984  root = move_mcts(root, r*19 + c);
985  if(root) fprintf(stderr, "Visited: %d\n", root->total_count);
986  else fprintf(stderr, "NOT VISITED\n");
987 
988  printf("=%s \n\n", ids);
989  //print_board(stderr, board, 1, 0);
990  } else if (!strcmp(buff, "genmove") || !strcmp(buff, "genmove_black") || !strcmp(buff, "genmove_white")){
991  ++move_num;
992  int player = 0;
993  if(!strcmp(buff, "genmove")){
994  char color[256];
995  scanf("%s", color);
996  player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
997  } else if (!strcmp(buff, "genmove_black")){
998  player = 1;
999  } else {
1000  player = -1;
1001  }
1002  if(player > 0){
1003  if(black_time_left <= 30) secs = 2.5;
1004  else secs = orig_time;
1005  } else {
1006  if(white_time_left <= 30) secs = 2.5;
1007  else secs = orig_time;
1008  }
1009  ponder_player = -player;
1010 
1011  //tree = generate_move(net, player, board, multi, .1, two, 1);
1012  double t = what_time_is_it_now();
1013  root = run_mcts(root, net, board, two, player, mcts_iters, cpuct, secs);
1014  fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
1015  move m = pick_move(root, temp, player);
1016  root = move_mcts(root, m.row*19 + m.col);
1017 
1018 
1019  if(move_num > resign && m.value < .1 && m.mcts < .1){
1020  printf("=%s resign\n\n", ids);
1021  } else if(m.row == 19){
1022  printf("=%s pass\n\n", ids);
1023  passed = 0;
1024  } else {
1025  int row = m.row;
1026  int col = m.col;
1027 
1028  float *swap = two;
1029  two = one;
1030  one = swap;
1031 
1032  move_go(board, player, row, col);
1033  copy_cpu(19*19*3, board, 1, one, 1);
1034  row = 19 - row;
1035  if (col >= 8) ++col;
1036  printf("=%s %c%d\n\n", ids, 'A' + col, row);
1037  }
1038 
1039  } else if (!strcmp(buff, "p")){
1040  //print_board(board, 1, 0);
1041  } else if (!strcmp(buff, "final_status_list")){
1042  char type[256];
1043  scanf("%s", type);
1044  fprintf(stderr, "final_status\n");
1045  char *line = fgetl(stdin);
1046  free(line);
1047  if(type[0] == 'd' || type[0] == 'D'){
1048  int i;
1049  FILE *f = fopen("game.txt", "w");
1050  int count = print_game(board, f);
1051  fprintf(f, "%s final_status_list dead\n", ids);
1052  fclose(f);
1053  FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
1054  for(i = 0; i < count; ++i){
1055  free(fgetl(p));
1056  free(fgetl(p));
1057  }
1058  char *l = 0;
1059  while((l = fgetl(p))){
1060  printf("%s\n", l);
1061  free(l);
1062  }
1063  } else {
1064  printf("?%s unknown command\n\n", ids);
1065  }
1066  } else if (!strcmp(buff, "kgs-genmove_cleanup")){
1067  char type[256];
1068  scanf("%s", type);
1069  fprintf(stderr, "kgs-genmove_cleanup\n");
1070  char *line = fgetl(stdin);
1071  free(line);
1072  int i;
1073  FILE *f = fopen("game.txt", "w");
1074  int count = print_game(board, f);
1075  fprintf(f, "%s kgs-genmove_cleanup %s\n", ids, type);
1076  fclose(f);
1077  FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
1078  for(i = 0; i < count; ++i){
1079  free(fgetl(p));
1080  free(fgetl(p));
1081  }
1082  char *l = 0;
1083  while((l = fgetl(p))){
1084  printf("%s\n", l);
1085  free(l);
1086  }
1087  } else {
1088  char *line = fgetl(stdin);
1089  free(line);
1090  printf("?%s unknown command\n\n", ids);
1091  }
1092  fflush(stdout);
1093  fflush(stderr);
1094  }
1095  printf("%d %d %d\n",passed, black_stones_left, white_stones_left);
1096 }
1097 
1098 void test_go(char *cfg, char *weights, int multi)
1099 {
1100  int i;
1101  network *net = load_network(cfg, weights, 0);
1102  set_batch_network(net, 1);
1103  srand(time(0));
1104  float *board = calloc(19*19*3, sizeof(float));
1105  flip_board(board);
1106  float *move = calloc(19*19+1, sizeof(float));
1107  int color = 1;
1108  while(1){
1109  float result = predict_move2(net, board, move, multi);
1110  printf("%.2f%% Win Chance\n", (result+1)/2*100);
1111 
1112  int indexes[nind];
1113  int row, col;
1114  top_k(move, 19*19+1, nind, indexes);
1115  print_board(stderr, board, color, indexes);
1116  for(i = 0; i < nind; ++i){
1117  int index = indexes[i];
1118  row = index / 19;
1119  col = index % 19;
1120  if(row == 19){
1121  printf("%d: Pass, %.2f%%\n", i+1, move[index]*100);
1122  } else {
1123  printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
1124  }
1125  }
1126  //if(color == 1) printf("\u25EF Enter move: ");
1127  //else printf("\u25C9 Enter move: ");
1128  if(color == 1) printf("X Enter move: ");
1129  else printf("O Enter move: ");
1130 
1131  char c;
1132  char *line = fgetl(stdin);
1133  int picked = 1;
1134  int dnum = sscanf(line, "%d", &picked);
1135  int cnum = sscanf(line, "%c", &c);
1136  if (strlen(line) == 0 || dnum) {
1137  --picked;
1138  if (picked < nind){
1139  int index = indexes[picked];
1140  row = index / 19;
1141  col = index % 19;
1142  if(row < 19){
1143  move_go(board, 1, row, col);
1144  }
1145  }
1146  } else if (cnum){
1147  if (c <= 'T' && c >= 'A'){
1148  int num = sscanf(line, "%c %d", &c, &row);
1149  row = (inverted)?19 - row : row-1;
1150  col = c - 'A';
1151  if (col > 7 && noi) col -= 1;
1152  if (num == 2) move_go(board, 1, row, col);
1153  } else if (c == 'p') {
1154  // Pass
1155  } else if(c=='b' || c == 'w'){
1156  char g;
1157  int num = sscanf(line, "%c %c %d", &g, &c, &row);
1158  row = (inverted)?19 - row : row-1;
1159  col = c - 'A';
1160  if (col > 7 && noi) col -= 1;
1161  if (num == 3) {
1162  int mc = (g == 'b') ? 1 : -1;
1163  if (mc == color) {
1164  board[row*19 + col] = 1;
1165  } else {
1166  board[19*19 + row*19 + col] = 1;
1167  }
1168  }
1169  } else if(c == 'c'){
1170  char g;
1171  int num = sscanf(line, "%c %c %d", &g, &c, &row);
1172  row = (inverted)?19 - row : row-1;
1173  col = c - 'A';
1174  if (col > 7 && noi) col -= 1;
1175  if (num == 3) {
1176  board[row*19 + col] = 0;
1177  board[19*19 + row*19 + col] = 0;
1178  }
1179  }
1180  }
1181  free(line);
1182  flip_board(board);
1183  color = -color;
1184  }
1185 }
1186 
1187 float score_game(float *board)
1188 {
1189  int i;
1190  FILE *f = fopen("game.txt", "w");
1191  int count = print_game(board, f);
1192  fprintf(f, "final_score\n");
1193  fclose(f);
1194  FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
1195  for(i = 0; i < count; ++i){
1196  free(fgetl(p));
1197  free(fgetl(p));
1198  }
1199  char *l = 0;
1200  float score = 0;
1201  char player = 0;
1202  while((l = fgetl(p))){
1203  fprintf(stderr, "%s \t", l);
1204  int n = sscanf(l, "= %c+%f", &player, &score);
1205  free(l);
1206  if (n == 2) break;
1207  }
1208  if(player == 'W') score = -score;
1209  pclose(p);
1210  return score;
1211 }
1212 
1213 void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
1214 {
1215  mcts_tree *tree1 = 0;
1216  mcts_tree *tree2 = 0;
1217  network *net = load_network(filename, weightfile, 0);
1218  //set_batch_network(net, 1);
1219 
1220  network *net2;
1221  if (f2) {
1222  net2 = parse_network_cfg(f2);
1223  if(w2){
1224  load_weights(net2, w2);
1225  }
1226  } else {
1227  net2 = calloc(1, sizeof(network));
1228  *net2 = *net;
1229  }
1230  srand(time(0));
1231  char boards[600][93];
1232  int count = 0;
1233  //set_batch_network(net, 1);
1234  //set_batch_network(net2, 1);
1235  float *board = calloc(19*19*3, sizeof(float));
1236  flip_board(board);
1237  float *one = calloc(19*19*3, sizeof(float));
1238  float *two = calloc(19*19*3, sizeof(float));
1239  int done = 0;
1240  int player = 1;
1241  int p1 = 0;
1242  int p2 = 0;
1243  int total = 0;
1244  float temp = .1;
1245  int mcts_iters = 500;
1246  float cpuct = 5;
1247  while(1){
1248  if (done){
1249  tree1 = move_mcts(tree1, -1);
1250  tree2 = move_mcts(tree2, -1);
1251  float score = score_game(board);
1252  if((score > 0) == (total%2==0)) ++p1;
1253  else ++p2;
1254  ++total;
1255  fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
1256  sleep(1);
1257  /*
1258  int i = (score > 0)? 0 : 1;
1259  int j;
1260  for(; i < count; i += 2){
1261  for(j = 0; j < 93; ++j){
1262  printf("%c", boards[i][j]);
1263  }
1264  printf("\n");
1265  }
1266  */
1267  memset(board, 0, 3*19*19*sizeof(float));
1268  flip_board(board);
1269  player = 1;
1270  done = 0;
1271  count = 0;
1272  fflush(stdout);
1273  fflush(stderr);
1274  }
1275  //print_board(stderr, board, 1, 0);
1276  //sleep(1);
1277 
1278  if ((total%2==0) == (player==1)){
1279  //mcts_iters = 4500;
1280  cpuct = 5;
1281  } else {
1282  //mcts_iters = 500;
1283  cpuct = 1;
1284  }
1285  network *use = ((total%2==0) == (player==1)) ? net : net2;
1286  mcts_tree *t = ((total%2==0) == (player==1)) ? tree1 : tree2;
1287  t = run_mcts(t, use, board, two, player, mcts_iters, cpuct, 0);
1288  move m = pick_move(t, temp, player);
1289  if(((total%2==0) == (player==1))) tree1 = t;
1290  else tree2 = t;
1291 
1292  tree1 = move_mcts(tree1, m.row*19 + m.col);
1293  tree2 = move_mcts(tree2, m.row*19 + m.col);
1294 
1295  if(m.row == 19){
1296  done = 1;
1297  continue;
1298  }
1299  int row = m.row;
1300  int col = m.col;
1301 
1302  float *swap = two;
1303  two = one;
1304  one = swap;
1305 
1306  if(player < 0) flip_board(board);
1307  boards[count][0] = row;
1308  boards[count][1] = col;
1309  board_to_string(boards[count] + 2, board);
1310  if(player < 0) flip_board(board);
1311  ++count;
1312 
1313  move_go(board, player, row, col);
1314  copy_cpu(19*19*3, board, 1, one, 1);
1315 
1316  player = -player;
1317  }
1318 }
1319 
1320 void run_go(int argc, char **argv)
1321 {
1322  //boards_go();
1323  if(argc < 4){
1324  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
1325  return;
1326  }
1327 
1328  char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
1329  int *gpus = 0;
1330  int gpu = 0;
1331  int ngpus = 0;
1332  if(gpu_list){
1333  printf("%s\n", gpu_list);
1334  int len = strlen(gpu_list);
1335  ngpus = 1;
1336  int i;
1337  for(i = 0; i < len; ++i){
1338  if (gpu_list[i] == ',') ++ngpus;
1339  }
1340  gpus = calloc(ngpus, sizeof(int));
1341  for(i = 0; i < ngpus; ++i){
1342  gpus[i] = atoi(gpu_list);
1343  gpu_list = strchr(gpu_list, ',')+1;
1344  }
1345  } else {
1346  gpu = gpu_index;
1347  gpus = &gpu;
1348  ngpus = 1;
1349  }
1350  int clear = find_arg(argc, argv, "-clear");
1351 
1352  char *cfg = argv[3];
1353  char *weights = (argc > 4) ? argv[4] : 0;
1354  char *c2 = (argc > 5) ? argv[5] : 0;
1355  char *w2 = (argc > 6) ? argv[6] : 0;
1356  int multi = find_arg(argc, argv, "-multi");
1357  int anon = find_arg(argc, argv, "-anon");
1358  int iters = find_int_arg(argc, argv, "-iters", 500);
1359  int resign = find_int_arg(argc, argv, "-resign", 175);
1360  float cpuct = find_float_arg(argc, argv, "-cpuct", 5);
1361  float temp = find_float_arg(argc, argv, "-temp", .1);
1362  float time = find_float_arg(argc, argv, "-time", 0);
1363  if(0==strcmp(argv[2], "train")) train_go(cfg, weights, c2, gpus, ngpus, clear);
1364  else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi, c2);
1365  else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
1366  else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
1367  else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, iters, time, temp, cpuct, anon, resign);
1368 }
1369 
1370 
Definition: go.c:542
int pass
Definition: go.c:373
int done
Definition: go.c:372
float decay
Definition: darknet.h:447
int batch
Definition: darknet.h:436
float * copy_board(float *board)
Definition: go.c:454
float * board
Definition: go.c:363
int find_arg(int argc, char *argv[], char *arg)
Definition: utils.c:120
void set_batch_network(network *net, int b)
Definition: network.c:339
int suicide_go(float *b, int p, int r, int c)
Definition: go.c:625
void run_go(int argc, char **argv)
Definition: go.c:1320
float learning_rate
Definition: darknet.h:445
int max_index(float *a, int n)
Definition: utils.c:619
float predict_move2(network *net, float *board, float *move, int multi)
Definition: go.c:297
float momentum
Definition: darknet.h:446
float * mean
Definition: go.c:368
float * prior
Definition: go.c:365
void free_data(data d)
Definition: data.c:665
char * find_char_arg(int argc, char **argv, char *arg, char *def)
Definition: utils.c:163
char * basecfg(char *cfgfile)
Definition: utils.c:179
int col
Definition: go.c:546
size_t * seen
Definition: darknet.h:437
float train_network(network *net, data d)
Definition: network.c:314
void top_k(float *a, int n, int k, int *index)
Definition: utils.c:237
int * random_index_order(int min, int max)
Definition: utils.c:97
float * value
Definition: go.c:367
float value
Definition: go.c:543
void move_go(float *b, int p, int r, int c)
Definition: go.c:344
Definition: darknet.h:512
void test_go(char *cfg, char *weights, int multi)
Definition: go.c:1098
Definition: go.c:362
moves load_go_moves(char *filename)
Definition: go.c:31
void free_network(network *net)
Definition: network.c:716
void flip_image(image a)
Definition: image.c:349
image float_to_image(int w, int h, int c, float *data)
Definition: image.c:774
void save_weights(network *net, char *filename)
Definition: parser.c:1080
int noi
Definition: go.c:8
network * parse_network_cfg(char *filename)
Definition: parser.c:742
void valid_go(char *cfgfile, char *weightfile, int multi, char *filename)
Definition: go.c:733
network_predict
Definition: darknet.py:79
float select_mcts(mcts_tree *root, network *net, float *prev, float cpuct)
Definition: go.c:461
mcts_tree * run_mcts(mcts_tree *tree, network *net, float *board, float *ko, int player, int n, float cpuct, float secs)
Definition: go.c:510
int print_game(float *board, FILE *fp)
Definition: go.c:766
int max_batches
Definition: darknet.h:453
int * visit_count
Definition: go.c:366
void free_mcts(mcts_tree *root)
Definition: go.c:376
void string_to_board(char *s, float *board)
Definition: go.c:53
void engine_go(char *filename, char *weightfile, int mcts_iters, float secs, float temp, float cpuct, int anon, int resign)
Definition: go.c:813
Definition: darknet.h:42
int row
Definition: go.c:545
move pick_move(mcts_tree *tree, float temp, int player)
Definition: go.c:549
float mcts
Definition: go.c:544
Definition: go.c:13
char * fgetl(FILE *fp)
Definition: utils.c:335
char ** data
Definition: go.c:14
int subdivisions
Definition: darknet.h:440
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
Definition: blas.c:178
float find_float_arg(int argc, char **argv, char *arg, float def)
Definition: utils.c:148
struct mcts_tree mcts_tree
char * fgetgo(FILE *fp)
Definition: go.c:18
float get_current_rate(network *net)
Definition: network.c:90
float result
Definition: go.c:371
int inverted
Definition: go.c:7
void scal_cpu(int N, float ALPHA, float *X, int INCX)
Definition: blas.c:184
struct mcts_tree ** children
Definition: go.c:364
float * network_predict_rotations(network *net, float *next)
Definition: go.c:393
mcts_tree * ponder(mcts_tree *tree, network *net, float *b, float *ko, int player, float cpuct)
Definition: go.c:800
int find_int_arg(int argc, char **argv, char *arg, int def)
Definition: utils.c:133
network * load_network(char *cfg, char *weights, int clear)
Definition: network.c:53
int check_ko(float *x, float *ko)
Definition: go.c:637
mcts_tree * move_mcts(mcts_tree *tree, int index)
Definition: go.c:528
mcts_tree * expand(float *next, float *ko, network *net)
Definition: go.c:424
void cuda_set_device(int n)
Definition: cuda.c:176
float score_game(float *board)
Definition: go.c:1187
void copy_cpu(int N, float *X, int INCX, float *Y, int INCY)
Definition: blas.c:226
matrix X
Definition: darknet.h:540
float * prob
Definition: go.c:369
void print_board(FILE *stream, float *board, int player, int *indexes)
Definition: go.c:247
int gpu_index
Definition: cuda.c:1
void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
Definition: go.c:1213
void rotate_image_cw(image im, int times)
Definition: image.c:328
size_t get_current_batch(network *net)
Definition: network.c:63
void flip_board(float *board)
Definition: go.c:286
float ** vals
Definition: darknet.h:534
int stdin_ready()
Definition: go.c:784
matrix make_matrix(int rows, int cols)
Definition: matrix.c:91
void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ngpus, int clear)
Definition: go.c:134
int compare_board(float *a, float *b)
Definition: go.c:356
int max_int_index(int *a, int n)
Definition: utils.c:605
int total_count
Definition: go.c:370
int legal_go(float *b, float *ko, int p, int r, int c)
Definition: go.c:647
int n
Definition: go.c:15
int sample_array(float *a, int n)
Definition: utils.c:592
void load_weights(network *net, char *filename)
Definition: parser.c:1308
void board_to_string(char *s, float *board)
Definition: go.c:71
data random_go_moves(moves m, int n)
Definition: go.c:95
Definition: darknet.h:538
double what_time_is_it_now()
Definition: utils.c:27
float * data
Definition: darknet.h:516
matrix y
Definition: darknet.h:541