darknet  v3
yolo.c
Go to the documentation of this file.
1 #include "darknet.h"
2 
3 char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
4 
5 void train_yolo(char *cfgfile, char *weightfile)
6 {
7  char *train_images = "/data/voc/train.txt";
8  char *backup_directory = "/home/pjreddie/backup/";
9  srand(time(0));
10  char *base = basecfg(cfgfile);
11  printf("%s\n", base);
12  float avg_loss = -1;
13  network *net = load_network(cfgfile, weightfile, 0);
14  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
15  int imgs = net->batch*net->subdivisions;
16  int i = *net->seen/imgs;
17  data train, buffer;
18 
19 
20  layer l = net->layers[net->n - 1];
21 
22  int side = l.side;
23  int classes = l.classes;
24  float jitter = l.jitter;
25 
26  list *plist = get_paths(train_images);
27  //int N = plist->size;
28  char **paths = (char **)list_to_array(plist);
29 
30  load_args args = {0};
31  args.w = net->w;
32  args.h = net->h;
33  args.paths = paths;
34  args.n = imgs;
35  args.m = plist->size;
36  args.classes = classes;
37  args.jitter = jitter;
38  args.num_boxes = side;
39  args.d = &buffer;
40  args.type = REGION_DATA;
41 
42  args.angle = net->angle;
43  args.exposure = net->exposure;
44  args.saturation = net->saturation;
45  args.hue = net->hue;
46 
47  pthread_t load_thread = load_data_in_thread(args);
48  clock_t time;
49  //while(i*imgs < N*120){
50  while(get_current_batch(net) < net->max_batches){
51  i += 1;
52  time=clock();
53  pthread_join(load_thread, 0);
54  train = buffer;
55  load_thread = load_data_in_thread(args);
56 
57  printf("Loaded: %lf seconds\n", sec(clock()-time));
58 
59  time=clock();
60  float loss = train_network(net, train);
61  if (avg_loss < 0) avg_loss = loss;
62  avg_loss = avg_loss*.9 + loss*.1;
63 
64  printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
65  if(i%1000==0 || (i < 1000 && i%100 == 0)){
66  char buff[256];
67  sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
68  save_weights(net, buff);
69  }
70  free_data(train);
71  }
72  char buff[256];
73  sprintf(buff, "%s/%s_final.weights", backup_directory, base);
74  save_weights(net, buff);
75 }
76 
77 void print_yolo_detections(FILE **fps, char *id, int total, int classes, int w, int h, detection *dets)
78 {
79  int i, j;
80  for(i = 0; i < total; ++i){
81  float xmin = dets[i].bbox.x - dets[i].bbox.w/2.;
82  float xmax = dets[i].bbox.x + dets[i].bbox.w/2.;
83  float ymin = dets[i].bbox.y - dets[i].bbox.h/2.;
84  float ymax = dets[i].bbox.y + dets[i].bbox.h/2.;
85 
86  if (xmin < 0) xmin = 0;
87  if (ymin < 0) ymin = 0;
88  if (xmax > w) xmax = w;
89  if (ymax > h) ymax = h;
90 
91  for(j = 0; j < classes; ++j){
92  if (dets[i].prob[j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, dets[i].prob[j],
93  xmin, ymin, xmax, ymax);
94  }
95  }
96 }
97 
98 void validate_yolo(char *cfg, char *weights)
99 {
100  network *net = load_network(cfg, weights, 0);
101  set_batch_network(net, 1);
102  fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
103  srand(time(0));
104 
105  char *base = "results/comp4_det_test_";
106  //list *plist = get_paths("data/voc.2007.test");
107  list *plist = get_paths("/home/pjreddie/data/voc/2007_test.txt");
108  //list *plist = get_paths("data/voc.2012.test");
109  char **paths = (char **)list_to_array(plist);
110 
111  layer l = net->layers[net->n-1];
112  int classes = l.classes;
113 
114  int j;
115  FILE **fps = calloc(classes, sizeof(FILE *));
116  for(j = 0; j < classes; ++j){
117  char buff[1024];
118  snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
119  fps[j] = fopen(buff, "w");
120  }
121 
122  int m = plist->size;
123  int i=0;
124  int t;
125 
126  float thresh = .001;
127  int nms = 1;
128  float iou_thresh = .5;
129 
130  int nthreads = 8;
131  image *val = calloc(nthreads, sizeof(image));
132  image *val_resized = calloc(nthreads, sizeof(image));
133  image *buf = calloc(nthreads, sizeof(image));
134  image *buf_resized = calloc(nthreads, sizeof(image));
135  pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
136 
137  load_args args = {0};
138  args.w = net->w;
139  args.h = net->h;
140  args.type = IMAGE_DATA;
141 
142  for(t = 0; t < nthreads; ++t){
143  args.path = paths[i+t];
144  args.im = &buf[t];
145  args.resized = &buf_resized[t];
146  thr[t] = load_data_in_thread(args);
147  }
148  time_t start = time(0);
149  for(i = nthreads; i < m+nthreads; i += nthreads){
150  fprintf(stderr, "%d\n", i);
151  for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
152  pthread_join(thr[t], 0);
153  val[t] = buf[t];
154  val_resized[t] = buf_resized[t];
155  }
156  for(t = 0; t < nthreads && i+t < m; ++t){
157  args.path = paths[i+t];
158  args.im = &buf[t];
159  args.resized = &buf_resized[t];
160  thr[t] = load_data_in_thread(args);
161  }
162  for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
163  char *path = paths[i+t-nthreads];
164  char *id = basecfg(path);
165  float *X = val_resized[t].data;
166  network_predict(net, X);
167  int w = val[t].w;
168  int h = val[t].h;
169  int nboxes = 0;
170  detection *dets = get_network_boxes(net, w, h, thresh, 0, 0, 0, &nboxes);
171  if (nms) do_nms_sort(dets, l.side*l.side*l.n, classes, iou_thresh);
172  print_yolo_detections(fps, id, l.side*l.side*l.n, classes, w, h, dets);
173  free_detections(dets, nboxes);
174  free(id);
175  free_image(val[t]);
176  free_image(val_resized[t]);
177  }
178  }
179  fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
180 }
181 
182 void validate_yolo_recall(char *cfg, char *weights)
183 {
184  network *net = load_network(cfg, weights, 0);
185  set_batch_network(net, 1);
186  fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
187  srand(time(0));
188 
189  char *base = "results/comp4_det_test_";
190  list *plist = get_paths("data/voc.2007.test");
191  char **paths = (char **)list_to_array(plist);
192 
193  layer l = net->layers[net->n-1];
194  int classes = l.classes;
195  int side = l.side;
196 
197  int j, k;
198  FILE **fps = calloc(classes, sizeof(FILE *));
199  for(j = 0; j < classes; ++j){
200  char buff[1024];
201  snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
202  fps[j] = fopen(buff, "w");
203  }
204 
205  int m = plist->size;
206  int i=0;
207 
208  float thresh = .001;
209  float iou_thresh = .5;
210  float nms = 0;
211 
212  int total = 0;
213  int correct = 0;
214  int proposals = 0;
215  float avg_iou = 0;
216 
217  for(i = 0; i < m; ++i){
218  char *path = paths[i];
219  image orig = load_image_color(path, 0, 0);
220  image sized = resize_image(orig, net->w, net->h);
221  char *id = basecfg(path);
222  network_predict(net, sized.data);
223 
224  int nboxes = 0;
225  detection *dets = get_network_boxes(net, orig.w, orig.h, thresh, 0, 0, 1, &nboxes);
226  if (nms) do_nms_obj(dets, side*side*l.n, 1, nms);
227 
228  char labelpath[4096];
229  find_replace(path, "images", "labels", labelpath);
230  find_replace(labelpath, "JPEGImages", "labels", labelpath);
231  find_replace(labelpath, ".jpg", ".txt", labelpath);
232  find_replace(labelpath, ".JPEG", ".txt", labelpath);
233 
234  int num_labels = 0;
235  box_label *truth = read_boxes(labelpath, &num_labels);
236  for(k = 0; k < side*side*l.n; ++k){
237  if(dets[k].objectness > thresh){
238  ++proposals;
239  }
240  }
241  for (j = 0; j < num_labels; ++j) {
242  ++total;
243  box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
244  float best_iou = 0;
245  for(k = 0; k < side*side*l.n; ++k){
246  float iou = box_iou(dets[k].bbox, t);
247  if(dets[k].objectness > thresh && iou > best_iou){
248  best_iou = iou;
249  }
250  }
251  avg_iou += best_iou;
252  if(best_iou > iou_thresh){
253  ++correct;
254  }
255  }
256 
257  fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
258  free_detections(dets, nboxes);
259  free(id);
260  free_image(orig);
261  free_image(sized);
262  }
263 }
264 
265 void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
266 {
267  image **alphabet = load_alphabet();
268  network *net = load_network(cfgfile, weightfile, 0);
269  layer l = net->layers[net->n-1];
270  set_batch_network(net, 1);
271  srand(2222222);
272  clock_t time;
273  char buff[256];
274  char *input = buff;
275  float nms=.4;
276  while(1){
277  if(filename){
278  strncpy(input, filename, 256);
279  } else {
280  printf("Enter Image Path: ");
281  fflush(stdout);
282  input = fgets(input, 256, stdin);
283  if(!input) return;
284  strtok(input, "\n");
285  }
286  image im = load_image_color(input,0,0);
287  image sized = resize_image(im, net->w, net->h);
288  float *X = sized.data;
289  time=clock();
290  network_predict(net, X);
291  printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
292 
293  int nboxes = 0;
294  detection *dets = get_network_boxes(net, 1, 1, thresh, 0, 0, 0, &nboxes);
295  if (nms) do_nms_sort(dets, l.side*l.side*l.n, l.classes, nms);
296 
297  draw_detections(im, dets, l.side*l.side*l.n, thresh, voc_names, alphabet, 20);
298  save_image(im, "predictions");
299  show_image(im, "predictions", 0);
300  free_detections(dets, nboxes);
301  free_image(im);
302  free_image(sized);
303  if (filename) break;
304  }
305 }
306 
307 void run_yolo(int argc, char **argv)
308 {
309  char *prefix = find_char_arg(argc, argv, "-prefix", 0);
310  float thresh = find_float_arg(argc, argv, "-thresh", .2);
311  int cam_index = find_int_arg(argc, argv, "-c", 0);
312  int frame_skip = find_int_arg(argc, argv, "-s", 0);
313  if(argc < 4){
314  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
315  return;
316  }
317 
318  int avg = find_int_arg(argc, argv, "-avg", 1);
319  char *cfg = argv[3];
320  char *weights = (argc > 4) ? argv[4] : 0;
321  char *filename = (argc > 5) ? argv[5]: 0;
322  if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh);
323  else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);
324  else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);
325  else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);
326  else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix, avg, .5, 0,0,0,0);
327 }
free_detections
Definition: darknet.py:73
void train_yolo(char *cfgfile, char *weightfile)
Definition: yolo.c:5
float hue
Definition: darknet.h:576
float decay
Definition: darknet.h:447
char ** paths
Definition: darknet.h:553
image resize_image(image im, int w, int h)
Definition: image.c:1351
int batch
Definition: darknet.h:436
image * im
Definition: darknet.h:578
pthread_t load_data_in_thread(load_args args)
Definition: data.c:1135
void set_batch_network(network *net, int b)
Definition: network.c:339
int n
Definition: darknet.h:142
void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
Definition: yolo.c:265
box_label * read_boxes(char *filename, int *n)
Definition: data.c:139
int w
Definition: darknet.h:559
float learning_rate
Definition: darknet.h:445
float hue
Definition: darknet.h:478
float momentum
Definition: darknet.h:446
float h
Definition: darknet.h:520
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, int avg, float hier_thresh, int w, int h, int fps, int fullscreen)
Definition: demo.c:359
void free_data(data d)
Definition: data.c:665
int show_image(image p, const char *name, int ms)
Definition: image.c:575
char * find_char_arg(int argc, char **argv, char *arg, char *def)
Definition: utils.c:163
char * basecfg(char *cfgfile)
Definition: utils.c:179
void ** list_to_array(list *l)
Definition: list.c:82
size_t * seen
Definition: darknet.h:437
float train_network(network *net, data d)
Definition: network.c:314
int size
Definition: darknet.h:603
Definition: darknet.h:512
void save_image(image p, const char *name)
Definition: image.c:717
void validate_yolo(char *cfg, char *weights)
Definition: yolo.c:98
int h
Definition: darknet.h:558
float jitter
Definition: darknet.h:571
data_type type
Definition: darknet.h:580
void save_weights(network *net, char *filename)
Definition: parser.c:1080
int side
Definition: darknet.h:146
do_nms_sort
Definition: darknet.py:92
void run_yolo(int argc, char **argv)
Definition: yolo.c:307
network_predict
Definition: darknet.py:79
float exposure
Definition: darknet.h:575
float y
Definition: darknet.h:586
image ** load_alphabet()
Definition: image.c:223
void validate_yolo_recall(char *cfg, char *weights)
Definition: yolo.c:182
float w
Definition: darknet.h:520
int h
Definition: darknet.h:514
int max_batches
Definition: darknet.h:453
char * voc_names[]
Definition: yolo.c:3
void print_yolo_detections(FILE **fps, char *id, int total, int classes, int w, int h, detection *dets)
Definition: yolo.c:77
float x
Definition: darknet.h:520
int num_boxes
Definition: darknet.h:564
int m
Definition: darknet.h:556
layer * layers
Definition: darknet.h:441
data * d
Definition: darknet.h:577
free_image
Definition: darknet.py:95
float w
Definition: darknet.h:586
int subdivisions
Definition: darknet.h:440
image load_image_color(char *filename, int w, int h)
Definition: image.c:1486
float find_float_arg(int argc, char **argv, char *arg, float def)
Definition: utils.c:148
int classes
Definition: darknet.h:566
float get_current_rate(network *net)
Definition: network.c:90
float saturation
Definition: darknet.h:477
float h
Definition: darknet.h:586
float sec(clock_t clocks)
Definition: utils.c:232
float saturation
Definition: darknet.h:574
float x
Definition: darknet.h:586
box bbox
Definition: darknet.h:524
int find_int_arg(int argc, char **argv, char *arg, int def)
Definition: utils.c:133
network * load_network(char *cfg, char *weights, int clear)
Definition: network.c:53
int n
Definition: darknet.h:555
void * load_thread(void *ptr)
Definition: data.c:1090
float box_iou(box a, box b)
Definition: box.c:179
int w
Definition: darknet.h:513
char * path
Definition: darknet.h:554
Definition: darknet.h:602
image * resized
Definition: darknet.h:579
int classes
Definition: darknet.h:172
float jitter
Definition: darknet.h:163
size_t get_current_batch(network *net)
Definition: network.c:63
void draw_detections(image im, detection *dets, int num, float thresh, char **names, image **alphabet, int classes)
Definition: image.c:239
void find_replace(char *str, char *orig, char *rep, char *output)
Definition: utils.c:216
int h
Definition: darknet.h:468
int n
Definition: darknet.h:435
get_network_boxes
Definition: darknet.py:65
float y
Definition: darknet.h:520
float angle
Definition: darknet.h:572
list * get_paths(char *filename)
Definition: data.c:12
int w
Definition: darknet.h:468
list classes
Definition: voc_label.py:9
do_nms_obj
Definition: darknet.py:89
Definition: darknet.h:538
Definition: darknet.h:519
Definition: darknet.h:119
float * data
Definition: darknet.h:516
float exposure
Definition: darknet.h:476
float angle
Definition: darknet.h:474