Skip to content

Commit

Permalink
6-3
Browse files Browse the repository at this point in the history
  • Loading branch information
1zgh committed Jun 4, 2019
1 parent 9f0780e commit b2f0cf5
Show file tree
Hide file tree
Showing 12 changed files with 816 additions and 7 deletions.
Empty file added --device
Empty file.
4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions .idea/st-gcn.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

775 changes: 775 additions & 0 deletions .idea/workspace.xml

Large diffs are not rendered by default.

Empty file added [--video
Empty file.
8 changes: 6 additions & 2 deletions net/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,18 @@ def get_hop_distance(num_node, edge, max_hop=1):
A[i, j] = 1

# compute hop steps
hop_dis = np.zeros((num_node, num_node)) + np.inf
hop_dis = np.zeros((num_node, num_node)) + np.inf # np.inf 表示一个无穷大的正数
# np.linalg.matrix_power(A, d)求矩阵A的d幂次方,transfer_mat矩阵(I,A)是一个将A矩阵拼接max_hop+1次的矩阵
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
# (np.stack(transfer_mat) > 0)矩阵中大于0的返回Ture,小于0的返回False,最终arrive_mat是一个布尔矩阵,大小与transfer_mat一样
arrive_mat = (np.stack(transfer_mat) > 0)
# range(start,stop,step) step=-1表示倒着取
for d in range(max_hop, -1, -1):
# 将arrive_mat[d]矩阵中为True的对应于hop_dis[]位置的数设置为d
hop_dis[arrive_mat[d]] = d
return hop_dis


# 将矩阵A中的每一列的各个元素分别除以此列元素的形成新的矩阵
def normalize_digraph(A):
Dl = np.sum(A, 0)
num_node = A.shape[0]
Expand Down
1 change: 1 addition & 0 deletions processor/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def load_arg(self, argv=None):
if p.config is not None:
# load config file
with open(p.config, 'r') as f:
# default_arg是字典的形式
default_arg = yaml.load(f)

# update parser from config file
Expand Down
10 changes: 5 additions & 5 deletions processor/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from .processor import Processor

def weights_init(m):
classname = m.__class__.__name__ # __class__查看对象所在的类
if classname.find('Conv1d') != -1:
m.weight.data.normal_(0.0, 0.02)
classname = m.__class__.__name__ # __class__查看对象所在的类
if classname.find('Conv1d') != -1: # 如果发现这个Conv1d字符串
m.weight.data.normal_(0.0, 0.02) #平均值为0.0,方差为0.02的数据正则化
if m.bias is not None:
m.bias.data.fill_(0)
elif classname.find('Conv2d') != -1:
Expand All @@ -40,8 +40,8 @@ class REC_Processor(Processor):
def load_model(self):
self.model = self.io.load_model(self.arg.model,
**(self.arg.model_args))
self.model.apply(weights_init)
self.loss = nn.CrossEntropyLoss()
self.model.apply(weights_init) #apply()函数复制函数weights_init()的功能给self.modle
self.loss = nn.CrossEntropyLoss() #定义交叉商和损失函数

def load_optimizer(self):
if self.arg.optimizer == 'SGD':
Expand Down
Binary file added resource/media/cuk.mp4
Binary file not shown.
Binary file added resource/media/cuk2.mp4
Binary file not shown.

0 comments on commit b2f0cf5

Please sign in to comment.