darknet  v3
convolutional_layer.c
Go to the documentation of this file.
1 #include "convolutional_layer.h"
2 #include "utils.h"
3 #include "batchnorm_layer.h"
4 #include "im2col.h"
5 #include "col2im.h"
6 #include "blas.h"
7 #include "gemm.h"
8 #include <stdio.h>
9 #include <time.h>
10 
11 #ifdef AI2
12 #include "xnor_layer.h"
13 #endif
14 
16 {
17  float *swap = l->weights;
18  l->weights = l->binary_weights;
19  l->binary_weights = swap;
20 
21 #ifdef GPU
22  swap = l->weights_gpu;
23  l->weights_gpu = l->binary_weights_gpu;
24  l->binary_weights_gpu = swap;
25 #endif
26 }
27 
28 void binarize_weights(float *weights, int n, int size, float *binary)
29 {
30  int i, f;
31  for(f = 0; f < n; ++f){
32  float mean = 0;
33  for(i = 0; i < size; ++i){
34  mean += fabs(weights[f*size + i]);
35  }
36  mean = mean / size;
37  for(i = 0; i < size; ++i){
38  binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
39  }
40  }
41 }
42 
43 void binarize_cpu(float *input, int n, float *binary)
44 {
45  int i;
46  for(i = 0; i < n; ++i){
47  binary[i] = (input[i] > 0) ? 1 : -1;
48  }
49 }
50 
51 void binarize_input(float *input, int n, int size, float *binary)
52 {
53  int i, s;
54  for(s = 0; s < size; ++s){
55  float mean = 0;
56  for(i = 0; i < n; ++i){
57  mean += fabs(input[i*size + s]);
58  }
59  mean = mean / n;
60  for(i = 0; i < n; ++i){
61  binary[i*size + s] = (input[i*size + s] > 0) ? mean : -mean;
62  }
63  }
64 }
65 
67 {
68  return (l.h + 2*l.pad - l.size) / l.stride + 1;
69 }
70 
72 {
73  return (l.w + 2*l.pad - l.size) / l.stride + 1;
74 }
75 
77 {
78  return float_to_image(l.out_w,l.out_h,l.out_c,l.output);
79 }
80 
82 {
83  return float_to_image(l.out_w,l.out_h,l.out_c,l.delta);
84 }
85 
86 static size_t get_workspace_size(layer l){
87 #ifdef CUDNN
88  if(gpu_index >= 0){
89  size_t most = 0;
90  size_t s = 0;
91  cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
92  l.srcTensorDesc,
93  l.weightDesc,
94  l.convDesc,
95  l.dstTensorDesc,
96  l.fw_algo,
97  &s);
98  if (s > most) most = s;
99  cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
100  l.srcTensorDesc,
101  l.ddstTensorDesc,
102  l.convDesc,
103  l.dweightDesc,
104  l.bf_algo,
105  &s);
106  if (s > most) most = s;
107  cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
108  l.weightDesc,
109  l.ddstTensorDesc,
110  l.convDesc,
111  l.dsrcTensorDesc,
112  l.bd_algo,
113  &s);
114  if (s > most) most = s;
115  return most;
116  }
117 #endif
118  return (size_t)l.out_h*l.out_w*l.size*l.size*l.c/l.groups*sizeof(float);
119 }
120 
121 #ifdef GPU
122 #ifdef CUDNN
123 void cudnn_convolutional_setup(layer *l)
124 {
125  cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
126  cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
127 
128  cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
129  cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
130  cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
131 
132  cudnnSetFilter4dDescriptor(l->dweightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c/l->groups, l->size, l->size);
133  cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c/l->groups, l->size, l->size);
134  #if CUDNN_MAJOR >= 6
135  cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT);
136  #else
137  cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
138  #endif
139 
140  #if CUDNN_MAJOR >= 7
141  cudnnSetConvolutionGroupCount(l->convDesc, l->groups);
142  #else
143  if(l->groups > 1){
144  error("CUDNN < 7 doesn't support groups, please upgrade!");
145  }
146  #endif
147 
148  cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
149  l->srcTensorDesc,
150  l->weightDesc,
151  l->convDesc,
152  l->dstTensorDesc,
153  CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
154  2000000000,
155  &l->fw_algo);
156  cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
157  l->weightDesc,
158  l->ddstTensorDesc,
159  l->convDesc,
160  l->dsrcTensorDesc,
161  CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
162  2000000000,
163  &l->bd_algo);
164  cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
165  l->srcTensorDesc,
166  l->ddstTensorDesc,
167  l->convDesc,
168  l->dweightDesc,
169  CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
170  2000000000,
171  &l->bf_algo);
172 }
173 #endif
174 #endif
175 
176 convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam)
177 {
178  int i;
179  convolutional_layer l = {0};
180  l.type = CONVOLUTIONAL;
181 
182  l.groups = groups;
183  l.h = h;
184  l.w = w;
185  l.c = c;
186  l.n = n;
187  l.binary = binary;
188  l.xnor = xnor;
189  l.batch = batch;
190  l.stride = stride;
191  l.size = size;
192  l.pad = padding;
193  l.batch_normalize = batch_normalize;
194 
195  l.weights = calloc(c/groups*n*size*size, sizeof(float));
196  l.weight_updates = calloc(c/groups*n*size*size, sizeof(float));
197 
198  l.biases = calloc(n, sizeof(float));
199  l.bias_updates = calloc(n, sizeof(float));
200 
201  l.nweights = c/groups*n*size*size;
202  l.nbiases = n;
203 
204  // float scale = 1./sqrt(size*size*c);
205  float scale = sqrt(2./(size*size*c/l.groups));
206  //printf("convscale %f\n", scale);
207  //scale = .02;
208  //for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1, 1);
209  for(i = 0; i < l.nweights; ++i) l.weights[i] = scale*rand_normal();
210  int out_w = convolutional_out_width(l);
211  int out_h = convolutional_out_height(l);
212  l.out_h = out_h;
213  l.out_w = out_w;
214  l.out_c = n;
215  l.outputs = l.out_h * l.out_w * l.out_c;
216  l.inputs = l.w * l.h * l.c;
217 
218  l.output = calloc(l.batch*l.outputs, sizeof(float));
219  l.delta = calloc(l.batch*l.outputs, sizeof(float));
220 
224  if(binary){
225  l.binary_weights = calloc(l.nweights, sizeof(float));
226  l.cweights = calloc(l.nweights, sizeof(char));
227  l.scales = calloc(n, sizeof(float));
228  }
229  if(xnor){
230  l.binary_weights = calloc(l.nweights, sizeof(float));
231  l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
232  }
233 
234  if(batch_normalize){
235  l.scales = calloc(n, sizeof(float));
236  l.scale_updates = calloc(n, sizeof(float));
237  for(i = 0; i < n; ++i){
238  l.scales[i] = 1;
239  }
240 
241  l.mean = calloc(n, sizeof(float));
242  l.variance = calloc(n, sizeof(float));
243 
244  l.mean_delta = calloc(n, sizeof(float));
245  l.variance_delta = calloc(n, sizeof(float));
246 
247  l.rolling_mean = calloc(n, sizeof(float));
248  l.rolling_variance = calloc(n, sizeof(float));
249  l.x = calloc(l.batch*l.outputs, sizeof(float));
250  l.x_norm = calloc(l.batch*l.outputs, sizeof(float));
251  }
252  if(adam){
253  l.m = calloc(l.nweights, sizeof(float));
254  l.v = calloc(l.nweights, sizeof(float));
255  l.bias_m = calloc(n, sizeof(float));
256  l.scale_m = calloc(n, sizeof(float));
257  l.bias_v = calloc(n, sizeof(float));
258  l.scale_v = calloc(n, sizeof(float));
259  }
260 
261 #ifdef GPU
265 
266  if(gpu_index >= 0){
267  if (adam) {
268  l.m_gpu = cuda_make_array(l.m, l.nweights);
269  l.v_gpu = cuda_make_array(l.v, l.nweights);
270  l.bias_m_gpu = cuda_make_array(l.bias_m, n);
271  l.bias_v_gpu = cuda_make_array(l.bias_v, n);
272  l.scale_m_gpu = cuda_make_array(l.scale_m, n);
273  l.scale_v_gpu = cuda_make_array(l.scale_v, n);
274  }
275 
276  l.weights_gpu = cuda_make_array(l.weights, l.nweights);
277  l.weight_updates_gpu = cuda_make_array(l.weight_updates, l.nweights);
278 
279  l.biases_gpu = cuda_make_array(l.biases, n);
280  l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
281 
282  l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
283  l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
284 
285  if(binary){
286  l.binary_weights_gpu = cuda_make_array(l.weights, l.nweights);
287  }
288  if(xnor){
289  l.binary_weights_gpu = cuda_make_array(l.weights, l.nweights);
290  l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
291  }
292 
293  if(batch_normalize){
294  l.mean_gpu = cuda_make_array(l.mean, n);
295  l.variance_gpu = cuda_make_array(l.variance, n);
296 
297  l.rolling_mean_gpu = cuda_make_array(l.mean, n);
298  l.rolling_variance_gpu = cuda_make_array(l.variance, n);
299 
300  l.mean_delta_gpu = cuda_make_array(l.mean, n);
301  l.variance_delta_gpu = cuda_make_array(l.variance, n);
302 
303  l.scales_gpu = cuda_make_array(l.scales, n);
304  l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
305 
306  l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
307  l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
308  }
309 #ifdef CUDNN
310  cudnnCreateTensorDescriptor(&l.normTensorDesc);
311  cudnnCreateTensorDescriptor(&l.srcTensorDesc);
312  cudnnCreateTensorDescriptor(&l.dstTensorDesc);
313  cudnnCreateFilterDescriptor(&l.weightDesc);
314  cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
315  cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
316  cudnnCreateFilterDescriptor(&l.dweightDesc);
317  cudnnCreateConvolutionDescriptor(&l.convDesc);
318  cudnn_convolutional_setup(&l);
319 #endif
320  }
321 #endif
322  l.workspace_size = get_workspace_size(l);
323  l.activation = activation;
324 
325  fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BFLOPs\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c, (2.0 * l.n * l.size*l.size*l.c/l.groups * l.out_h*l.out_w)/1000000000.);
326 
327  return l;
328 }
329 
331 {
332  int i, j;
333  for(i = 0; i < l.n; ++i){
334  float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001);
335  for(j = 0; j < l.c/l.groups*l.size*l.size; ++j){
336  l.weights[i*l.c/l.groups*l.size*l.size + j] *= scale;
337  }
338  l.biases[i] -= l.rolling_mean[i] * scale;
339  l.scales[i] = 1;
340  l.rolling_mean[i] = 0;
341  l.rolling_variance[i] = 1;
342  }
343 }
344 
345 /*
346 void test_convolutional_layer()
347 {
348  convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0);
349  l.batch_normalize = 1;
350  float data[] = {1,1,1,1,1,
351  1,1,1,1,1,
352  1,1,1,1,1,
353  1,1,1,1,1,
354  1,1,1,1,1,
355  2,2,2,2,2,
356  2,2,2,2,2,
357  2,2,2,2,2,
358  2,2,2,2,2,
359  2,2,2,2,2,
360  3,3,3,3,3,
361  3,3,3,3,3,
362  3,3,3,3,3,
363  3,3,3,3,3,
364  3,3,3,3,3};
365  //net.input = data;
366  //forward_convolutional_layer(l);
367 }
368 */
369 
371 {
372  l->w = w;
373  l->h = h;
374  int out_w = convolutional_out_width(*l);
375  int out_h = convolutional_out_height(*l);
376 
377  l->out_w = out_w;
378  l->out_h = out_h;
379 
380  l->outputs = l->out_h * l->out_w * l->out_c;
381  l->inputs = l->w * l->h * l->c;
382 
383  l->output = realloc(l->output, l->batch*l->outputs*sizeof(float));
384  l->delta = realloc(l->delta, l->batch*l->outputs*sizeof(float));
385  if(l->batch_normalize){
386  l->x = realloc(l->x, l->batch*l->outputs*sizeof(float));
387  l->x_norm = realloc(l->x_norm, l->batch*l->outputs*sizeof(float));
388  }
389 
390 #ifdef GPU
391  cuda_free(l->delta_gpu);
392  cuda_free(l->output_gpu);
393 
394  l->delta_gpu = cuda_make_array(l->delta, l->batch*l->outputs);
395  l->output_gpu = cuda_make_array(l->output, l->batch*l->outputs);
396 
397  if(l->batch_normalize){
398  cuda_free(l->x_gpu);
399  cuda_free(l->x_norm_gpu);
400 
401  l->x_gpu = cuda_make_array(l->output, l->batch*l->outputs);
402  l->x_norm_gpu = cuda_make_array(l->output, l->batch*l->outputs);
403  }
404 #ifdef CUDNN
405  cudnn_convolutional_setup(l);
406 #endif
407 #endif
408  l->workspace_size = get_workspace_size(*l);
409 }
410 
411 void add_bias(float *output, float *biases, int batch, int n, int size)
412 {
413  int i,j,b;
414  for(b = 0; b < batch; ++b){
415  for(i = 0; i < n; ++i){
416  for(j = 0; j < size; ++j){
417  output[(b*n + i)*size + j] += biases[i];
418  }
419  }
420  }
421 }
422 
423 void scale_bias(float *output, float *scales, int batch, int n, int size)
424 {
425  int i,j,b;
426  for(b = 0; b < batch; ++b){
427  for(i = 0; i < n; ++i){
428  for(j = 0; j < size; ++j){
429  output[(b*n + i)*size + j] *= scales[i];
430  }
431  }
432  }
433 }
434 
435 void backward_bias(float *bias_updates, float *delta, int batch, int n, int size)
436 {
437  int i,b;
438  for(b = 0; b < batch; ++b){
439  for(i = 0; i < n; ++i){
440  bias_updates[i] += sum_array(delta+size*(i+b*n), size);
441  }
442  }
443 }
444 
446 {
447  int i, j;
448 
449  fill_cpu(l.outputs*l.batch, 0, l.output, 1);
450 
451  if(l.xnor){
453  swap_binary(&l);
454  binarize_cpu(net.input, l.c*l.h*l.w*l.batch, l.binary_input);
455  net.input = l.binary_input;
456  }
457 
458  int m = l.n/l.groups;
459  int k = l.size*l.size*l.c/l.groups;
460  int n = l.out_w*l.out_h;
461  for(i = 0; i < l.batch; ++i){
462  for(j = 0; j < l.groups; ++j){
463  float *a = l.weights + j*l.nweights/l.groups;
464  float *b = net.workspace;
465  float *c = l.output + (i*l.groups + j)*n*m;
466  float *im = net.input + (i*l.groups + j)*l.c/l.groups*l.h*l.w;
467 
468  if (l.size == 1) {
469  b = im;
470  } else {
471  im2col_cpu(im, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
472  }
473  gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
474  }
475  }
476 
477  if(l.batch_normalize){
478  forward_batchnorm_layer(l, net);
479  } else {
480  add_bias(l.output, l.biases, l.batch, l.n, l.out_h*l.out_w);
481  }
482 
484  if(l.binary || l.xnor) swap_binary(&l);
485 }
486 
488 {
489  int i, j;
490  int m = l.n/l.groups;
491  int n = l.size*l.size*l.c/l.groups;
492  int k = l.out_w*l.out_h;
493 
495 
496  if(l.batch_normalize){
497  backward_batchnorm_layer(l, net);
498  } else {
499  backward_bias(l.bias_updates, l.delta, l.batch, l.n, k);
500  }
501 
502  for(i = 0; i < l.batch; ++i){
503  for(j = 0; j < l.groups; ++j){
504  float *a = l.delta + (i*l.groups + j)*m*k;
505  float *b = net.workspace;
506  float *c = l.weight_updates + j*l.nweights/l.groups;
507 
508  float *im = net.input + (i*l.groups + j)*l.c/l.groups*l.h*l.w;
509  float *imd = net.delta + (i*l.groups + j)*l.c/l.groups*l.h*l.w;
510 
511  if(l.size == 1){
512  b = im;
513  } else {
514  im2col_cpu(im, l.c/l.groups, l.h, l.w,
515  l.size, l.stride, l.pad, b);
516  }
517 
518  gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
519 
520  if (net.delta) {
521  a = l.weights + j*l.nweights/l.groups;
522  b = l.delta + (i*l.groups + j)*m*k;
523  c = net.workspace;
524  if (l.size == 1) {
525  c = imd;
526  }
527 
528  gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
529 
530  if (l.size != 1) {
531  col2im_cpu(net.workspace, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, imd);
532  }
533  }
534  }
535  }
536 }
537 
539 {
540  float learning_rate = a.learning_rate*l.learning_rate_scale;
541  float momentum = a.momentum;
542  float decay = a.decay;
543  int batch = a.batch;
544 
545  axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
546  scal_cpu(l.n, momentum, l.bias_updates, 1);
547 
548  if(l.scales){
549  axpy_cpu(l.n, learning_rate/batch, l.scale_updates, 1, l.scales, 1);
550  scal_cpu(l.n, momentum, l.scale_updates, 1);
551  }
552 
553  axpy_cpu(l.nweights, -decay*batch, l.weights, 1, l.weight_updates, 1);
554  axpy_cpu(l.nweights, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
555  scal_cpu(l.nweights, momentum, l.weight_updates, 1);
556 }
557 
558 
560 {
561  int h = l.size;
562  int w = l.size;
563  int c = l.c/l.groups;
564  return float_to_image(w,h,c,l.weights+i*h*w*c);
565 }
566 
568 {
569  int i;
570  for(i = 0; i < l.n; ++i){
572  if (im.c == 3) {
573  rgbgr_image(im);
574  }
575  }
576 }
577 
578 void rescale_weights(convolutional_layer l, float scale, float trans)
579 {
580  int i;
581  for(i = 0; i < l.n; ++i){
583  if (im.c == 3) {
584  scale_image(im, scale);
585  float sum = sum_array(im.data, im.w*im.h*im.c);
586  l.biases[i] += sum*trans;
587  }
588  }
589 }
590 
592 {
593  image *weights = calloc(l.n, sizeof(image));
594  int i;
595  for(i = 0; i < l.n; ++i){
596  weights[i] = copy_image(get_convolutional_weight(l, i));
597  normalize_image(weights[i]);
598  /*
599  char buff[256];
600  sprintf(buff, "filter%d", i);
601  save_image(weights[i], buff);
602  */
603  }
604  //error("hey");
605  return weights;
606 }
607 
609 {
610  image *single_weights = get_weights(l);
611  show_images(single_weights, l.n, window);
612 
613  image delta = get_convolutional_image(l);
614  image dc = collapse_image_layers(delta, 1);
615  char buff[256];
616  sprintf(buff, "%s: Output", window);
617  //show_image(dc, buff);
618  //save_image(dc, buff);
619  free_image(dc);
620  return single_weights;
621 }
622 
void backward_bias(float *bias_updates, float *delta, int batch, int n, int size)
size_t workspace_size
Definition: darknet.h:336
float momentum
Definition: darknet.h:104
ACTIVATION activation
Definition: darknet.h:121
ACTIVATION
Definition: darknet.h:56
float * scales
Definition: darknet.h:239
image * visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_weights)
image copy_image(image p)
Definition: image.c:519
float * mean
Definition: darknet.h:252
float * biases
Definition: darknet.h:236
void forward_convolutional_layer_gpu(convolutional_layer l, network net)
float * weight_updates
Definition: darknet.h:243
int w
Definition: darknet.h:140
int pad
Definition: darknet.h:151
int n
Definition: darknet.h:142
void(* update)(struct layer, update_args)
Definition: darknet.h:125
void(* forward_gpu)(struct layer, struct network)
Definition: darknet.h:126
float * scale_v
Definition: darknet.h:270
int binary
Definition: darknet.h:155
float learning_rate
Definition: darknet.h:103
float * rolling_variance
Definition: darknet.h:259
void add_bias(float *output, float *biases, int batch, int n, int size)
void(* backward_gpu)(struct layer, struct network)
Definition: darknet.h:127
void denormalize_convolutional_layer(convolutional_layer l)
float * x
Definition: darknet.h:261
void(* update_gpu)(struct layer, update_args)
Definition: darknet.h:128
float decay
Definition: darknet.h:105
Definition: darknet.h:512
float * binary_input
Definition: darknet.h:293
int convolutional_out_width(convolutional_layer l)
void(* forward)(struct layer, struct network)
Definition: darknet.h:123
int out_w
Definition: darknet.h:141
void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta)
Definition: activations.c:143
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam)
image float_to_image(int w, int h, int c, float *data)
Definition: image.c:774
float * delta
Definition: darknet.h:486
rgbgr_image
Definition: darknet.py:110
int nweights
Definition: darknet.h:136
int out_c
Definition: darknet.h:141
void show_images(image *ims, int n, char *window)
Definition: image.c:1596
int h
Definition: darknet.h:514
float * workspace
Definition: darknet.h:487
float * variance_delta
Definition: darknet.h:256
float * v
Definition: darknet.h:265
image get_convolutional_image(convolutional_layer l)
int batch_normalize
Definition: darknet.h:129
void fill_cpu(int N, float ALPHA, float *X, int INCX)
Definition: blas.c:190
int size
Definition: darknet.h:145
int batch
Definition: darknet.h:102
int xnor
Definition: darknet.h:156
int convolutional_out_height(convolutional_layer l)
free_image
Definition: darknet.py:95
float * bias_m
Definition: darknet.h:267
int h
Definition: darknet.h:140
void rgbgr_weights(convolutional_layer l)
float * delta
Definition: darknet.h:245
int out_h
Definition: darknet.h:141
void gemm(int TA, int TB, int M, int N, int K, float ALPHA, float *A, int lda, float *B, int ldb, float BETA, float *C, int ldc)
Definition: gemm.c:65
int inputs
Definition: darknet.h:134
void normalize_image(image p)
Definition: image.c:465
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
Definition: blas.c:178
void(* backward)(struct layer, struct network)
Definition: darknet.h:124
float * x_norm
Definition: darknet.h:262
void backward_convolutional_layer_gpu(convolutional_layer l, network net)
image get_convolutional_delta(convolutional_layer l)
void forward_convolutional_layer(convolutional_layer l, network net)
int batch
Definition: darknet.h:131
image get_convolutional_weight(convolutional_layer l, int i)
float * output
Definition: darknet.h:246
void scal_cpu(int N, float ALPHA, float *X, int INCX)
Definition: blas.c:184
float sum_array(float *a, int n)
Definition: utils.c:479
float * rolling_mean
Definition: darknet.h:258
void backward_convolutional_layer(convolutional_layer l, network net)
void resize_convolutional_layer(convolutional_layer *l, int w, int h)
float * mean_delta
Definition: darknet.h:255
int groups
Definition: darknet.h:144
void binarize_weights(float *weights, int n, int size, float *binary)
void binarize_input(float *input, int n, int size, float *binary)
void im2col_cpu(float *data_im, int channels, int height, int width, int ksize, int stride, int pad, float *data_col)
Definition: im2col.c:16
float rand_normal()
Definition: utils.c:654
image * get_weights(convolutional_layer l)
float * bias_updates
Definition: darknet.h:237
int c
Definition: darknet.h:515
int w
Definition: darknet.h:513
float learning_rate_scale
Definition: darknet.h:168
int stride
Definition: darknet.h:147
void binarize_cpu(float *input, int n, float *binary)
int c
Definition: darknet.h:140
void activate_array(float *x, const int n, const ACTIVATION a)
Definition: activations.c:100
void update_convolutional_layer(convolutional_layer l, update_args a)
void update_convolutional_layer_gpu(layer l, update_args a)
void rescale_weights(convolutional_layer l, float scale, float trans)
int gpu_index
Definition: cuda.c:1
LAYER_TYPE type
Definition: darknet.h:120
float * scale_m
Definition: darknet.h:269
float * input
Definition: darknet.h:484
float * scale_updates
Definition: darknet.h:240
void forward_batchnorm_layer(layer l, network net)
image collapse_image_layers(image source, int border)
Definition: image.c:441
float * binary_weights
Definition: darknet.h:234
void scale_bias(float *output, float *scales, int batch, int n, int size)
int outputs
Definition: darknet.h:135
int nbiases
Definition: darknet.h:137
float * m
Definition: darknet.h:264
float * variance
Definition: darknet.h:253
void col2im_cpu(float *data_col, int channels, int height, int width, int ksize, int stride, int pad, float *data_im)
Definition: col2im.c:14
void backward_batchnorm_layer(layer l, network net)
void swap_binary(convolutional_layer *l)
char * cweights
Definition: darknet.h:214
void error(const char *s)
Definition: utils.c:253
float * bias_v
Definition: darknet.h:268
void scale_image(image m, float s)
Definition: image.c:855
Definition: darknet.h:119
float * data
Definition: darknet.h:516
float * weights
Definition: darknet.h:242