darknet  v3
dice.c
Go to the documentation of this file.
1 #include "darknet.h"
2 
3 char *dice_labels[] = {"face1","face2","face3","face4","face5","face6"};
4 
5 void train_dice(char *cfgfile, char *weightfile)
6 {
7  srand(time(0));
8  float avg_loss = -1;
9  char *base = basecfg(cfgfile);
10  char *backup_directory = "/home/pjreddie/backup/";
11  printf("%s\n", base);
12  network net = parse_network_cfg(cfgfile);
13  if(weightfile){
14  load_weights(&net, weightfile);
15  }
16  printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
17  int imgs = 1024;
18  int i = *net.seen/imgs;
19  char **labels = dice_labels;
20  list *plist = get_paths("data/dice/dice.train.list");
21  char **paths = (char **)list_to_array(plist);
22  printf("%d\n", plist->size);
23  clock_t time;
24  while(1){
25  ++i;
26  time=clock();
27  data train = load_data_old(paths, imgs, plist->size, labels, 6, net.w, net.h);
28  printf("Loaded: %lf seconds\n", sec(clock()-time));
29 
30  time=clock();
31  float loss = train_network(net, train);
32  if(avg_loss == -1) avg_loss = loss;
33  avg_loss = avg_loss*.9 + loss*.1;
34  printf("%d: %f, %f avg, %lf seconds, %ld images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
35  free_data(train);
36  if((i % 100) == 0) net.learning_rate *= .1;
37  if(i%100==0){
38  char buff[256];
39  sprintf(buff, "%s/%s_%d.weights",backup_directory,base, i);
40  save_weights(net, buff);
41  }
42  }
43 }
44 
45 void validate_dice(char *filename, char *weightfile)
46 {
47  network net = parse_network_cfg(filename);
48  if(weightfile){
49  load_weights(&net, weightfile);
50  }
51  srand(time(0));
52 
53  char **labels = dice_labels;
54  list *plist = get_paths("data/dice/dice.val.list");
55 
56  char **paths = (char **)list_to_array(plist);
57  int m = plist->size;
58  free_list(plist);
59 
60  data val = load_data_old(paths, m, 0, labels, 6, net.w, net.h);
61  float *acc = network_accuracies(net, val, 2);
62  printf("Validation Accuracy: %f, %d images\n", acc[0], m);
63  free_data(val);
64 }
65 
66 void test_dice(char *cfgfile, char *weightfile, char *filename)
67 {
68  network net = parse_network_cfg(cfgfile);
69  if(weightfile){
70  load_weights(&net, weightfile);
71  }
72  set_batch_network(&net, 1);
73  srand(2222222);
74  int i = 0;
75  char **names = dice_labels;
76  char buff[256];
77  char *input = buff;
78  int indexes[6];
79  while(1){
80  if(filename){
81  strncpy(input, filename, 256);
82  }else{
83  printf("Enter Image Path: ");
84  fflush(stdout);
85  input = fgets(input, 256, stdin);
86  if(!input) return;
87  strtok(input, "\n");
88  }
89  image im = load_image_color(input, net.w, net.h);
90  float *X = im.data;
91  float *predictions = network_predict(net, X);
92  top_predictions(net, 6, indexes);
93  for(i = 0; i < 6; ++i){
94  int index = indexes[i];
95  printf("%s: %f\n", names[index], predictions[index]);
96  }
97  free_image(im);
98  if (filename) break;
99  }
100 }
101 
102 void run_dice(int argc, char **argv)
103 {
104  if(argc < 4){
105  fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
106  return;
107  }
108 
109  char *cfg = argv[3];
110  char *weights = (argc > 4) ? argv[4] : 0;
111  char *filename = (argc > 5) ? argv[5]: 0;
112  if(0==strcmp(argv[2], "test")) test_dice(cfg, weights, filename);
113  else if(0==strcmp(argv[2], "train")) train_dice(cfg, weights);
114  else if(0==strcmp(argv[2], "valid")) validate_dice(cfg, weights);
115 }
116 
data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h)
Definition: data.c:1204
float decay
Definition: darknet.h:447
void train_dice(char *cfgfile, char *weightfile)
Definition: dice.c:5
void test_dice(char *cfgfile, char *weightfile, char *filename)
Definition: dice.c:66
void set_batch_network(network *net, int b)
Definition: network.c:339
float learning_rate
Definition: darknet.h:445
float momentum
Definition: darknet.h:446
void free_data(data d)
Definition: data.c:665
char * basecfg(char *cfgfile)
Definition: utils.c:179
void ** list_to_array(list *l)
Definition: list.c:82
size_t * seen
Definition: darknet.h:437
float train_network(network *net, data d)
Definition: network.c:314
int size
Definition: darknet.h:603
void free_list(list *l)
Definition: list.c:67
Definition: darknet.h:512
char * dice_labels[]
Definition: dice.c:3
float * network_accuracies(network *net, data d, int n)
Definition: network.c:689
void save_weights(network *net, char *filename)
Definition: parser.c:1080
network * parse_network_cfg(char *filename)
Definition: parser.c:742
network_predict
Definition: darknet.py:79
free_image
Definition: darknet.py:95
image load_image_color(char *filename, int w, int h)
Definition: image.c:1486
void run_dice(int argc, char **argv)
Definition: dice.c:102
float sec(clock_t clocks)
Definition: utils.c:232
void validate_dice(char *filename, char *weightfile)
Definition: dice.c:45
Definition: darknet.h:602
int h
Definition: darknet.h:468
list * get_paths(char *filename)
Definition: data.c:12
void top_predictions(network *net, int n, int *index)
Definition: network.c:491
void load_weights(network *net, char *filename)
Definition: parser.c:1308
int w
Definition: darknet.h:468
Definition: darknet.h:538
float * data
Definition: darknet.h:516