darknet  v3
swag.c
Go to the documentation of this file.
1 #include "darknet.h"
2 #include <sys/time.h>
3 
4 void train_swag(char *cfgfile, char *weightfile)
5 {
6  char *train_images = "data/voc.0712.trainval";
7  char *backup_directory = "/home/pjreddie/backup/";
8  srand(time(0));
9  char *base = basecfg(cfgfile);
10  printf("%s\n", base);
11  float avg_loss = -1;
12  network net = parse_network_cfg(cfgfile);
13  if(weightfile){
14  load_weights(&net, weightfile);
15  }
16  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
17  int imgs = net.batch*net.subdivisions;
18  int i = *net.seen/imgs;
19  data train, buffer;
20 
21  layer l = net.layers[net.n - 1];
22 
23  int side = l.side;
24  int classes = l.classes;
25  float jitter = l.jitter;
26 
27  list *plist = get_paths(train_images);
28  //int N = plist->size;
29  char **paths = (char **)list_to_array(plist);
30 
31  load_args args = {0};
32  args.w = net.w;
33  args.h = net.h;
34  args.paths = paths;
35  args.n = imgs;
36  args.m = plist->size;
37  args.classes = classes;
38  args.jitter = jitter;
39  args.num_boxes = side;
40  args.d = &buffer;
41  args.type = REGION_DATA;
42 
43  pthread_t load_thread = load_data_in_thread(args);
44  clock_t time;
45  //while(i*imgs < N*120){
46  while(get_current_batch(net) < net.max_batches){
47  i += 1;
48  time=clock();
49  pthread_join(load_thread, 0);
50  train = buffer;
51  load_thread = load_data_in_thread(args);
52 
53  printf("Loaded: %lf seconds\n", sec(clock()-time));
54 
55  time=clock();
56  float loss = train_network(net, train);
57  if (avg_loss < 0) avg_loss = loss;
58  avg_loss = avg_loss*.9 + loss*.1;
59 
60  printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
61  if(i%1000==0 || i == 600){
62  char buff[256];
63  sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
64  save_weights(net, buff);
65  }
66  free_data(train);
67  }
68  char buff[256];
69  sprintf(buff, "%s/%s_final.weights", backup_directory, base);
70  save_weights(net, buff);
71 }
72 
73 void run_swag(int argc, char **argv)
74 {
75  if(argc < 4){
76  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
77  return;
78  }
79 
80  char *cfg = argv[3];
81  char *weights = (argc > 4) ? argv[4] : 0;
82  if(0==strcmp(argv[2], "train")) train_swag(cfg, weights);
83 }
float decay
Definition: darknet.h:447
char ** paths
Definition: darknet.h:553
int batch
Definition: darknet.h:436
pthread_t load_data_in_thread(load_args args)
Definition: data.c:1135
int w
Definition: darknet.h:559
void train_swag(char *cfgfile, char *weightfile)
Definition: swag.c:4
float learning_rate
Definition: darknet.h:445
float momentum
Definition: darknet.h:446
void free_data(data d)
Definition: data.c:665
char * basecfg(char *cfgfile)
Definition: utils.c:179
void ** list_to_array(list *l)
Definition: list.c:82
size_t * seen
Definition: darknet.h:437
float train_network(network *net, data d)
Definition: network.c:314
int size
Definition: darknet.h:603
int h
Definition: darknet.h:558
float jitter
Definition: darknet.h:571
data_type type
Definition: darknet.h:580
void run_swag(int argc, char **argv)
Definition: swag.c:73
void save_weights(network *net, char *filename)
Definition: parser.c:1080
network * parse_network_cfg(char *filename)
Definition: parser.c:742
int side
Definition: darknet.h:146
int max_batches
Definition: darknet.h:453
int num_boxes
Definition: darknet.h:564
int m
Definition: darknet.h:556
layer * layers
Definition: darknet.h:441
data * d
Definition: darknet.h:577
int subdivisions
Definition: darknet.h:440
int classes
Definition: darknet.h:566
float get_current_rate(network *net)
Definition: network.c:90
float sec(clock_t clocks)
Definition: utils.c:232
int n
Definition: darknet.h:555
void * load_thread(void *ptr)
Definition: data.c:1090
Definition: darknet.h:602
int classes
Definition: darknet.h:172
float jitter
Definition: darknet.h:163
size_t get_current_batch(network *net)
Definition: network.c:63
int h
Definition: darknet.h:468
int n
Definition: darknet.h:435
list * get_paths(char *filename)
Definition: data.c:12
void load_weights(network *net, char *filename)
Definition: parser.c:1308
int w
Definition: darknet.h:468
list classes
Definition: voc_label.py:9
Definition: darknet.h:538
Definition: darknet.h:119