14 unsigned char **contents = calloc(*n,
sizeof(
char *));
17 for(i = 0; i < *n; ++i){
28 FILE *fp = fopen(filename,
"r");
29 int *d = calloc(size,
sizeof(
int));
31 one = fscanf(fp,
"%d", &n);
36 d = realloc(d, size*
sizeof(
int));
39 one = fscanf(fp,
"%d", &n);
42 d = realloc(d, count*
sizeof(
int));
51 FILE *fp = fopen(filename,
"r");
52 char **d = calloc(size,
sizeof(
char *));
54 while((line=
fgetl(fp)) != 0){
58 d = realloc(d, size*
sizeof(
char *));
60 if(0==strcmp(line,
"<NEWLINE>")) line =
"\n";
64 d = realloc(d, count*
sizeof(
char *));
72 float *x = calloc(batch * steps * characters,
sizeof(
float));
73 float *y = calloc(batch * steps * characters,
sizeof(
float));
75 for(i = 0; i < batch; ++i){
76 for(j = 0; j < steps; ++j){
77 int curr = tokens[(offsets[i])%len];
78 int next = tokens[(offsets[i] + 1)%len];
80 x[(j*batch + i)*characters + curr] = 1;
81 y[(j*batch + i)*characters + next] = 1;
83 offsets[i] = (offsets[i] + 1) % len;
85 if(curr >= characters || curr < 0 || next >= characters || next < 0){
99 float *x = calloc(batch * steps * characters,
sizeof(
float));
100 float *y = calloc(batch * steps * characters,
sizeof(
float));
101 for(i = 0; i < batch; ++i){
102 int index = rand()%n;
105 for(j = 0; j < steps; ++j){
106 unsigned char curr = source[index][j];
107 unsigned char next = dest[index][j];
109 x[(j*batch + i)*characters + curr] = 1;
110 y[(j*batch + i)*characters + next] = 1;
112 if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
129 float *x = calloc(batch * steps * characters,
sizeof(
float));
130 float *y = calloc(batch * steps * characters,
sizeof(
float));
132 for(i = 0; i < batch; ++i){
133 for(j = 0; j < steps; ++j){
134 unsigned char curr = text[(offsets[i])%len];
135 unsigned char next = text[(offsets[i] + 1)%len];
137 x[(j*batch + i)*characters + curr] = 1;
138 y[(j*batch + i)*characters + next] = 1;
140 offsets[i] = (offsets[i] + 1) % len;
142 if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
157 void train_char_rnn(
char *cfgfile,
char *weightfile,
char *filename,
int clear,
int tokenized)
160 unsigned char *text = 0;
167 size = strlen((
const char*)text);
170 char *backup_directory =
"/home/pjreddie/backup/";
172 fprintf(stderr,
"%s\n", base);
177 fprintf(stderr,
"Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n", net->
learning_rate, net->
momentum, net->
decay, inputs, net->
batch, net->
time_steps);
178 int batch = net->
batch;
180 if(clear) *net->
seen = 0;
183 int streams = batch/steps;
184 size_t *offsets = calloc(streams,
sizeof(
size_t));
186 for(j = 0; j < streams; ++j){
198 p =
get_rnn_data(text, offsets, inputs, size, streams, steps);
206 if (avg_loss < 0) avg_loss = loss;
207 avg_loss = avg_loss*.9 + loss*.1;
210 fprintf(stderr,
"%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss,
get_current_rate(net),
sec(clock()-time), (
float) chars/size);
212 for(j = 0; j < streams; ++j){
223 sprintf(buff,
"%s/%s_%d.weights", backup_directory, base, i);
228 sprintf(buff,
"%s/%s.backup", backup_directory, base);
233 sprintf(buff,
"%s/%s_final.weights", backup_directory, base);
239 printf(
"%s ", tokens[n]);
245 void test_char_rnn(
char *cfgfile,
char *weightfile,
int num,
char *seed,
float temp,
int rseed,
char *token_file)
255 fprintf(stderr,
"%s\n", base);
263 int len = strlen(seed);
264 float *input = calloc(inputs,
sizeof(
float));
274 for(i = 0; i < len-1; ++i){
281 if(len) c = seed[len-1];
283 for(i = 0; i < num; ++i){
287 for(j = 32; j < 127; ++j){
290 for(j = 0; j < inputs; ++j){
291 if (out[j] < .0001) out[j] = 0;
309 fprintf(stderr,
"%s\n", base);
317 float *input = calloc(inputs,
sizeof(
float));
322 while((c = getc(stdin)) != EOF && c != 0){
327 for(i = 0; i < num; ++i){
328 for(j = 0; j < inputs; ++j){
329 if (out[j] < .0001) out[j] = 0;
332 if(c ==
'.' && next ==
'\n')
break;
344 void test_tactic_rnn(
char *cfgfile,
char *weightfile,
int num,
float temp,
int rseed,
char *token_file)
354 fprintf(stderr,
"%s\n", base);
362 float *input = calloc(inputs,
sizeof(
float));
365 while((c = getc(stdin)) != EOF){
370 for(i = 0; i < num; ++i){
371 for(j = 0; j < inputs; ++j){
372 if (out[j] < .0001) out[j] = 0;
375 if(c ==
'.' && next ==
'\n')
break;
389 fprintf(stderr,
"%s\n", base);
397 int len = strlen(seed);
398 float *input = calloc(inputs,
sizeof(
float));
400 for(i = 0; i < len; ++i){
411 int next = getc(stdin);
412 if(next == EOF)
break;
413 if(next < 0 || next >= 255)
error(
"Out of range character");
419 if(c ==
'.' && next ==
'\n') in = 0;
421 if(c ==
'>' && next ==
'>'){
429 sum += log(out[next])/log2;
431 printf(
"%d %d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
438 fprintf(stderr,
"%s\n", base);
446 int len = strlen(seed);
447 float *input = calloc(inputs,
sizeof(
float));
449 for(i = 0; i < len; ++i){
459 int next = getc(stdin);
460 if(next == EOF)
break;
461 if(next < 0 || next >= 255)
error(
"Out of range character");
463 if(next ==
' ' || next ==
'\n' || next ==
'\t') ++words;
467 sum += log(out[next])/log2;
469 printf(
"%d BPC: %4.4f Perplexity: %4.4f Word Perplexity: %4.4f\n", count, -sum/count, pow(2, -sum/count), pow(2, -sum/words));
476 fprintf(stderr,
"%s\n", base);
482 int seed_len = strlen(seed);
483 float *input = calloc(inputs,
sizeof(
float));
486 while((line=
fgetl(stdin)) != 0){
488 for(i = 0; i < seed_len; ++i){
495 int str_len = strlen(line);
496 for(i = 0; i < str_len; ++i){
509 cuda_pull_array(l.output_gpu, l.output, l.outputs);
512 for(i = 0; i < l.outputs; ++i){
513 printf(
",%g", l.output[i]);
522 fprintf(stderr,
"usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
525 char *filename =
find_char_arg(argc, argv,
"-file",
"data/shakespeare.txt");
529 int rseed =
find_int_arg(argc, argv,
"-srand", time(0));
530 int clear =
find_arg(argc, argv,
"-clear");
531 int tokenized =
find_arg(argc, argv,
"-tokenized");
535 char *weights = (argc > 4) ? argv[4] : 0;
536 if(0==strcmp(argv[2],
"train"))
train_char_rnn(cfg, weights, filename, clear, tokenized);
537 else if(0==strcmp(argv[2],
"valid"))
valid_char_rnn(cfg, weights, seed);
538 else if(0==strcmp(argv[2],
"validtactic"))
valid_tactic_rnn(cfg, weights, seed);
539 else if(0==strcmp(argv[2],
"vec"))
vec_char_rnn(cfg, weights, seed);
540 else if(0==strcmp(argv[2],
"generate"))
test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
541 else if(0==strcmp(argv[2],
"generatetactic"))
test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
void print_symbol(int n, char **tokens)
void reset_network_state(network *net, int b)
int find_arg(int argc, char *argv[], char *arg)
void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
int * read_tokenized_data(char *filename, size_t *read)
void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
char * find_char_arg(int argc, char **argv, char *arg, char *def)
char * basecfg(char *cfgfile)
char ** read_tokens(char *filename, size_t *read)
void save_weights(network *net, char *filename)
unsigned char ** load_files(char *filename, int *n)
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
float find_float_arg(int argc, char **argv, char *arg, float def)
float get_current_rate(network *net)
float sec(clock_t clocks)
int find_int_arg(int argc, char **argv, char *arg, int def)
void run_char_rnn(int argc, char **argv)
network * load_network(char *cfg, char *weights, int clear)
float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
void copy_cpu(int N, float *X, int INCX, float *Y, int INCY)
unsigned char * read_file(char *filename)
float_pair get_rnn_token_data(int *tokens, size_t *offsets, int characters, size_t len, int batch, int steps)
void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
size_t get_current_batch(network *net)
list * get_paths(char *filename)
void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
float train_network_datum(network *net)
float_pair get_seq2seq_data(char **source, char **dest, int n, int characters, size_t len, int batch, int steps)
void error(const char *s)
int sample_array(float *a, int n)
void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)