darknet  v3
compare.c
Go to the documentation of this file.
1 #ifdef _COMPARE_
2 #include <stdio.h>
3 
4 #include "network.h"
5 #include "detection_layer.h"
6 #include "cost_layer.h"
7 #include "utils.h"
8 #include "parser.h"
9 #include "box.h"
10 
11 void train_compare(char *cfgfile, char *weightfile)
12 {
13  srand(time(0));
14  float avg_loss = -1;
15  char *base = basecfg(cfgfile);
16  char *backup_directory = "/home/pjreddie/backup/";
17  printf("%s\n", base);
18  network net = parse_network_cfg(cfgfile);
19  if(weightfile){
20  load_weights(&net, weightfile);
21  }
22  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
23  int imgs = 1024;
24  list *plist = get_paths("data/compare.train.list");
25  char **paths = (char **)list_to_array(plist);
26  int N = plist->size;
27  printf("%d\n", N);
28  clock_t time;
29  pthread_t load_thread;
30  data train;
31  data buffer;
32 
33  load_args args = {0};
34  args.w = net.w;
35  args.h = net.h;
36  args.paths = paths;
37  args.classes = 20;
38  args.n = imgs;
39  args.m = N;
40  args.d = &buffer;
41  args.type = COMPARE_DATA;
42 
43  load_thread = load_data_in_thread(args);
44  int epoch = *net.seen/N;
45  int i = 0;
46  while(1){
47  ++i;
48  time=clock();
49  pthread_join(load_thread, 0);
50  train = buffer;
51 
52  load_thread = load_data_in_thread(args);
53  printf("Loaded: %lf seconds\n", sec(clock()-time));
54  time=clock();
55  float loss = train_network(net, train);
56  if(avg_loss == -1) avg_loss = loss;
57  avg_loss = avg_loss*.9 + loss*.1;
58  printf("%.3f: %f, %f avg, %lf seconds, %ld images\n", (float)*net.seen/N, loss, avg_loss, sec(clock()-time), *net.seen);
59  free_data(train);
60  if(i%100 == 0){
61  char buff[256];
62  sprintf(buff, "%s/%s_%d_minor_%d.weights",backup_directory,base, epoch, i);
63  save_weights(net, buff);
64  }
65  if(*net.seen/N > epoch){
66  epoch = *net.seen/N;
67  i = 0;
68  char buff[256];
69  sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
70  save_weights(net, buff);
71  if(epoch%22 == 0) net.learning_rate *= .1;
72  }
73  }
74  pthread_join(load_thread, 0);
75  free_data(buffer);
76  free_network(net);
77  free_ptrs((void**)paths, plist->size);
78  free_list(plist);
79  free(base);
80 }
81 
82 void validate_compare(char *filename, char *weightfile)
83 {
84  int i = 0;
85  network net = parse_network_cfg(filename);
86  if(weightfile){
87  load_weights(&net, weightfile);
88  }
89  srand(time(0));
90 
91  list *plist = get_paths("data/compare.val.list");
92  //list *plist = get_paths("data/compare.val.old");
93  char **paths = (char **)list_to_array(plist);
94  int N = plist->size/2;
95  free_list(plist);
96 
97  clock_t time;
98  int correct = 0;
99  int total = 0;
100  int splits = 10;
101  int num = (i+1)*N/splits - i*N/splits;
102 
103  data val, buffer;
104 
105  load_args args = {0};
106  args.w = net.w;
107  args.h = net.h;
108  args.paths = paths;
109  args.classes = 20;
110  args.n = num;
111  args.m = 0;
112  args.d = &buffer;
113  args.type = COMPARE_DATA;
114 
115  pthread_t load_thread = load_data_in_thread(args);
116  for(i = 1; i <= splits; ++i){
117  time=clock();
118 
119  pthread_join(load_thread, 0);
120  val = buffer;
121 
122  num = (i+1)*N/splits - i*N/splits;
123  char **part = paths+(i*N/splits);
124  if(i != splits){
125  args.paths = part;
126  load_thread = load_data_in_thread(args);
127  }
128  printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
129 
130  time=clock();
131  matrix pred = network_predict_data(net, val);
132  int j,k;
133  for(j = 0; j < val.y.rows; ++j){
134  for(k = 0; k < 20; ++k){
135  if(val.y.vals[j][k*2] != val.y.vals[j][k*2+1]){
136  ++total;
137  if((val.y.vals[j][k*2] < val.y.vals[j][k*2+1]) == (pred.vals[j][k*2] < pred.vals[j][k*2+1])){
138  ++correct;
139  }
140  }
141  }
142  }
143  free_matrix(pred);
144  printf("%d: Acc: %f, %lf seconds, %d images\n", i, (float)correct/total, sec(clock()-time), val.X.rows);
145  free_data(val);
146  }
147 }
148 
149 typedef struct {
150  network net;
151  char *filename;
152  int class;
153  int classes;
154  float elo;
155  float *elos;
156 } sortable_bbox;
157 
158 int total_compares = 0;
159 int current_class = 0;
160 
161 int elo_comparator(const void*a, const void *b)
162 {
163  sortable_bbox box1 = *(sortable_bbox*)a;
164  sortable_bbox box2 = *(sortable_bbox*)b;
165  if(box1.elos[current_class] == box2.elos[current_class]) return 0;
166  if(box1.elos[current_class] > box2.elos[current_class]) return -1;
167  return 1;
168 }
169 
170 int bbox_comparator(const void *a, const void *b)
171 {
172  ++total_compares;
173  sortable_bbox box1 = *(sortable_bbox*)a;
174  sortable_bbox box2 = *(sortable_bbox*)b;
175  network net = box1.net;
176  int class = box1.class;
177 
178  image im1 = load_image_color(box1.filename, net.w, net.h);
179  image im2 = load_image_color(box2.filename, net.w, net.h);
180  float *X = calloc(net.w*net.h*net.c, sizeof(float));
181  memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
182  memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
183  float *predictions = network_predict(net, X);
184 
185  free_image(im1);
186  free_image(im2);
187  free(X);
188  if (predictions[class*2] > predictions[class*2+1]){
189  return 1;
190  }
191  return -1;
192 }
193 
194 void bbox_update(sortable_bbox *a, sortable_bbox *b, int class, int result)
195 {
196  int k = 32;
197  float EA = 1./(1+pow(10, (b->elos[class] - a->elos[class])/400.));
198  float EB = 1./(1+pow(10, (a->elos[class] - b->elos[class])/400.));
199  float SA = result ? 1 : 0;
200  float SB = result ? 0 : 1;
201  a->elos[class] += k*(SA - EA);
202  b->elos[class] += k*(SB - EB);
203 }
204 
205 void bbox_fight(network net, sortable_bbox *a, sortable_bbox *b, int classes, int class)
206 {
207  image im1 = load_image_color(a->filename, net.w, net.h);
208  image im2 = load_image_color(b->filename, net.w, net.h);
209  float *X = calloc(net.w*net.h*net.c, sizeof(float));
210  memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
211  memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
212  float *predictions = network_predict(net, X);
213  ++total_compares;
214 
215  int i;
216  for(i = 0; i < classes; ++i){
217  if(class < 0 || class == i){
218  int result = predictions[i*2] > predictions[i*2+1];
219  bbox_update(a, b, i, result);
220  }
221  }
222 
223  free_image(im1);
224  free_image(im2);
225  free(X);
226 }
227 
228 void SortMaster3000(char *filename, char *weightfile)
229 {
230  int i = 0;
231  network net = parse_network_cfg(filename);
232  if(weightfile){
233  load_weights(&net, weightfile);
234  }
235  srand(time(0));
236  set_batch_network(&net, 1);
237 
238  list *plist = get_paths("data/compare.sort.list");
239  //list *plist = get_paths("data/compare.val.old");
240  char **paths = (char **)list_to_array(plist);
241  int N = plist->size;
242  free_list(plist);
243  sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
244  printf("Sorting %d boxes...\n", N);
245  for(i = 0; i < N; ++i){
246  boxes[i].filename = paths[i];
247  boxes[i].net = net;
248  boxes[i].class = 7;
249  boxes[i].elo = 1500;
250  }
251  clock_t time=clock();
252  qsort(boxes, N, sizeof(sortable_bbox), bbox_comparator);
253  for(i = 0; i < N; ++i){
254  printf("%s\n", boxes[i].filename);
255  }
256  printf("Sorted in %d compares, %f secs\n", total_compares, sec(clock()-time));
257 }
258 
259 void BattleRoyaleWithCheese(char *filename, char *weightfile)
260 {
261  int classes = 20;
262  int i,j;
263  network net = parse_network_cfg(filename);
264  if(weightfile){
265  load_weights(&net, weightfile);
266  }
267  srand(time(0));
268  set_batch_network(&net, 1);
269 
270  list *plist = get_paths("data/compare.sort.list");
271  //list *plist = get_paths("data/compare.small.list");
272  //list *plist = get_paths("data/compare.cat.list");
273  //list *plist = get_paths("data/compare.val.old");
274  char **paths = (char **)list_to_array(plist);
275  int N = plist->size;
276  int total = N;
277  free_list(plist);
278  sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
279  printf("Battling %d boxes...\n", N);
280  for(i = 0; i < N; ++i){
281  boxes[i].filename = paths[i];
282  boxes[i].net = net;
283  boxes[i].classes = classes;
284  boxes[i].elos = calloc(classes, sizeof(float));;
285  for(j = 0; j < classes; ++j){
286  boxes[i].elos[j] = 1500;
287  }
288  }
289  int round;
290  clock_t time=clock();
291  for(round = 1; round <= 4; ++round){
292  clock_t round_time=clock();
293  printf("Round: %d\n", round);
294  shuffle(boxes, N, sizeof(sortable_bbox));
295  for(i = 0; i < N/2; ++i){
296  bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1);
297  }
298  printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
299  }
300 
301  int class;
302 
303  for (class = 0; class < classes; ++class){
304 
305  N = total;
306  current_class = class;
307  qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
308  N /= 2;
309 
310  for(round = 1; round <= 100; ++round){
311  clock_t round_time=clock();
312  printf("Round: %d\n", round);
313 
314  sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
315  for(i = 0; i < N/2; ++i){
316  bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class);
317  }
318  qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
319  if(round <= 20) N = (N*9/10)/2*2;
320 
321  printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
322  }
323  char buff[256];
324  sprintf(buff, "results/battle_%d.log", class);
325  FILE *outfp = fopen(buff, "w");
326  for(i = 0; i < N; ++i){
327  fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elos[class]);
328  }
329  fclose(outfp);
330  }
331  printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time));
332 }
333 
334 void run_compare(int argc, char **argv)
335 {
336  if(argc < 4){
337  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
338  return;
339  }
340 
341  char *cfg = argv[3];
342  char *weights = (argc > 4) ? argv[4] : 0;
343  //char *filename = (argc > 5) ? argv[5]: 0;
344  if(0==strcmp(argv[2], "train")) train_compare(cfg, weights);
345  else if(0==strcmp(argv[2], "valid")) validate_compare(cfg, weights);
346  else if(0==strcmp(argv[2], "sort")) SortMaster3000(cfg, weights);
347  else if(0==strcmp(argv[2], "battle")) BattleRoyaleWithCheese(cfg, weights);
348  /*
349  else if(0==strcmp(argv[2], "train")) train_coco(cfg, weights);
350  else if(0==strcmp(argv[2], "extract")) extract_boxes(cfg, weights);
351  else if(0==strcmp(argv[2], "valid")) validate_recall(cfg, weights);
352  */
353 }
354 
355 #endif
float decay
Definition: darknet.h:447
char ** paths
Definition: darknet.h:553
int rows
Definition: darknet.h:533
pthread_t load_data_in_thread(load_args args)
Definition: data.c:1135
void set_batch_network(network *net, int b)
Definition: network.c:339
int w
Definition: darknet.h:559
float learning_rate
Definition: darknet.h:445
float momentum
Definition: darknet.h:446
void free_data(data d)
Definition: data.c:665
char * basecfg(char *cfgfile)
Definition: utils.c:179
void ** list_to_array(list *l)
Definition: list.c:82
size_t * seen
Definition: darknet.h:437
float train_network(network *net, data d)
Definition: network.c:314
int size
Definition: darknet.h:603
void free_list(list *l)
Definition: list.c:67
Definition: darknet.h:512
int h
Definition: darknet.h:558
void free_network(network *net)
Definition: network.c:716
data_type type
Definition: darknet.h:580
void save_weights(network *net, char *filename)
Definition: parser.c:1080
network * parse_network_cfg(char *filename)
Definition: parser.c:742
network_predict
Definition: darknet.py:79
int h
Definition: darknet.h:514
int m
Definition: darknet.h:556
data * d
Definition: darknet.h:577
free_image
Definition: darknet.py:95
void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections)
Definition: utils.c:74
image load_image_color(char *filename, int w, int h)
Definition: image.c:1486
int classes
Definition: darknet.h:566
void shuffle(void *arr, size_t n, size_t size)
Definition: utils.c:85
float sec(clock_t clocks)
Definition: utils.c:232
int n
Definition: darknet.h:555
matrix network_predict_data(network *net, data test)
Definition: network.c:616
void free_matrix(matrix m)
Definition: matrix.c:10
void * load_thread(void *ptr)
Definition: data.c:1090
int c
Definition: darknet.h:515
int w
Definition: darknet.h:513
Definition: darknet.h:602
matrix X
Definition: darknet.h:540
int c
Definition: darknet.h:468
float ** vals
Definition: darknet.h:534
int h
Definition: darknet.h:468
list * get_paths(char *filename)
Definition: data.c:12
free_ptrs
Definition: darknet.py:76
void load_weights(network *net, char *filename)
Definition: parser.c:1308
int w
Definition: darknet.h:468
list classes
Definition: voc_label.py:9
Definition: darknet.h:538
float * data
Definition: darknet.h:516
matrix y
Definition: darknet.h:541