darknet  v3
writing.c
Go to the documentation of this file.
1 #include "darknet.h"
2 
3 void train_writing(char *cfgfile, char *weightfile)
4 {
5  char *backup_directory = "/home/pjreddie/backup/";
6  srand(time(0));
7  float avg_loss = -1;
8  char *base = basecfg(cfgfile);
9  printf("%s\n", base);
10  network net = parse_network_cfg(cfgfile);
11  if(weightfile){
12  load_weights(&net, weightfile);
13  }
14  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
15  int imgs = net.batch*net.subdivisions;
16  list *plist = get_paths("figures.list");
17  char **paths = (char **)list_to_array(plist);
18  clock_t time;
19  int N = plist->size;
20  printf("N: %d\n", N);
21  image out = get_network_image(net);
22 
23  data train, buffer;
24 
25  load_args args = {0};
26  args.w = net.w;
27  args.h = net.h;
28  args.out_w = out.w;
29  args.out_h = out.h;
30  args.paths = paths;
31  args.n = imgs;
32  args.m = N;
33  args.d = &buffer;
34  args.type = WRITING_DATA;
35 
36  pthread_t load_thread = load_data_in_thread(args);
37  int epoch = (*net.seen)/N;
38  while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
39  time=clock();
40  pthread_join(load_thread, 0);
41  train = buffer;
42  load_thread = load_data_in_thread(args);
43  printf("Loaded %lf seconds\n",sec(clock()-time));
44 
45  time=clock();
46  float loss = train_network(net, train);
47 
48  /*
49  image pred = float_to_image(64, 64, 1, out);
50  print_image(pred);
51  */
52 
53  /*
54  image im = float_to_image(256, 256, 3, train.X.vals[0]);
55  image lab = float_to_image(64, 64, 1, train.y.vals[0]);
56  image pred = float_to_image(64, 64, 1, out);
57  show_image(im, "image");
58  show_image(lab, "label");
59  print_image(lab);
60  show_image(pred, "pred");
61  cvWaitKey(0);
62  */
63 
64  if(avg_loss == -1) avg_loss = loss;
65  avg_loss = avg_loss*.9 + loss*.1;
66  printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
67  free_data(train);
68  if(get_current_batch(net)%100 == 0){
69  char buff[256];
70  sprintf(buff, "%s/%s_batch_%ld.weights", backup_directory, base, get_current_batch(net));
71  save_weights(net, buff);
72  }
73  if(*net.seen/N > epoch){
74  epoch = *net.seen/N;
75  char buff[256];
76  sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
77  save_weights(net, buff);
78  }
79  }
80 }
81 
82 void test_writing(char *cfgfile, char *weightfile, char *filename)
83 {
84  network net = parse_network_cfg(cfgfile);
85  if(weightfile){
86  load_weights(&net, weightfile);
87  }
88  set_batch_network(&net, 1);
89  srand(2222222);
90  clock_t time;
91  char buff[256];
92  char *input = buff;
93  while(1){
94  if(filename){
95  strncpy(input, filename, 256);
96  }else{
97  printf("Enter Image Path: ");
98  fflush(stdout);
99  input = fgets(input, 256, stdin);
100  if(!input) return;
101  strtok(input, "\n");
102  }
103 
104  image im = load_image_color(input, 0, 0);
105  resize_network(&net, im.w, im.h);
106  printf("%d %d %d\n", im.h, im.w, im.c);
107  float *X = im.data;
108  time=clock();
109  network_predict(net, X);
110  printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
111  image pred = get_network_image(net);
112 
113  image upsampled = resize_image(pred, im.w, im.h);
114  image thresh = threshold_image(upsampled, .5);
115  pred = thresh;
116 
117  show_image(pred, "prediction");
118  show_image(im, "orig");
119 #ifdef OPENCV
120  cvWaitKey(0);
121  cvDestroyAllWindows();
122 #endif
123 
124  free_image(upsampled);
125  free_image(thresh);
126  free_image(im);
127  if (filename) break;
128  }
129 }
130 
131 void run_writing(int argc, char **argv)
132 {
133  if(argc < 4){
134  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
135  return;
136  }
137 
138  char *cfg = argv[3];
139  char *weights = (argc > 4) ? argv[4] : 0;
140  char *filename = (argc > 5) ? argv[5] : 0;
141  if(0==strcmp(argv[2], "train")) train_writing(cfg, weights);
142  else if(0==strcmp(argv[2], "test")) test_writing(cfg, weights, filename);
143 }
144 
float decay
Definition: darknet.h:447
char ** paths
Definition: darknet.h:553
int out_w
Definition: darknet.h:560
image resize_image(image im, int w, int h)
Definition: image.c:1351
int batch
Definition: darknet.h:436
pthread_t load_data_in_thread(load_args args)
Definition: data.c:1135
int out_h
Definition: darknet.h:561
void set_batch_network(network *net, int b)
Definition: network.c:339
int w
Definition: darknet.h:559
float learning_rate
Definition: darknet.h:445
float momentum
Definition: darknet.h:446
void free_data(data d)
Definition: data.c:665
int show_image(image p, const char *name, int ms)
Definition: image.c:575
char * basecfg(char *cfgfile)
Definition: utils.c:179
void ** list_to_array(list *l)
Definition: list.c:82
size_t * seen
Definition: darknet.h:437
float train_network(network *net, data d)
Definition: network.c:314
int size
Definition: darknet.h:603
void test_writing(char *cfgfile, char *weightfile, char *filename)
Definition: writing.c:82
Definition: darknet.h:512
int h
Definition: darknet.h:558
data_type type
Definition: darknet.h:580
void save_weights(network *net, char *filename)
Definition: parser.c:1080
network * parse_network_cfg(char *filename)
Definition: parser.c:742
network_predict
Definition: darknet.py:79
image threshold_image(image im, float thresh)
Definition: image.c:1228
image get_network_image(network *net)
Definition: network.c:466
void run_writing(int argc, char **argv)
Definition: writing.c:131
int h
Definition: darknet.h:514
int max_batches
Definition: darknet.h:453
int resize_network(network *net, int w, int h)
Definition: network.c:358
int m
Definition: darknet.h:556
data * d
Definition: darknet.h:577
free_image
Definition: darknet.py:95
int subdivisions
Definition: darknet.h:440
void train_writing(char *cfgfile, char *weightfile)
Definition: writing.c:3
image load_image_color(char *filename, int w, int h)
Definition: image.c:1486
float get_current_rate(network *net)
Definition: network.c:90
float sec(clock_t clocks)
Definition: utils.c:232
int n
Definition: darknet.h:555
void * load_thread(void *ptr)
Definition: data.c:1090
int c
Definition: darknet.h:515
int w
Definition: darknet.h:513
Definition: darknet.h:602
size_t get_current_batch(network *net)
Definition: network.c:63
int h
Definition: darknet.h:468
list * get_paths(char *filename)
Definition: data.c:12
void load_weights(network *net, char *filename)
Definition: parser.c:1308
int w
Definition: darknet.h:468
Definition: darknet.h:538
float * data
Definition: darknet.h:516