darknet  v3
cifar.c
Go to the documentation of this file.
1 #include "darknet.h"
2 
3 void train_cifar(char *cfgfile, char *weightfile)
4 {
5  srand(time(0));
6  float avg_loss = -1;
7  char *base = basecfg(cfgfile);
8  printf("%s\n", base);
9  network *net = load_network(cfgfile, weightfile, 0);
10  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
11 
12  char *backup_directory = "/home/pjreddie/backup/";
13  int classes = 10;
14  int N = 50000;
15 
16  char **labels = get_labels("data/cifar/labels.txt");
17  int epoch = (*net->seen)/N;
18  data train = load_all_cifar10();
19  while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
20  clock_t time=clock();
21 
22  float loss = train_network_sgd(net, train, 1);
23  if(avg_loss == -1) avg_loss = loss;
24  avg_loss = avg_loss*.95 + loss*.05;
25  printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
26  if(*net->seen/N > epoch){
27  epoch = *net->seen/N;
28  char buff[256];
29  sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
30  save_weights(net, buff);
31  }
32  if(get_current_batch(net)%100 == 0){
33  char buff[256];
34  sprintf(buff, "%s/%s.backup",backup_directory,base);
35  save_weights(net, buff);
36  }
37  }
38  char buff[256];
39  sprintf(buff, "%s/%s.weights", backup_directory, base);
40  save_weights(net, buff);
41 
42  free_network(net);
43  free_ptrs((void**)labels, classes);
44  free(base);
45  free_data(train);
46 }
47 
48 void train_cifar_distill(char *cfgfile, char *weightfile)
49 {
50  srand(time(0));
51  float avg_loss = -1;
52  char *base = basecfg(cfgfile);
53  printf("%s\n", base);
54  network *net = load_network(cfgfile, weightfile, 0);
55  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
56 
57  char *backup_directory = "/home/pjreddie/backup/";
58  int classes = 10;
59  int N = 50000;
60 
61  char **labels = get_labels("data/cifar/labels.txt");
62  int epoch = (*net->seen)/N;
63 
64  data train = load_all_cifar10();
65  matrix soft = csv_to_matrix("results/ensemble.csv");
66 
67  float weight = .9;
68  scale_matrix(soft, weight);
69  scale_matrix(train.y, 1. - weight);
70  matrix_add_matrix(soft, train.y);
71 
72  while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
73  clock_t time=clock();
74 
75  float loss = train_network_sgd(net, train, 1);
76  if(avg_loss == -1) avg_loss = loss;
77  avg_loss = avg_loss*.95 + loss*.05;
78  printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
79  if(*net->seen/N > epoch){
80  epoch = *net->seen/N;
81  char buff[256];
82  sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
83  save_weights(net, buff);
84  }
85  if(get_current_batch(net)%100 == 0){
86  char buff[256];
87  sprintf(buff, "%s/%s.backup",backup_directory,base);
88  save_weights(net, buff);
89  }
90  }
91  char buff[256];
92  sprintf(buff, "%s/%s.weights", backup_directory, base);
93  save_weights(net, buff);
94 
95  free_network(net);
96  free_ptrs((void**)labels, classes);
97  free(base);
98  free_data(train);
99 }
100 
101 void test_cifar_multi(char *filename, char *weightfile)
102 {
103  network *net = load_network(filename, weightfile, 0);
104  set_batch_network(net, 1);
105  srand(time(0));
106 
107  float avg_acc = 0;
108  data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
109 
110  int i;
111  for(i = 0; i < test.X.rows; ++i){
112  image im = float_to_image(32, 32, 3, test.X.vals[i]);
113 
114  float pred[10] = {0};
115 
116  float *p = network_predict(net, im.data);
117  axpy_cpu(10, 1, p, 1, pred, 1);
118  flip_image(im);
119  p = network_predict(net, im.data);
120  axpy_cpu(10, 1, p, 1, pred, 1);
121 
122  int index = max_index(pred, 10);
123  int class = max_index(test.y.vals[i], 10);
124  if(index == class) avg_acc += 1;
125  free_image(im);
126  printf("%4d: %.2f%%\n", i, 100.*avg_acc/(i+1));
127  }
128 }
129 
130 void test_cifar(char *filename, char *weightfile)
131 {
132  network *net = load_network(filename, weightfile, 0);
133  srand(time(0));
134 
135  clock_t time;
136  float avg_acc = 0;
137  float avg_top5 = 0;
138  data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
139 
140  time=clock();
141 
142  float *acc = network_accuracies(net, test, 2);
143  avg_acc += acc[0];
144  avg_top5 += acc[1];
145  printf("top1: %f, %lf seconds, %d images\n", avg_acc, sec(clock()-time), test.X.rows);
146  free_data(test);
147 }
148 
150 {
151 char *labels[] = {"airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"};
152  int i;
153  data train = load_all_cifar10();
154  data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
155  for(i = 0; i < train.X.rows; ++i){
156  image im = float_to_image(32, 32, 3, train.X.vals[i]);
157  int class = max_index(train.y.vals[i], 10);
158  char buff[256];
159  sprintf(buff, "data/cifar/train/%d_%s",i,labels[class]);
160  save_image_png(im, buff);
161  }
162  for(i = 0; i < test.X.rows; ++i){
163  image im = float_to_image(32, 32, 3, test.X.vals[i]);
164  int class = max_index(test.y.vals[i], 10);
165  char buff[256];
166  sprintf(buff, "data/cifar/test/%d_%s",i,labels[class]);
167  save_image_png(im, buff);
168  }
169 }
170 
171 void test_cifar_csv(char *filename, char *weightfile)
172 {
173  network *net = load_network(filename, weightfile, 0);
174  srand(time(0));
175 
176  data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
177 
178  matrix pred = network_predict_data(net, test);
179 
180  int i;
181  for(i = 0; i < test.X.rows; ++i){
182  image im = float_to_image(32, 32, 3, test.X.vals[i]);
183  flip_image(im);
184  }
185  matrix pred2 = network_predict_data(net, test);
186  scale_matrix(pred, .5);
187  scale_matrix(pred2, .5);
188  matrix_add_matrix(pred2, pred);
189 
190  matrix_to_csv(pred);
191  fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
192  free_data(test);
193 }
194 
195 void test_cifar_csvtrain(char *cfg, char *weights)
196 {
197  network *net = load_network(cfg, weights, 0);
198  srand(time(0));
199 
200  data test = load_all_cifar10();
201 
202  matrix pred = network_predict_data(net, test);
203 
204  int i;
205  for(i = 0; i < test.X.rows; ++i){
206  image im = float_to_image(32, 32, 3, test.X.vals[i]);
207  flip_image(im);
208  }
209  matrix pred2 = network_predict_data(net, test);
210  scale_matrix(pred, .5);
211  scale_matrix(pred2, .5);
212  matrix_add_matrix(pred2, pred);
213 
214  matrix_to_csv(pred);
215  fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
216  free_data(test);
217 }
218 
220 {
221  data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
222 
223  matrix pred = csv_to_matrix("results/combined.csv");
224  fprintf(stderr, "%d %d\n", pred.rows, pred.cols);
225 
226  fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
227  free_data(test);
228  free_matrix(pred);
229 }
230 
231 
232 void run_cifar(int argc, char **argv)
233 {
234  if(argc < 4){
235  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
236  return;
237  }
238 
239  char *cfg = argv[3];
240  char *weights = (argc > 4) ? argv[4] : 0;
241  if(0==strcmp(argv[2], "train")) train_cifar(cfg, weights);
242  else if(0==strcmp(argv[2], "extract")) extract_cifar();
243  else if(0==strcmp(argv[2], "distill")) train_cifar_distill(cfg, weights);
244  else if(0==strcmp(argv[2], "test")) test_cifar(cfg, weights);
245  else if(0==strcmp(argv[2], "multi")) test_cifar_multi(cfg, weights);
246  else if(0==strcmp(argv[2], "csv")) test_cifar_csv(cfg, weights);
247  else if(0==strcmp(argv[2], "csvtrain")) test_cifar_csvtrain(cfg, weights);
248  else if(0==strcmp(argv[2], "eval")) eval_cifar_csv();
249 }
250 
251 
data load_cifar10_data(char *filename)
Definition: data.c:1422
matrix csv_to_matrix(char *filename)
Definition: matrix.c:133
void test_cifar_multi(char *filename, char *weightfile)
Definition: cifar.c:101
void test_cifar_csv(char *filename, char *weightfile)
Definition: cifar.c:171
float decay
Definition: darknet.h:447
void test_cifar(char *filename, char *weightfile)
Definition: cifar.c:130
int rows
Definition: darknet.h:533
void set_batch_network(network *net, int b)
Definition: network.c:339
float learning_rate
Definition: darknet.h:445
int cols
Definition: darknet.h:533
float train_network_sgd(network *net, data d, int n)
Definition: network.c:300
int max_index(float *a, int n)
Definition: utils.c:619
float momentum
Definition: darknet.h:446
void free_data(data d)
Definition: data.c:665
char * basecfg(char *cfgfile)
Definition: utils.c:179
void run_cifar(int argc, char **argv)
Definition: cifar.c:232
size_t * seen
Definition: darknet.h:437
void eval_cifar_csv()
Definition: cifar.c:219
Definition: darknet.h:512
float * network_accuracies(network *net, data d, int n)
Definition: network.c:689
void train_cifar(char *cfgfile, char *weightfile)
Definition: cifar.c:3
void free_network(network *net)
Definition: network.c:716
void flip_image(image a)
Definition: image.c:349
image float_to_image(int w, int h, int c, float *data)
Definition: image.c:774
void save_weights(network *net, char *filename)
Definition: parser.c:1080
void test_cifar_csvtrain(char *cfg, char *weights)
Definition: cifar.c:195
network_predict
Definition: darknet.py:79
void save_image_png(image im, const char *name)
Definition: image.c:700
int max_batches
Definition: darknet.h:453
free_image
Definition: darknet.py:95
data load_all_cifar10()
Definition: data.c:1481
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
Definition: blas.c:178
void extract_cifar()
Definition: cifar.c:149
float get_current_rate(network *net)
Definition: network.c:90
float sec(clock_t clocks)
Definition: utils.c:232
network * load_network(char *cfg, char *weights, int clear)
Definition: network.c:53
char ** get_labels(char *filename)
Definition: data.c:657
matrix network_predict_data(network *net, data test)
Definition: network.c:616
void free_matrix(matrix m)
Definition: matrix.c:10
matrix X
Definition: darknet.h:540
void matrix_add_matrix(matrix from, matrix to)
Definition: matrix.c:66
void train_cifar_distill(char *cfgfile, char *weightfile)
Definition: cifar.c:48
size_t get_current_batch(network *net)
Definition: network.c:63
float ** vals
Definition: darknet.h:534
float matrix_topk_accuracy(matrix truth, matrix guess, int k)
Definition: matrix.c:17
free_ptrs
Definition: darknet.py:76
void matrix_to_csv(matrix m)
Definition: matrix.c:161
list classes
Definition: voc_label.py:9
void scale_matrix(matrix m, float scale)
Definition: matrix.c:37
Definition: darknet.h:538
float * data
Definition: darknet.h:516
matrix y
Definition: darknet.h:541