darknet  v3
rnn.c
Go to the documentation of this file.
1 #include "darknet.h"
2 
3 #include <math.h>
4 
5 typedef struct {
6  float *x;
7  float *y;
8 } float_pair;
9 
10 unsigned char **load_files(char *filename, int *n)
11 {
12  list *paths = get_paths(filename);
13  *n = paths->size;
14  unsigned char **contents = calloc(*n, sizeof(char *));
15  int i;
16  node *x = paths->front;
17  for(i = 0; i < *n; ++i){
18  contents[i] = read_file((char *)x->val);
19  x = x->next;
20  }
21  return contents;
22 }
23 
24 int *read_tokenized_data(char *filename, size_t *read)
25 {
26  size_t size = 512;
27  size_t count = 0;
28  FILE *fp = fopen(filename, "r");
29  int *d = calloc(size, sizeof(int));
30  int n, one;
31  one = fscanf(fp, "%d", &n);
32  while(one == 1){
33  ++count;
34  if(count > size){
35  size = size*2;
36  d = realloc(d, size*sizeof(int));
37  }
38  d[count-1] = n;
39  one = fscanf(fp, "%d", &n);
40  }
41  fclose(fp);
42  d = realloc(d, count*sizeof(int));
43  *read = count;
44  return d;
45 }
46 
47 char **read_tokens(char *filename, size_t *read)
48 {
49  size_t size = 512;
50  size_t count = 0;
51  FILE *fp = fopen(filename, "r");
52  char **d = calloc(size, sizeof(char *));
53  char *line;
54  while((line=fgetl(fp)) != 0){
55  ++count;
56  if(count > size){
57  size = size*2;
58  d = realloc(d, size*sizeof(char *));
59  }
60  if(0==strcmp(line, "<NEWLINE>")) line = "\n";
61  d[count-1] = line;
62  }
63  fclose(fp);
64  d = realloc(d, count*sizeof(char *));
65  *read = count;
66  return d;
67 }
68 
69 
70 float_pair get_rnn_token_data(int *tokens, size_t *offsets, int characters, size_t len, int batch, int steps)
71 {
72  float *x = calloc(batch * steps * characters, sizeof(float));
73  float *y = calloc(batch * steps * characters, sizeof(float));
74  int i,j;
75  for(i = 0; i < batch; ++i){
76  for(j = 0; j < steps; ++j){
77  int curr = tokens[(offsets[i])%len];
78  int next = tokens[(offsets[i] + 1)%len];
79 
80  x[(j*batch + i)*characters + curr] = 1;
81  y[(j*batch + i)*characters + next] = 1;
82 
83  offsets[i] = (offsets[i] + 1) % len;
84 
85  if(curr >= characters || curr < 0 || next >= characters || next < 0){
86  error("Bad char");
87  }
88  }
89  }
90  float_pair p;
91  p.x = x;
92  p.y = y;
93  return p;
94 }
95 
96 float_pair get_seq2seq_data(char **source, char **dest, int n, int characters, size_t len, int batch, int steps)
97 {
98  int i,j;
99  float *x = calloc(batch * steps * characters, sizeof(float));
100  float *y = calloc(batch * steps * characters, sizeof(float));
101  for(i = 0; i < batch; ++i){
102  int index = rand()%n;
103  //int slen = strlen(source[index]);
104  //int dlen = strlen(dest[index]);
105  for(j = 0; j < steps; ++j){
106  unsigned char curr = source[index][j];
107  unsigned char next = dest[index][j];
108 
109  x[(j*batch + i)*characters + curr] = 1;
110  y[(j*batch + i)*characters + next] = 1;
111 
112  if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
113  /*text[(index+j+2)%len] = 0;
114  printf("%ld %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
115  printf("%s", text+index);
116  */
117  error("Bad char");
118  }
119  }
120  }
121  float_pair p;
122  p.x = x;
123  p.y = y;
124  return p;
125 }
126 
127 float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
128 {
129  float *x = calloc(batch * steps * characters, sizeof(float));
130  float *y = calloc(batch * steps * characters, sizeof(float));
131  int i,j;
132  for(i = 0; i < batch; ++i){
133  for(j = 0; j < steps; ++j){
134  unsigned char curr = text[(offsets[i])%len];
135  unsigned char next = text[(offsets[i] + 1)%len];
136 
137  x[(j*batch + i)*characters + curr] = 1;
138  y[(j*batch + i)*characters + next] = 1;
139 
140  offsets[i] = (offsets[i] + 1) % len;
141 
142  if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
143  /*text[(index+j+2)%len] = 0;
144  printf("%ld %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
145  printf("%s", text+index);
146  */
147  error("Bad char");
148  }
149  }
150  }
151  float_pair p;
152  p.x = x;
153  p.y = y;
154  return p;
155 }
156 
157 void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
158 {
159  srand(time(0));
160  unsigned char *text = 0;
161  int *tokens = 0;
162  size_t size;
163  if(tokenized){
164  tokens = read_tokenized_data(filename, &size);
165  } else {
166  text = read_file(filename);
167  size = strlen((const char*)text);
168  }
169 
170  char *backup_directory = "/home/pjreddie/backup/";
171  char *base = basecfg(cfgfile);
172  fprintf(stderr, "%s\n", base);
173  float avg_loss = -1;
174  network *net = load_network(cfgfile, weightfile, clear);
175 
176  int inputs = net->inputs;
177  fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n", net->learning_rate, net->momentum, net->decay, inputs, net->batch, net->time_steps);
178  int batch = net->batch;
179  int steps = net->time_steps;
180  if(clear) *net->seen = 0;
181  int i = (*net->seen)/net->batch;
182 
183  int streams = batch/steps;
184  size_t *offsets = calloc(streams, sizeof(size_t));
185  int j;
186  for(j = 0; j < streams; ++j){
187  offsets[j] = rand_size_t()%size;
188  }
189 
190  clock_t time;
191  while(get_current_batch(net) < net->max_batches){
192  i += 1;
193  time=clock();
194  float_pair p;
195  if(tokenized){
196  p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps);
197  }else{
198  p = get_rnn_data(text, offsets, inputs, size, streams, steps);
199  }
200 
201  copy_cpu(net->inputs*net->batch, p.x, 1, net->input, 1);
202  copy_cpu(net->truths*net->batch, p.y, 1, net->truth, 1);
203  float loss = train_network_datum(net) / (batch);
204  free(p.x);
205  free(p.y);
206  if (avg_loss < 0) avg_loss = loss;
207  avg_loss = avg_loss*.9 + loss*.1;
208 
209  size_t chars = get_current_batch(net)*batch;
210  fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), (float) chars/size);
211 
212  for(j = 0; j < streams; ++j){
213  //printf("%d\n", j);
214  if(rand()%64 == 0){
215  //fprintf(stderr, "Reset\n");
216  offsets[j] = rand_size_t()%size;
217  reset_network_state(net, j);
218  }
219  }
220 
221  if(i%10000==0){
222  char buff[256];
223  sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
224  save_weights(net, buff);
225  }
226  if(i%100==0){
227  char buff[256];
228  sprintf(buff, "%s/%s.backup", backup_directory, base);
229  save_weights(net, buff);
230  }
231  }
232  char buff[256];
233  sprintf(buff, "%s/%s_final.weights", backup_directory, base);
234  save_weights(net, buff);
235 }
236 
237 void print_symbol(int n, char **tokens){
238  if(tokens){
239  printf("%s ", tokens[n]);
240  } else {
241  printf("%c", n);
242  }
243 }
244 
245 void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
246 {
247  char **tokens = 0;
248  if(token_file){
249  size_t n;
250  tokens = read_tokens(token_file, &n);
251  }
252 
253  srand(rseed);
254  char *base = basecfg(cfgfile);
255  fprintf(stderr, "%s\n", base);
256 
257  network *net = load_network(cfgfile, weightfile, 0);
258  int inputs = net->inputs;
259 
260  int i, j;
261  for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
262  int c = 0;
263  int len = strlen(seed);
264  float *input = calloc(inputs, sizeof(float));
265 
266  /*
267  fill_cpu(inputs, 0, input, 1);
268  for(i = 0; i < 10; ++i){
269  network_predict(net, input);
270  }
271  fill_cpu(inputs, 0, input, 1);
272  */
273 
274  for(i = 0; i < len-1; ++i){
275  c = seed[i];
276  input[c] = 1;
277  network_predict(net, input);
278  input[c] = 0;
279  print_symbol(c, tokens);
280  }
281  if(len) c = seed[len-1];
282  print_symbol(c, tokens);
283  for(i = 0; i < num; ++i){
284  input[c] = 1;
285  float *out = network_predict(net, input);
286  input[c] = 0;
287  for(j = 32; j < 127; ++j){
288  //printf("%d %c %f\n",j, j, out[j]);
289  }
290  for(j = 0; j < inputs; ++j){
291  if (out[j] < .0001) out[j] = 0;
292  }
293  c = sample_array(out, inputs);
294  print_symbol(c, tokens);
295  }
296  printf("\n");
297 }
298 
299 void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
300 {
301  char **tokens = 0;
302  if(token_file){
303  size_t n;
304  tokens = read_tokens(token_file, &n);
305  }
306 
307  srand(rseed);
308  char *base = basecfg(cfgfile);
309  fprintf(stderr, "%s\n", base);
310 
311  network *net = load_network(cfgfile, weightfile, 0);
312  int inputs = net->inputs;
313 
314  int i, j;
315  for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
316  int c = 0;
317  float *input = calloc(inputs, sizeof(float));
318  float *out = 0;
319 
320  while(1){
321  reset_network_state(net, 0);
322  while((c = getc(stdin)) != EOF && c != 0){
323  input[c] = 1;
324  out = network_predict(net, input);
325  input[c] = 0;
326  }
327  for(i = 0; i < num; ++i){
328  for(j = 0; j < inputs; ++j){
329  if (out[j] < .0001) out[j] = 0;
330  }
331  int next = sample_array(out, inputs);
332  if(c == '.' && next == '\n') break;
333  c = next;
334  print_symbol(c, tokens);
335 
336  input[c] = 1;
337  out = network_predict(net, input);
338  input[c] = 0;
339  }
340  printf("\n");
341  }
342 }
343 
344 void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
345 {
346  char **tokens = 0;
347  if(token_file){
348  size_t n;
349  tokens = read_tokens(token_file, &n);
350  }
351 
352  srand(rseed);
353  char *base = basecfg(cfgfile);
354  fprintf(stderr, "%s\n", base);
355 
356  network *net = load_network(cfgfile, weightfile, 0);
357  int inputs = net->inputs;
358 
359  int i, j;
360  for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
361  int c = 0;
362  float *input = calloc(inputs, sizeof(float));
363  float *out = 0;
364 
365  while((c = getc(stdin)) != EOF){
366  input[c] = 1;
367  out = network_predict(net, input);
368  input[c] = 0;
369  }
370  for(i = 0; i < num; ++i){
371  for(j = 0; j < inputs; ++j){
372  if (out[j] < .0001) out[j] = 0;
373  }
374  int next = sample_array(out, inputs);
375  if(c == '.' && next == '\n') break;
376  c = next;
377  print_symbol(c, tokens);
378 
379  input[c] = 1;
380  out = network_predict(net, input);
381  input[c] = 0;
382  }
383  printf("\n");
384 }
385 
386 void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
387 {
388  char *base = basecfg(cfgfile);
389  fprintf(stderr, "%s\n", base);
390 
391  network *net = load_network(cfgfile, weightfile, 0);
392  int inputs = net->inputs;
393 
394  int count = 0;
395  int words = 1;
396  int c;
397  int len = strlen(seed);
398  float *input = calloc(inputs, sizeof(float));
399  int i;
400  for(i = 0; i < len; ++i){
401  c = seed[i];
402  input[(int)c] = 1;
403  network_predict(net, input);
404  input[(int)c] = 0;
405  }
406  float sum = 0;
407  c = getc(stdin);
408  float log2 = log(2);
409  int in = 0;
410  while(c != EOF){
411  int next = getc(stdin);
412  if(next == EOF) break;
413  if(next < 0 || next >= 255) error("Out of range character");
414 
415  input[c] = 1;
416  float *out = network_predict(net, input);
417  input[c] = 0;
418 
419  if(c == '.' && next == '\n') in = 0;
420  if(!in) {
421  if(c == '>' && next == '>'){
422  in = 1;
423  ++words;
424  }
425  c = next;
426  continue;
427  }
428  ++count;
429  sum += log(out[next])/log2;
430  c = next;
431  printf("%d %d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
432  }
433 }
434 
435 void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
436 {
437  char *base = basecfg(cfgfile);
438  fprintf(stderr, "%s\n", base);
439 
440  network *net = load_network(cfgfile, weightfile, 0);
441  int inputs = net->inputs;
442 
443  int count = 0;
444  int words = 1;
445  int c;
446  int len = strlen(seed);
447  float *input = calloc(inputs, sizeof(float));
448  int i;
449  for(i = 0; i < len; ++i){
450  c = seed[i];
451  input[(int)c] = 1;
452  network_predict(net, input);
453  input[(int)c] = 0;
454  }
455  float sum = 0;
456  c = getc(stdin);
457  float log2 = log(2);
458  while(c != EOF){
459  int next = getc(stdin);
460  if(next == EOF) break;
461  if(next < 0 || next >= 255) error("Out of range character");
462  ++count;
463  if(next == ' ' || next == '\n' || next == '\t') ++words;
464  input[c] = 1;
465  float *out = network_predict(net, input);
466  input[c] = 0;
467  sum += log(out[next])/log2;
468  c = next;
469  printf("%d BPC: %4.4f Perplexity: %4.4f Word Perplexity: %4.4f\n", count, -sum/count, pow(2, -sum/count), pow(2, -sum/words));
470  }
471 }
472 
473 void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
474 {
475  char *base = basecfg(cfgfile);
476  fprintf(stderr, "%s\n", base);
477 
478  network *net = load_network(cfgfile, weightfile, 0);
479  int inputs = net->inputs;
480 
481  int c;
482  int seed_len = strlen(seed);
483  float *input = calloc(inputs, sizeof(float));
484  int i;
485  char *line;
486  while((line=fgetl(stdin)) != 0){
487  reset_network_state(net, 0);
488  for(i = 0; i < seed_len; ++i){
489  c = seed[i];
490  input[(int)c] = 1;
491  network_predict(net, input);
492  input[(int)c] = 0;
493  }
494  strip(line);
495  int str_len = strlen(line);
496  for(i = 0; i < str_len; ++i){
497  c = line[i];
498  input[(int)c] = 1;
499  network_predict(net, input);
500  input[(int)c] = 0;
501  }
502  c = ' ';
503  input[(int)c] = 1;
504  network_predict(net, input);
505  input[(int)c] = 0;
506 
507  layer l = net->layers[0];
508  #ifdef GPU
509  cuda_pull_array(l.output_gpu, l.output, l.outputs);
510  #endif
511  printf("%s", line);
512  for(i = 0; i < l.outputs; ++i){
513  printf(",%g", l.output[i]);
514  }
515  printf("\n");
516  }
517 }
518 
519 void run_char_rnn(int argc, char **argv)
520 {
521  if(argc < 4){
522  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
523  return;
524  }
525  char *filename = find_char_arg(argc, argv, "-file", "data/shakespeare.txt");
526  char *seed = find_char_arg(argc, argv, "-seed", "\n\n");
527  int len = find_int_arg(argc, argv, "-len", 1000);
528  float temp = find_float_arg(argc, argv, "-temp", .7);
529  int rseed = find_int_arg(argc, argv, "-srand", time(0));
530  int clear = find_arg(argc, argv, "-clear");
531  int tokenized = find_arg(argc, argv, "-tokenized");
532  char *tokens = find_char_arg(argc, argv, "-tokens", 0);
533 
534  char *cfg = argv[3];
535  char *weights = (argc > 4) ? argv[4] : 0;
536  if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
537  else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
538  else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
539  else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
540  else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
541  else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
542 }
void print_symbol(int n, char **tokens)
Definition: rnn.c:237
Definition: darknet.h:596
float decay
Definition: darknet.h:447
void reset_network_state(network *net, int b)
Definition: network.c:69
int batch
Definition: darknet.h:436
float temperature
Definition: darknet.h:210
int find_arg(int argc, char *argv[], char *arg)
Definition: utils.c:120
void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
Definition: rnn.c:473
struct node * next
Definition: darknet.h:598
int * read_tokenized_data(char *filename, size_t *read)
Definition: rnn.c:24
float * y
Definition: rnn.c:7
float learning_rate
Definition: darknet.h:445
float momentum
Definition: darknet.h:446
void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
Definition: rnn.c:245
float * truth
Definition: darknet.h:485
char * find_char_arg(int argc, char **argv, char *arg, char *def)
Definition: utils.c:163
char * basecfg(char *cfgfile)
Definition: utils.c:179
size_t * seen
Definition: darknet.h:437
int size
Definition: darknet.h:603
char ** read_tokens(char *filename, size_t *read)
Definition: rnn.c:47
void save_weights(network *net, char *filename)
Definition: parser.c:1080
network_predict
Definition: darknet.py:79
int max_batches
Definition: darknet.h:453
unsigned char ** load_files(char *filename, int *n)
Definition: rnn.c:10
layer * layers
Definition: darknet.h:441
char * fgetl(FILE *fp)
Definition: utils.c:335
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
Definition: rnn.c:157
float find_float_arg(int argc, char **argv, char *arg, float def)
Definition: utils.c:148
float get_current_rate(network *net)
Definition: network.c:90
void * val
Definition: darknet.h:597
float sec(clock_t clocks)
Definition: utils.c:232
int find_int_arg(int argc, char **argv, char *arg, int def)
Definition: utils.c:133
void run_char_rnn(int argc, char **argv)
Definition: rnn.c:519
network * load_network(char *cfg, char *weights, int clear)
Definition: network.c:53
float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
Definition: rnn.c:127
node * front
Definition: darknet.h:604
int truths
Definition: darknet.h:466
size_t rand_size_t()
Definition: utils.c:686
Definition: darknet.h:602
void copy_cpu(int N, float *X, int INCX, float *Y, int INCY)
Definition: blas.c:226
unsigned char * read_file(char *filename)
Definition: utils.c:260
float_pair get_rnn_token_data(int *tokens, size_t *offsets, int characters, size_t len, int batch, int steps)
Definition: rnn.c:70
void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
Definition: rnn.c:299
float * input
Definition: darknet.h:484
size_t get_current_batch(network *net)
Definition: network.c:63
int n
Definition: darknet.h:435
int time_steps
Definition: darknet.h:451
float * x
Definition: rnn.c:6
list * get_paths(char *filename)
Definition: data.c:12
void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
Definition: rnn.c:344
void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
Definition: rnn.c:386
float train_network_datum(network *net)
Definition: network.c:289
float_pair get_seq2seq_data(char **source, char **dest, int n, int characters, size_t len, int batch, int steps)
Definition: rnn.c:96
void error(const char *s)
Definition: utils.c:253
int sample_array(float *a, int n)
Definition: utils.c:592
void strip(char *s)
Definition: utils.c:302
void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
Definition: rnn.c:435
Definition: rnn.c:5
int inputs
Definition: darknet.h:464
Definition: darknet.h:119