-
Notifications
You must be signed in to change notification settings - Fork 18
/
net_spiking.cuh
43 lines (34 loc) · 957 Bytes
/
net_spiking.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#ifndef _NET_SPIKING_CUH_
#define _NET_SPIKING_CUH_
#include "common/cuMatrix.h"
#include <vector>
#include <stdio.h>
#include <cuda_runtime.h>
#include "common/cuMatrixVector.h"
/*
* function : read the network weight from checkpoint
* parameter :
* path : the path for the checkpoint file
*/
void cuReadSpikingNet(const char* path);
/*
* function: trainning the network
*/
void cuTrainSpikingNetwork(cuMatrixVector<bool>&x,
cuMatrix<int>*y ,
cuMatrixVector<bool>& testX,
cuMatrix<int>* testY,
int batch,
int nclasses,
std::vector<float>&nlrate,
std::vector<float>&nMomentum,
std::vector<int>&epoCount,
cublasHandle_t handle);
void buildSpikingNetwork(int trainLen, int testLen);
void cuFreeSpikingNet();
void cuFreeSNNMemory(
int batch,
cuMatrixVector<bool>&trainX,
cuMatrixVector<bool>&testX);
void getSpikingNetworkCost(int* y);
#endif