darknet  v3
super.c
Go to the documentation of this file.
1 #include "darknet.h"
2 
3 void train_super(char *cfgfile, char *weightfile, int clear)
4 {
5  char *train_images = "/data/imagenet/imagenet1k.train.list";
6  char *backup_directory = "/home/pjreddie/backup/";
7  srand(time(0));
8  char *base = basecfg(cfgfile);
9  printf("%s\n", base);
10  float avg_loss = -1;
11  network *net = load_network(cfgfile, weightfile, clear);
12  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
13  int imgs = net->batch*net->subdivisions;
14  int i = *net->seen/imgs;
15  data train, buffer;
16 
17 
18  list *plist = get_paths(train_images);
19  //int N = plist->size;
20  char **paths = (char **)list_to_array(plist);
21 
22  load_args args = {0};
23  args.w = net->w;
24  args.h = net->h;
25  args.scale = 4;
26  args.paths = paths;
27  args.n = imgs;
28  args.m = plist->size;
29  args.d = &buffer;
30  args.type = SUPER_DATA;
31 
32  pthread_t load_thread = load_data_in_thread(args);
33  clock_t time;
34  //while(i*imgs < N*120){
35  while(get_current_batch(net) < net->max_batches){
36  i += 1;
37  time=clock();
38  pthread_join(load_thread, 0);
39  train = buffer;
40  load_thread = load_data_in_thread(args);
41 
42  printf("Loaded: %lf seconds\n", sec(clock()-time));
43 
44  time=clock();
45  float loss = train_network(net, train);
46  if (avg_loss < 0) avg_loss = loss;
47  avg_loss = avg_loss*.9 + loss*.1;
48 
49  printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
50  if(i%1000==0){
51  char buff[256];
52  sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
53  save_weights(net, buff);
54  }
55  if(i%100==0){
56  char buff[256];
57  sprintf(buff, "%s/%s.backup", backup_directory, base);
58  save_weights(net, buff);
59  }
60  free_data(train);
61  }
62  char buff[256];
63  sprintf(buff, "%s/%s_final.weights", backup_directory, base);
64  save_weights(net, buff);
65 }
66 
67 void test_super(char *cfgfile, char *weightfile, char *filename)
68 {
69  network *net = load_network(cfgfile, weightfile, 0);
70  set_batch_network(net, 1);
71  srand(2222222);
72 
73  clock_t time;
74  char buff[256];
75  char *input = buff;
76  while(1){
77  if(filename){
78  strncpy(input, filename, 256);
79  }else{
80  printf("Enter Image Path: ");
81  fflush(stdout);
82  input = fgets(input, 256, stdin);
83  if(!input) return;
84  strtok(input, "\n");
85  }
86  image im = load_image_color(input, 0, 0);
87  resize_network(net, im.w, im.h);
88  printf("%d %d\n", im.w, im.h);
89 
90  float *X = im.data;
91  time=clock();
92  network_predict(net, X);
93  image out = get_network_image(net);
94  printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
95  save_image(out, "out");
96  show_image(out, "out", 0);
97 
98  free_image(im);
99  if (filename) break;
100  }
101 }
102 
103 
104 void run_super(int argc, char **argv)
105 {
106  if(argc < 4){
107  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
108  return;
109  }
110 
111  char *cfg = argv[3];
112  char *weights = (argc > 4) ? argv[4] : 0;
113  char *filename = (argc > 5) ? argv[5] : 0;
114  int clear = find_arg(argc, argv, "-clear");
115  if(0==strcmp(argv[2], "train")) train_super(cfg, weights, clear);
116  else if(0==strcmp(argv[2], "test")) test_super(cfg, weights, filename);
117  /*
118  else if(0==strcmp(argv[2], "valid")) validate_super(cfg, weights);
119  */
120 }
void train_super(char *cfgfile, char *weightfile, int clear)
Definition: super.c:3
float decay
Definition: darknet.h:447
char ** paths
Definition: darknet.h:553
int batch
Definition: darknet.h:436
pthread_t load_data_in_thread(load_args args)
Definition: data.c:1135
int find_arg(int argc, char *argv[], char *arg)
Definition: utils.c:120
void set_batch_network(network *net, int b)
Definition: network.c:339
int w
Definition: darknet.h:559
float learning_rate
Definition: darknet.h:445
float momentum
Definition: darknet.h:446
void test_super(char *cfgfile, char *weightfile, char *filename)
Definition: super.c:67
void free_data(data d)
Definition: data.c:665
int show_image(image p, const char *name, int ms)
Definition: image.c:575
char * basecfg(char *cfgfile)
Definition: utils.c:179
void ** list_to_array(list *l)
Definition: list.c:82
size_t * seen
Definition: darknet.h:437
float train_network(network *net, data d)
Definition: network.c:314
int size
Definition: darknet.h:603
Definition: darknet.h:512
void save_image(image p, const char *name)
Definition: image.c:717
int h
Definition: darknet.h:558
data_type type
Definition: darknet.h:580
void save_weights(network *net, char *filename)
Definition: parser.c:1080
network_predict
Definition: darknet.py:79
image get_network_image(network *net)
Definition: network.c:466
int h
Definition: darknet.h:514
int max_batches
Definition: darknet.h:453
int resize_network(network *net, int w, int h)
Definition: network.c:358
int m
Definition: darknet.h:556
data * d
Definition: darknet.h:577
free_image
Definition: darknet.py:95
int subdivisions
Definition: darknet.h:440
image load_image_color(char *filename, int w, int h)
Definition: image.c:1486
float get_current_rate(network *net)
Definition: network.c:90
int scale
Definition: darknet.h:568
float sec(clock_t clocks)
Definition: utils.c:232
network * load_network(char *cfg, char *weights, int clear)
Definition: network.c:53
int n
Definition: darknet.h:555
void * load_thread(void *ptr)
Definition: data.c:1090
int w
Definition: darknet.h:513
void run_super(int argc, char **argv)
Definition: super.c:104
Definition: darknet.h:602
size_t get_current_batch(network *net)
Definition: network.c:63
int h
Definition: darknet.h:468
list * get_paths(char *filename)
Definition: data.c:12
int w
Definition: darknet.h:468
Definition: darknet.h:538
float * data
Definition: darknet.h:516