darknet  v3
captcha.c
Go to the documentation of this file.
1 #include "darknet.h"
2 
3 void fix_data_captcha(data d, int mask)
4 {
5  matrix labels = d.y;
6  int i, j;
7  for(i = 0; i < d.y.rows; ++i){
8  for(j = 0; j < d.y.cols; j += 2){
9  if (mask){
10  if(!labels.vals[i][j]){
11  labels.vals[i][j] = SECRET_NUM;
12  labels.vals[i][j+1] = SECRET_NUM;
13  }else if(labels.vals[i][j+1]){
14  labels.vals[i][j] = 0;
15  }
16  } else{
17  if (labels.vals[i][j]) {
18  labels.vals[i][j+1] = 0;
19  } else {
20  labels.vals[i][j+1] = 1;
21  }
22  }
23  }
24  }
25 }
26 
27 void train_captcha(char *cfgfile, char *weightfile)
28 {
29  srand(time(0));
30  float avg_loss = -1;
31  char *base = basecfg(cfgfile);
32  printf("%s\n", base);
33  network *net = load_network(cfgfile, weightfile, 0);
34  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
35  int imgs = 1024;
36  int i = *net->seen/imgs;
37  int solved = 1;
38  list *plist;
39  char **labels = get_labels("/data/captcha/reimgs.labels.list");
40  if (solved){
41  plist = get_paths("/data/captcha/reimgs.solved.list");
42  }else{
43  plist = get_paths("/data/captcha/reimgs.raw.list");
44  }
45  char **paths = (char **)list_to_array(plist);
46  printf("%d\n", plist->size);
47  clock_t time;
48  pthread_t load_thread;
49  data train;
50  data buffer;
51 
52  load_args args = {0};
53  args.w = net->w;
54  args.h = net->h;
55  args.paths = paths;
56  args.classes = 26;
57  args.n = imgs;
58  args.m = plist->size;
59  args.labels = labels;
60  args.d = &buffer;
62 
63  load_thread = load_data_in_thread(args);
64  while(1){
65  ++i;
66  time=clock();
67  pthread_join(load_thread, 0);
68  train = buffer;
69  fix_data_captcha(train, solved);
70 
71  /*
72  image im = float_to_image(256, 256, 3, train.X.vals[114]);
73  show_image(im, "training");
74  cvWaitKey(0);
75  */
76 
77  load_thread = load_data_in_thread(args);
78  printf("Loaded: %lf seconds\n", sec(clock()-time));
79  time=clock();
80  float loss = train_network(net, train);
81  if(avg_loss == -1) avg_loss = loss;
82  avg_loss = avg_loss*.9 + loss*.1;
83  printf("%d: %f, %f avg, %lf seconds, %ld images\n", i, loss, avg_loss, sec(clock()-time), *net->seen);
84  free_data(train);
85  if(i%100==0){
86  char buff[256];
87  sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
88  save_weights(net, buff);
89  }
90  }
91 }
92 
93 void test_captcha(char *cfgfile, char *weightfile, char *filename)
94 {
95  network *net = load_network(cfgfile, weightfile, 0);
96  set_batch_network(net, 1);
97  srand(2222222);
98  int i = 0;
99  char **names = get_labels("/data/captcha/reimgs.labels.list");
100  char buff[256];
101  char *input = buff;
102  int indexes[26];
103  while(1){
104  if(filename){
105  strncpy(input, filename, 256);
106  }else{
107  //printf("Enter Image Path: ");
108  //fflush(stdout);
109  input = fgets(input, 256, stdin);
110  if(!input) return;
111  strtok(input, "\n");
112  }
113  image im = load_image_color(input, net->w, net->h);
114  float *X = im.data;
115  float *predictions = network_predict(net, X);
116  top_predictions(net, 26, indexes);
117  //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
118  for(i = 0; i < 26; ++i){
119  int index = indexes[i];
120  if(i != 0) printf(", ");
121  printf("%s %f", names[index], predictions[index]);
122  }
123  printf("\n");
124  fflush(stdout);
125  free_image(im);
126  if (filename) break;
127  }
128 }
129 
130 void valid_captcha(char *cfgfile, char *weightfile, char *filename)
131 {
132  char **labels = get_labels("/data/captcha/reimgs.labels.list");
133  network *net = load_network(cfgfile, weightfile, 0);
134  list *plist = get_paths("/data/captcha/reimgs.fg.list");
135  char **paths = (char **)list_to_array(plist);
136  int N = plist->size;
137  int outputs = net->outputs;
138 
139  set_batch_network(net, 1);
140  srand(2222222);
141  int i, j;
142  for(i = 0; i < N; ++i){
143  if (i%100 == 0) fprintf(stderr, "%d\n", i);
144  image im = load_image_color(paths[i], net->w, net->h);
145  float *X = im.data;
146  float *predictions = network_predict(net, X);
147  //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
148  int truth = -1;
149  for(j = 0; j < 13; ++j){
150  if (strstr(paths[i], labels[j])) truth = j;
151  }
152  if (truth == -1){
153  fprintf(stderr, "bad: %s\n", paths[i]);
154  return;
155  }
156  printf("%d, ", truth);
157  for(j = 0; j < outputs; ++j){
158  if (j != 0) printf(", ");
159  printf("%f", predictions[j]);
160  }
161  printf("\n");
162  fflush(stdout);
163  free_image(im);
164  if (filename) break;
165  }
166 }
167 
168 /*
169  void train_captcha(char *cfgfile, char *weightfile)
170  {
171  float avg_loss = -1;
172  srand(time(0));
173  char *base = basecfg(cfgfile);
174  printf("%s\n", base);
175  network net = parse_network_cfg(cfgfile);
176  if(weightfile){
177  load_weights(&net, weightfile);
178  }
179  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
180  int imgs = 1024;
181  int i = net->seen/imgs;
182  list *plist = get_paths("/data/captcha/train.auto5");
183  char **paths = (char **)list_to_array(plist);
184  printf("%d\n", plist->size);
185  clock_t time;
186  while(1){
187  ++i;
188  time=clock();
189  data train = load_data_captcha(paths, imgs, plist->size, 10, 200, 60);
190  translate_data_rows(train, -128);
191  scale_data_rows(train, 1./128);
192  printf("Loaded: %lf seconds\n", sec(clock()-time));
193  time=clock();
194  float loss = train_network(net, train);
195  net->seen += imgs;
196  if(avg_loss == -1) avg_loss = loss;
197  avg_loss = avg_loss*.9 + loss*.1;
198  printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net->seen);
199  free_data(train);
200  if(i%10==0){
201  char buff[256];
202  sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
203  save_weights(net, buff);
204  }
205  }
206  }
207 
208  void decode_captcha(char *cfgfile, char *weightfile)
209  {
210  setbuf(stdout, NULL);
211  srand(time(0));
212  network net = parse_network_cfg(cfgfile);
213  set_batch_network(&net, 1);
214  if(weightfile){
215  load_weights(&net, weightfile);
216  }
217  char filename[256];
218  while(1){
219  printf("Enter filename: ");
220  fgets(filename, 256, stdin);
221  strtok(filename, "\n");
222  image im = load_image_color(filename, 300, 57);
223  scale_image(im, 1./255.);
224  float *X = im.data;
225  float *predictions = network_predict(net, X);
226  image out = float_to_image(300, 57, 1, predictions);
227  show_image(out, "decoded");
228 #ifdef OPENCV
229 cvWaitKey(0);
230 #endif
231 free_image(im);
232 }
233 }
234 
235 void encode_captcha(char *cfgfile, char *weightfile)
236 {
237 float avg_loss = -1;
238 srand(time(0));
239 char *base = basecfg(cfgfile);
240 printf("%s\n", base);
241 network net = parse_network_cfg(cfgfile);
242 if(weightfile){
243  load_weights(&net, weightfile);
244 }
245 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
246 int imgs = 1024;
247 int i = net->seen/imgs;
248 list *plist = get_paths("/data/captcha/encode.list");
249 char **paths = (char **)list_to_array(plist);
250 printf("%d\n", plist->size);
251 clock_t time;
252 while(1){
253  ++i;
254  time=clock();
255  data train = load_data_captcha_encode(paths, imgs, plist->size, 300, 57);
256  scale_data_rows(train, 1./255);
257  printf("Loaded: %lf seconds\n", sec(clock()-time));
258  time=clock();
259  float loss = train_network(net, train);
260  net->seen += imgs;
261  if(avg_loss == -1) avg_loss = loss;
262  avg_loss = avg_loss*.9 + loss*.1;
263  printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net->seen);
264  free_matrix(train.X);
265  if(i%100==0){
266  char buff[256];
267  sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
268  save_weights(net, buff);
269  }
270 }
271 }
272 
273 void validate_captcha(char *cfgfile, char *weightfile)
274 {
275  srand(time(0));
276  char *base = basecfg(cfgfile);
277  printf("%s\n", base);
278  network net = parse_network_cfg(cfgfile);
279  if(weightfile){
280  load_weights(&net, weightfile);
281  }
282  int numchars = 37;
283  list *plist = get_paths("/data/captcha/solved.hard");
284  char **paths = (char **)list_to_array(plist);
285  int imgs = plist->size;
286  data valid = load_data_captcha(paths, imgs, 0, 10, 200, 60);
287  translate_data_rows(valid, -128);
288  scale_data_rows(valid, 1./128);
289  matrix pred = network_predict_data(net, valid);
290  int i, k;
291  int correct = 0;
292  int total = 0;
293  int accuracy = 0;
294  for(i = 0; i < imgs; ++i){
295  int allcorrect = 1;
296  for(k = 0; k < 10; ++k){
297  char truth = int_to_alphanum(max_index(valid.y.vals[i]+k*numchars, numchars));
298  char prediction = int_to_alphanum(max_index(pred.vals[i]+k*numchars, numchars));
299  if (truth != prediction) allcorrect=0;
300  if (truth != '.' && truth == prediction) ++correct;
301  if (truth != '.' || truth != prediction) ++total;
302  }
303  accuracy += allcorrect;
304  }
305  printf("Word Accuracy: %f, Char Accuracy %f\n", (float)accuracy/imgs, (float)correct/total);
306  free_data(valid);
307 }
308 
309 void test_captcha(char *cfgfile, char *weightfile)
310 {
311  setbuf(stdout, NULL);
312  srand(time(0));
313  //char *base = basecfg(cfgfile);
314  //printf("%s\n", base);
315  network net = parse_network_cfg(cfgfile);
316  set_batch_network(&net, 1);
317  if(weightfile){
318  load_weights(&net, weightfile);
319  }
320  char filename[256];
321  while(1){
322  //printf("Enter filename: ");
323  fgets(filename, 256, stdin);
324  strtok(filename, "\n");
325  image im = load_image_color(filename, 200, 60);
326  translate_image(im, -128);
327  scale_image(im, 1/128.);
328  float *X = im.data;
329  float *predictions = network_predict(net, X);
330  print_letters(predictions, 10);
331  free_image(im);
332  }
333 }
334  */
335 void run_captcha(int argc, char **argv)
336 {
337  if(argc < 4){
338  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
339  return;
340  }
341 
342  char *cfg = argv[3];
343  char *weights = (argc > 4) ? argv[4] : 0;
344  char *filename = (argc > 5) ? argv[5]: 0;
345  if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
346  else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
347  else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
348  //if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
349  //else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
350  //else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights);
351  //else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights);
352 }
353 
float decay
Definition: darknet.h:447
char ** paths
Definition: darknet.h:553
int rows
Definition: darknet.h:533
pthread_t load_data_in_thread(load_args args)
Definition: data.c:1135
void set_batch_network(network *net, int b)
Definition: network.c:339
int w
Definition: darknet.h:559
float learning_rate
Definition: darknet.h:445
int cols
Definition: darknet.h:533
float momentum
Definition: darknet.h:446
void free_data(data d)
Definition: data.c:665
char * basecfg(char *cfgfile)
Definition: utils.c:179
void ** list_to_array(list *l)
Definition: list.c:82
size_t * seen
Definition: darknet.h:437
float train_network(network *net, data d)
Definition: network.c:314
int size
Definition: darknet.h:603
Definition: darknet.h:512
int h
Definition: darknet.h:558
data_type type
Definition: darknet.h:580
void save_weights(network *net, char *filename)
Definition: parser.c:1080
network_predict
Definition: darknet.py:79
void valid_captcha(char *cfgfile, char *weightfile, char *filename)
Definition: captcha.c:130
int m
Definition: darknet.h:556
data * d
Definition: darknet.h:577
free_image
Definition: darknet.py:95
image load_image_color(char *filename, int w, int h)
Definition: image.c:1486
int classes
Definition: darknet.h:566
void test_captcha(char *cfgfile, char *weightfile, char *filename)
Definition: captcha.c:93
void fix_data_captcha(data d, int mask)
Definition: captcha.c:3
float sec(clock_t clocks)
Definition: utils.c:232
network * load_network(char *cfg, char *weights, int clear)
Definition: network.c:53
char ** get_labels(char *filename)
Definition: data.c:657
int n
Definition: darknet.h:555
void * load_thread(void *ptr)
Definition: data.c:1090
void run_captcha(int argc, char **argv)
Definition: captcha.c:335
Definition: darknet.h:602
char ** labels
Definition: darknet.h:557
int outputs
Definition: darknet.h:465
#define SECRET_NUM
Definition: darknet.h:8
float ** vals
Definition: darknet.h:534
void train_captcha(char *cfgfile, char *weightfile)
Definition: captcha.c:27
int h
Definition: darknet.h:468
list * get_paths(char *filename)
Definition: data.c:12
void top_predictions(network *net, int n, int *index)
Definition: network.c:491
int w
Definition: darknet.h:468
Definition: darknet.h:538
float * data
Definition: darknet.h:516
matrix y
Definition: darknet.h:541