11 void train_compare(
char *cfgfile,
char *weightfile)
16 char *backup_directory =
"/home/pjreddie/backup/";
44 int epoch = *net.
seen/N;
49 pthread_join(load_thread, 0);
53 printf(
"Loaded: %lf seconds\n",
sec(clock()-time));
56 if(avg_loss == -1) avg_loss = loss;
57 avg_loss = avg_loss*.9 + loss*.1;
58 printf(
"%.3f: %f, %f avg, %lf seconds, %ld images\n", (
float)*net.
seen/N, loss, avg_loss,
sec(clock()-time), *net.
seen);
62 sprintf(buff,
"%s/%s_%d_minor_%d.weights",backup_directory,base, epoch, i);
65 if(*net.
seen/N > epoch){
69 sprintf(buff,
"%s/%s_%d.weights",backup_directory,base, epoch);
74 pthread_join(load_thread, 0);
82 void validate_compare(
char *filename,
char *weightfile)
94 int N = plist->
size/2;
101 int num = (i+1)*N/splits - i*N/splits;
116 for(i = 1; i <= splits; ++i){
119 pthread_join(load_thread, 0);
122 num = (i+1)*N/splits - i*N/splits;
123 char **part = paths+(i*N/splits);
128 printf(
"Loaded: %d images in %lf seconds\n", val.
X.
rows,
sec(clock()-time));
133 for(j = 0; j < val.
y.
rows; ++j){
134 for(k = 0; k < 20; ++k){
135 if(val.
y.
vals[j][k*2] != val.
y.
vals[j][k*2+1]){
137 if((val.
y.
vals[j][k*2] < val.
y.
vals[j][k*2+1]) == (pred.
vals[j][k*2] < pred.
vals[j][k*2+1])){
144 printf(
"%d: Acc: %f, %lf seconds, %d images\n", i, (
float)correct/total,
sec(clock()-time), val.
X.
rows);
158 int total_compares = 0;
159 int current_class = 0;
161 int elo_comparator(
const void*a,
const void *b)
163 sortable_bbox box1 = *(sortable_bbox*)a;
164 sortable_bbox box2 = *(sortable_bbox*)b;
165 if(box1.elos[current_class] == box2.elos[current_class])
return 0;
166 if(box1.elos[current_class] > box2.elos[current_class])
return -1;
170 int bbox_comparator(
const void *a,
const void *b)
173 sortable_bbox box1 = *(sortable_bbox*)a;
174 sortable_bbox box2 = *(sortable_bbox*)b;
176 int class = box1.class;
180 float *X = calloc(net.
w*net.
h*net.
c,
sizeof(
float));
181 memcpy(X, im1.data, im1.w*im1.h*im1.c*
sizeof(
float));
182 memcpy(X+im1.w*im1.h*im1.c, im2.
data, im2.
w*im2.
h*im2.
c*
sizeof(
float));
188 if (predictions[
class*2] > predictions[
class*2+1]){
194 void bbox_update(sortable_bbox *a, sortable_bbox *b,
int class,
int result)
197 float EA = 1./(1+pow(10, (b->elos[
class] - a->elos[
class])/400.));
198 float EB = 1./(1+pow(10, (a->elos[
class] - b->elos[
class])/400.));
199 float SA = result ? 1 : 0;
200 float SB = result ? 0 : 1;
201 a->elos[
class] += k*(SA - EA);
202 b->elos[
class] += k*(SB - EB);
205 void bbox_fight(
network net, sortable_bbox *a, sortable_bbox *b,
int classes,
int class)
209 float *X = calloc(net.
w*net.
h*net.
c,
sizeof(
float));
210 memcpy(X, im1.
data, im1.
w*im1.
h*im1.
c*
sizeof(
float));
211 memcpy(X+im1.
w*im1.
h*im1.
c, im2.
data, im2.
w*im2.
h*im2.
c*
sizeof(
float));
217 if(
class < 0 ||
class == i){
218 int result = predictions[i*2] > predictions[i*2+1];
219 bbox_update(a, b, i, result);
228 void SortMaster3000(
char *filename,
char *weightfile)
243 sortable_bbox *boxes = calloc(N,
sizeof(sortable_bbox));
244 printf(
"Sorting %d boxes...\n", N);
245 for(i = 0; i < N; ++i){
246 boxes[i].filename = paths[i];
251 clock_t time=clock();
252 qsort(boxes, N,
sizeof(sortable_bbox), bbox_comparator);
253 for(i = 0; i < N; ++i){
254 printf(
"%s\n", boxes[i].filename);
256 printf(
"Sorted in %d compares, %f secs\n", total_compares,
sec(clock()-time));
259 void BattleRoyaleWithCheese(
char *filename,
char *weightfile)
278 sortable_bbox *boxes = calloc(N,
sizeof(sortable_bbox));
279 printf(
"Battling %d boxes...\n", N);
280 for(i = 0; i < N; ++i){
281 boxes[i].filename = paths[i];
284 boxes[i].elos = calloc(classes,
sizeof(
float));;
286 boxes[i].elos[j] = 1500;
290 clock_t time=clock();
291 for(round = 1; round <= 4; ++round){
292 clock_t round_time=clock();
293 printf(
"Round: %d\n", round);
294 shuffle(boxes, N,
sizeof(sortable_bbox));
295 for(i = 0; i < N/2; ++i){
296 bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1);
298 printf(
"Round: %f secs, %d remaining\n",
sec(clock()-round_time), N);
303 for (
class = 0;
class < classes; ++
class){
306 current_class =
class;
307 qsort(boxes, N,
sizeof(sortable_bbox), elo_comparator);
310 for(round = 1; round <= 100; ++round){
311 clock_t round_time=clock();
312 printf(
"Round: %d\n", round);
315 for(i = 0; i < N/2; ++i){
316 bbox_fight(net, boxes+i*2, boxes+i*2+1, classes,
class);
318 qsort(boxes, N,
sizeof(sortable_bbox), elo_comparator);
319 if(round <= 20) N = (N*9/10)/2*2;
321 printf(
"Round: %f secs, %d remaining\n",
sec(clock()-round_time), N);
324 sprintf(buff,
"results/battle_%d.log",
class);
325 FILE *outfp = fopen(buff,
"w");
326 for(i = 0; i < N; ++i){
327 fprintf(outfp,
"%s %f\n", boxes[i].filename, boxes[i].elos[
class]);
331 printf(
"Tournament in %d compares, %f secs\n", total_compares,
sec(clock()-time));
334 void run_compare(
int argc,
char **argv)
337 fprintf(stderr,
"usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
342 char *weights = (argc > 4) ? argv[4] : 0;
344 if(0==strcmp(argv[2],
"train")) train_compare(cfg, weights);
345 else if(0==strcmp(argv[2],
"valid")) validate_compare(cfg, weights);
346 else if(0==strcmp(argv[2],
"sort")) SortMaster3000(cfg, weights);
347 else if(0==strcmp(argv[2],
"battle")) BattleRoyaleWithCheese(cfg, weights);
pthread_t load_data_in_thread(load_args args)
void set_batch_network(network *net, int b)
char * basecfg(char *cfgfile)
void ** list_to_array(list *l)
float train_network(network *net, data d)
void free_network(network *net)
void save_weights(network *net, char *filename)
network * parse_network_cfg(char *filename)
void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections)
image load_image_color(char *filename, int w, int h)
void shuffle(void *arr, size_t n, size_t size)
float sec(clock_t clocks)
matrix network_predict_data(network *net, data test)
void free_matrix(matrix m)
void * load_thread(void *ptr)
list * get_paths(char *filename)
void load_weights(network *net, char *filename)