darknet  v3
activation_kernels.cu
Go to the documentation of this file.
1 #include "cuda_runtime.h"
2 #include "curand.h"
3 #include "cublas_v2.h"
4 
5 extern "C" {
6 #include "activations.h"
7 #include "cuda.h"
8 }
9 
10 
11 __device__ float lhtan_activate_kernel(float x)
12 {
13  if(x < 0) return .001f*x;
14  if(x > 1) return .001f*(x-1.f) + 1.f;
15  return x;
16 }
17 __device__ float lhtan_gradient_kernel(float x)
18 {
19  if(x > 0 && x < 1) return 1;
20  return .001;
21 }
22 
23 __device__ float hardtan_activate_kernel(float x)
24 {
25  if (x < -1) return -1;
26  if (x > 1) return 1;
27  return x;
28 }
29 __device__ float linear_activate_kernel(float x){return x;}
30 __device__ float logistic_activate_kernel(float x){return 1.f/(1.f + expf(-x));}
31 __device__ float loggy_activate_kernel(float x){return 2.f/(1.f + expf(-x)) - 1;}
32 __device__ float relu_activate_kernel(float x){return x*(x>0);}
33 __device__ float elu_activate_kernel(float x){return (x >= 0)*x + (x < 0)*(expf(x)-1);}
34 __device__ float selu_activate_kernel(float x){return (x >= 0)*1.0507f*x + (x < 0)*1.0507f*1.6732f*(expf(x)-1);}
35 __device__ float relie_activate_kernel(float x){return (x>0) ? x : .01f*x;}
36 __device__ float ramp_activate_kernel(float x){return x*(x>0)+.1f*x;}
37 __device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1f*x;}
38 __device__ float tanh_activate_kernel(float x){return (2.f/(1 + expf(-2*x)) - 1);}
39 __device__ float plse_activate_kernel(float x)
40 {
41  if(x < -4) return .01f * (x + 4);
42  if(x > 4) return .01f * (x - 4) + 1;
43  return .125f*x + .5f;
44 }
45 __device__ float stair_activate_kernel(float x)
46 {
47  int n = floorf(x);
48  if (n%2 == 0) return floorf(x/2);
49  else return (x - n) + floorf(x/2);
50 }
51 
52 
53 __device__ float hardtan_gradient_kernel(float x)
54 {
55  if (x > -1 && x < 1) return 1;
56  return 0;
57 }
58 __device__ float linear_gradient_kernel(float x){return 1;}
59 __device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
60 __device__ float loggy_gradient_kernel(float x)
61 {
62  float y = (x+1)/2;
63  return 2*(1-y)*y;
64 }
65 __device__ float relu_gradient_kernel(float x){return (x>0);}
66 __device__ float elu_gradient_kernel(float x){return (x >= 0) + (x < 0)*(x + 1);}
67 __device__ float selu_gradient_kernel(float x){return (x >= 0)*1.0507 + (x < 0)*(x + 1.0507*1.6732);}
68 __device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01f;}
69 __device__ float ramp_gradient_kernel(float x){return (x>0)+.1f;}
70 __device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1f;}
71 __device__ float tanh_gradient_kernel(float x){return 1-x*x;}
72 __device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01f : .125f;}
73 __device__ float stair_gradient_kernel(float x)
74 {
75  if (floorf(x) == x) return 0;
76  return 1;
77 }
78 
79 __device__ float activate_kernel(float x, ACTIVATION a)
80 {
81  switch(a){
82  case LINEAR:
83  return linear_activate_kernel(x);
84  case LOGISTIC:
85  return logistic_activate_kernel(x);
86  case LOGGY:
87  return loggy_activate_kernel(x);
88  case RELU:
89  return relu_activate_kernel(x);
90  case ELU:
91  return elu_activate_kernel(x);
92  case SELU:
93  return selu_activate_kernel(x);
94  case RELIE:
95  return relie_activate_kernel(x);
96  case RAMP:
97  return ramp_activate_kernel(x);
98  case LEAKY:
99  return leaky_activate_kernel(x);
100  case TANH:
101  return tanh_activate_kernel(x);
102  case PLSE:
103  return plse_activate_kernel(x);
104  case STAIR:
105  return stair_activate_kernel(x);
106  case HARDTAN:
107  return hardtan_activate_kernel(x);
108  case LHTAN:
109  return lhtan_activate_kernel(x);
110  }
111  return 0;
112 }
113 
114 __device__ float gradient_kernel(float x, ACTIVATION a)
115 {
116  switch(a){
117  case LINEAR:
118  return linear_gradient_kernel(x);
119  case LOGISTIC:
120  return logistic_gradient_kernel(x);
121  case LOGGY:
122  return loggy_gradient_kernel(x);
123  case RELU:
124  return relu_gradient_kernel(x);
125  case ELU:
126  return elu_gradient_kernel(x);
127  case SELU:
128  return selu_gradient_kernel(x);
129  case RELIE:
130  return relie_gradient_kernel(x);
131  case RAMP:
132  return ramp_gradient_kernel(x);
133  case LEAKY:
134  return leaky_gradient_kernel(x);
135  case TANH:
136  return tanh_gradient_kernel(x);
137  case PLSE:
138  return plse_gradient_kernel(x);
139  case STAIR:
140  return stair_gradient_kernel(x);
141  case HARDTAN:
142  return hardtan_gradient_kernel(x);
143  case LHTAN:
144  return lhtan_gradient_kernel(x);
145  }
146  return 0;
147 }
148 
149 __global__ void binary_gradient_array_kernel(float *x, float *dy, int n, int s, BINARY_ACTIVATION a, float *dx)
150 {
151  int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
152  int i = id % s;
153  int b = id / s;
154  float x1 = x[b*s + i];
155  float x2 = x[b*s + s/2 + i];
156  if(id < n) {
157  float de = dy[id];
158  dx[b*s + i] = x2*de;
159  dx[b*s + s/2 + i] = x1*de;
160  }
161 }
162 
163 extern "C" void binary_gradient_array_gpu(float *x, float *dx, int n, int size, BINARY_ACTIVATION a, float *y)
164 {
165  binary_gradient_array_kernel<<<cuda_gridsize(n/2), BLOCK>>>(x, dx, n/2, size, a, y);
166  check_error(cudaPeekAtLastError());
167 }
168 __global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTIVATION a, float *y)
169 {
170  int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
171  int i = id % s;
172  int b = id / s;
173  float x1 = x[b*s + i];
174  float x2 = x[b*s + s/2 + i];
175  if(id < n) y[id] = x1*x2;
176 }
177 
178 extern "C" void binary_activate_array_gpu(float *x, int n, int size, BINARY_ACTIVATION a, float *y)
179 {
180  binary_activate_array_kernel<<<cuda_gridsize(n/2), BLOCK>>>(x, n/2, size, a, y);
181  check_error(cudaPeekAtLastError());
182 }
183 
184 __global__ void activate_array_kernel(float *x, int n, ACTIVATION a)
185 {
186  int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
187  if(i < n) x[i] = activate_kernel(x[i], a);
188 }
189 
190 __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta)
191 {
192  int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
193  if(i < n) delta[i] *= gradient_kernel(x[i], a);
194 }
195 
196 extern "C" void activate_array_gpu(float *x, int n, ACTIVATION a)
197 {
198  activate_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a);
199  check_error(cudaPeekAtLastError());
200 }
201 
202 extern "C" void gradient_array_gpu(float *x, int n, ACTIVATION a, float *delta)
203 {
204  gradient_array_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, a, delta);
205  check_error(cudaPeekAtLastError());
206 }
__device__ float relie_gradient_kernel(float x)
Definition: darknet.h:57
ACTIVATION
Definition: darknet.h:56
__device__ float gradient_kernel(float x, ACTIVATION a)
__global__ void activate_array_kernel(float *x, int n, ACTIVATION a)
__global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta)
__device__ float plse_gradient_kernel(float x)
Definition: darknet.h:57
__device__ float plse_activate_kernel(float x)
BINARY_ACTIVATION
Definition: darknet.h:60
__device__ float logistic_gradient_kernel(float x)
__device__ float ramp_activate_kernel(float x)
__device__ float tanh_gradient_kernel(float x)
__device__ float hardtan_gradient_kernel(float x)
__device__ float hardtan_activate_kernel(float x)
void binary_activate_array_gpu(float *x, int n, int size, BINARY_ACTIVATION a, float *y)
__device__ float leaky_activate_kernel(float x)
void binary_gradient_array_gpu(float *x, float *dx, int n, int size, BINARY_ACTIVATION a, float *y)
__device__ float relu_gradient_kernel(float x)
__global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTIVATION a, float *y)
__device__ float stair_gradient_kernel(float x)
__device__ float loggy_activate_kernel(float x)
__device__ float lhtan_gradient_kernel(float x)
Definition: darknet.h:57
Definition: darknet.h:57
__device__ float lhtan_activate_kernel(float x)
Definition: darknet.h:57
__device__ float stair_activate_kernel(float x)
__global__ void binary_gradient_array_kernel(float *x, float *dy, int n, int s, BINARY_ACTIVATION a, float *dx)
__device__ float logistic_activate_kernel(float x)
__device__ float relie_activate_kernel(float x)
__device__ float ramp_gradient_kernel(float x)
Definition: darknet.h:57
Definition: darknet.h:57
__device__ float activate_kernel(float x, ACTIVATION a)
__device__ float loggy_gradient_kernel(float x)
__device__ float relu_activate_kernel(float x)
__device__ float linear_gradient_kernel(float x)
__device__ float selu_gradient_kernel(float x)
__device__ float elu_activate_kernel(float x)
__device__ float leaky_gradient_kernel(float x)
void activate_array_gpu(float *x, int n, ACTIVATION a)
Definition: darknet.h:57
Definition: darknet.h:57
Definition: darknet.h:57
__device__ float linear_activate_kernel(float x)
Definition: darknet.h:57
__device__ float tanh_activate_kernel(float x)
void gradient_array_gpu(float *x, int n, ACTIVATION a, float *delta)
__device__ float elu_gradient_kernel(float x)
__device__ float selu_activate_kernel(float x)
Definition: darknet.h:57