12 char *backup_directory =
"/home/pjreddie/backup/";
16 char **labels =
get_labels(
"data/cifar/labels.txt");
17 int epoch = (*net->
seen)/N;
23 if(avg_loss == -1) avg_loss = loss;
24 avg_loss = avg_loss*.95 + loss*.05;
25 printf(
"%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n",
get_current_batch(net), (
float)(*net->
seen)/N, loss, avg_loss,
get_current_rate(net),
sec(clock()-time), *net->
seen);
26 if(*net->
seen/N > epoch){
29 sprintf(buff,
"%s/%s_%d.weights",backup_directory,base, epoch);
34 sprintf(buff,
"%s/%s.backup",backup_directory,base);
39 sprintf(buff,
"%s/%s.weights", backup_directory, base);
57 char *backup_directory =
"/home/pjreddie/backup/";
61 char **labels =
get_labels(
"data/cifar/labels.txt");
62 int epoch = (*net->
seen)/N;
76 if(avg_loss == -1) avg_loss = loss;
77 avg_loss = avg_loss*.95 + loss*.05;
78 printf(
"%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n",
get_current_batch(net), (
float)(*net->
seen)/N, loss, avg_loss,
get_current_rate(net),
sec(clock()-time), *net->
seen);
79 if(*net->
seen/N > epoch){
82 sprintf(buff,
"%s/%s_%d.weights",backup_directory,base, epoch);
87 sprintf(buff,
"%s/%s.backup",backup_directory,base);
92 sprintf(buff,
"%s/%s.weights", backup_directory, base);
111 for(i = 0; i < test.
X.
rows; ++i){
114 float pred[10] = {0};
123 int class =
max_index(test.y.vals[i], 10);
124 if(index ==
class) avg_acc += 1;
126 printf(
"%4d: %.2f%%\n", i, 100.*avg_acc/(i+1));
145 printf(
"top1: %f, %lf seconds, %d images\n", avg_acc,
sec(clock()-time), test.
X.
rows);
151 char *labels[] = {
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"};
155 for(i = 0; i < train.
X.
rows; ++i){
157 int class =
max_index(train.y.vals[i], 10);
159 sprintf(buff,
"data/cifar/train/%d_%s",i,labels[
class]);
162 for(i = 0; i < test.
X.
rows; ++i){
164 int class =
max_index(test.y.vals[i], 10);
166 sprintf(buff,
"data/cifar/test/%d_%s",i,labels[
class]);
181 for(i = 0; i < test.
X.
rows; ++i){
205 for(i = 0; i < test.
X.
rows; ++i){
224 fprintf(stderr,
"%d %d\n", pred.
rows, pred.
cols);
235 fprintf(stderr,
"usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
240 char *weights = (argc > 4) ? argv[4] : 0;
241 if(0==strcmp(argv[2],
"train"))
train_cifar(cfg, weights);
244 else if(0==strcmp(argv[2],
"test"))
test_cifar(cfg, weights);
data load_cifar10_data(char *filename)
matrix csv_to_matrix(char *filename)
void test_cifar_multi(char *filename, char *weightfile)
void test_cifar_csv(char *filename, char *weightfile)
void test_cifar(char *filename, char *weightfile)
void set_batch_network(network *net, int b)
float train_network_sgd(network *net, data d, int n)
int max_index(float *a, int n)
char * basecfg(char *cfgfile)
void run_cifar(int argc, char **argv)
float * network_accuracies(network *net, data d, int n)
void train_cifar(char *cfgfile, char *weightfile)
void free_network(network *net)
image float_to_image(int w, int h, int c, float *data)
void save_weights(network *net, char *filename)
void test_cifar_csvtrain(char *cfg, char *weights)
void save_image_png(image im, const char *name)
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
float get_current_rate(network *net)
float sec(clock_t clocks)
network * load_network(char *cfg, char *weights, int clear)
char ** get_labels(char *filename)
matrix network_predict_data(network *net, data test)
void free_matrix(matrix m)
void matrix_add_matrix(matrix from, matrix to)
void train_cifar_distill(char *cfgfile, char *weightfile)
size_t get_current_batch(network *net)
float matrix_topk_accuracy(matrix truth, matrix guess, int k)
void matrix_to_csv(matrix m)
void scale_matrix(matrix m, float scale)