-
Notifications
You must be signed in to change notification settings - Fork 0
/
mindspore算子分析之randompoisson.html
626 lines (472 loc) · 115 KB
/
mindspore算子分析之randompoisson.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=2">
<meta name="theme-color" content="#222">
<meta name="generator" content="Hexo 3.9.0">
<link rel="apple-touch-icon" sizes="180x180" href="/images/icons/logo64.png">
<link rel="icon" type="image/png" sizes="32x32" href="/images/icons/logo32.png">
<link rel="icon" type="image/png" sizes="16x16" href="/images/icons/logo32.png">
<link rel="mask-icon" href="/images/icons/logo200.png" color="#222">
<link rel="stylesheet" href="/css/main.css">
<link rel="stylesheet" href="/lib/font-awesome/css/font-awesome.min.css">
<script id="hexo-configurations">
var NexT = window.NexT || {};
var CONFIG = {
root: '/',
scheme: 'Gemini',
version: '7.4.2',
exturl: false,
sidebar: {"position":"left","display":"post","offset":12,"onmobile":false},
copycode: {"enable":true,"show_result":true,"style":null},
back2top: {"enable":true,"sidebar":false,"scrollpercent":false},
bookmark: {"enable":false,"color":"#222","save":"auto"},
fancybox: false,
mediumzoom: false,
lazyload: false,
pangu: false,
algolia: {
appID: '',
apiKey: '',
indexName: '',
hits: {"per_page":10},
labels: {"input_placeholder":"Search for Posts","hits_empty":"We didn't find any results for the search: ${query}","hits_stats":"${hits} results found in ${time} ms"}
},
localsearch: {"enable":false,"trigger":"auto","top_n_per_article":1,"unescape":false,"preload":false},
path: '',
motion: {"enable":true,"async":false,"transition":{"post_block":"fadeIn","post_header":"slideDownIn","post_body":"slideDownIn","coll_header":"slideLeftIn","sidebar":"slideUpIn"}},
translation: {
copy_button: '复制',
copy_success: '复制成功',
copy_failure: '复制失败'
},
sidebarPadding: 40
};
</script>
<meta name="description" content="算子介绍定义泊松分布是一种统计与概率学里常见的离散概率分布,适合描述单位时间内随机事件发生的次数的概率分布,如:某一服务设施在一定时间内受到的服务请求的次数、电话交换机接到呼叫的次数、汽车站台的候车人数等。泊松分布的函数表达式为:$$P(X=k) = \frac{e^{-\Lambda}\Lambda^k}{k!}$$">
<meta name="keywords" content="Mindspore">
<meta property="og:type" content="article">
<meta property="og:title" content="Mindspore算子分析之RandomPoisson">
<meta property="og:url" content="http://www.fisheryung.top/mindspore算子分析之randompoisson.html">
<meta property="og:site_name" content="Fisher's Blog">
<meta property="og:description" content="算子介绍定义泊松分布是一种统计与概率学里常见的离散概率分布,适合描述单位时间内随机事件发生的次数的概率分布,如:某一服务设施在一定时间内受到的服务请求的次数、电话交换机接到呼叫的次数、汽车站台的候车人数等。泊松分布的函数表达式为:$$P(X=k) = \frac{e^{-\Lambda}\Lambda^k}{k!}$$">
<meta property="og:locale" content="zh-CN">
<meta property="og:image" content="http://www.fisheryung.top/images/mindspore/RandomPoisson%E6%89%A7%E8%A1%8C%E6%B5%81%E7%A8%8B%E5%9B%BE.png">
<meta property="og:image" content="http://www.fisheryung.top/images/mindspore/RandomPoisson_test_result.png">
<meta property="og:updated_time" content="2022-09-22T03:04:44.804Z">
<meta name="twitter:card" content="summary">
<meta name="twitter:title" content="Mindspore算子分析之RandomPoisson">
<meta name="twitter:description" content="算子介绍定义泊松分布是一种统计与概率学里常见的离散概率分布,适合描述单位时间内随机事件发生的次数的概率分布,如:某一服务设施在一定时间内受到的服务请求的次数、电话交换机接到呼叫的次数、汽车站台的候车人数等。泊松分布的函数表达式为:$$P(X=k) = \frac{e^{-\Lambda}\Lambda^k}{k!}$$">
<meta name="twitter:image" content="http://www.fisheryung.top/images/mindspore/RandomPoisson%E6%89%A7%E8%A1%8C%E6%B5%81%E7%A8%8B%E5%9B%BE.png">
<link rel="canonical" href="http://www.fisheryung.top/mindspore算子分析之randompoisson.html">
<script id="page-configurations">
// https://hexo.io/docs/variables.html
CONFIG.page = {
sidebar: "",
isHome: false,
isPost: true,
isPage: false,
isArchive: false
};
</script>
<title>Mindspore算子分析之RandomPoisson | Fisher's Blog</title>
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-111555456-1"></script>
<script>
var host = window.location.hostname;
if (host !== "localhost" || !true) {
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'UA-111555456-1');
}
</script>
<noscript>
<style>
.use-motion .brand,
.use-motion .menu-item,
.sidebar-inner,
.use-motion .post-block,
.use-motion .pagination,
.use-motion .comments,
.use-motion .post-header,
.use-motion .post-body,
.use-motion .collection-header { opacity: initial; }
.use-motion .site-title,
.use-motion .site-subtitle {
opacity: initial;
top: initial;
}
.use-motion .logo-line-before i { left: initial; }
.use-motion .logo-line-after i { right: initial; }
</style>
</noscript>
</head>
<body itemscope itemtype="http://schema.org/WebPage">
<div class="container use-motion">
<div class="headband"></div>
<header class="header" itemscope itemtype="http://schema.org/WPHeader">
<div class="header-inner"><div class="site-brand-container">
<div class="site-meta">
<div>
<a href="/" class="brand" rel="start">
<span class="logo-line-before"><i></i></span>
<span class="site-title">Fisher's Blog</span>
<span class="logo-line-after"><i></i></span>
</a>
</div>
<p class="site-subtitle">自由尋覓快樂別人從沒法感受</p>
</div>
<div class="site-nav-toggle">
<div class="toggle" aria-label="切换导航栏">
<span class="toggle-line toggle-line-first"></span>
<span class="toggle-line toggle-line-middle"></span>
<span class="toggle-line toggle-line-last"></span>
</div>
</div>
</div>
<nav class="site-nav">
<ul id="menu" class="menu">
<li class="menu-item menu-item-home">
<a href="/" rel="section"><i class="fa fa-fw fa-home"></i>首页</a>
</li>
<li class="menu-item menu-item-about">
<a href="/about/" rel="section"><i class="fa fa-fw fa-user"></i>关于</a>
</li>
<li class="menu-item menu-item-tags">
<a href="/tags/" rel="section"><i class="fa fa-fw fa-tags"></i>标签</a>
</li>
<li class="menu-item menu-item-archives">
<a href="/archives/" rel="section"><i class="fa fa-fw fa-archive"></i>归档</a>
</li>
</ul>
</nav>
</div>
</header>
<div class="back-to-top">
<i class="fa fa-arrow-up"></i>
<span>0%</span>
</div>
<div class="reading-progress-bar"></div>
<a href="https://github.com/FisherWY" class="github-corner" title="Follow me on GitHub" aria-label="Follow me on GitHub" rel="noopener" target="_blank"><svg width="80" height="80" viewBox="0 0 250 250" aria-hidden="true"><path d="M0,0 L115,115 L130,115 L142,142 L250,250 L250,0 Z"></path><path d="M128.3,109.0 C113.8,99.7 119.0,89.6 119.0,89.6 C122.0,82.7 120.5,78.6 120.5,78.6 C119.2,72.0 123.4,76.3 123.4,76.3 C127.3,80.9 125.5,87.3 125.5,87.3 C122.9,97.6 130.6,101.9 134.4,103.2" fill="currentColor" style="transform-origin: 130px 106px;" class="octo-arm"></path><path d="M115.0,115.0 C114.9,115.1 118.7,116.5 119.8,115.4 L133.7,101.6 C136.9,99.2 139.9,98.4 142.2,98.6 C133.8,88.0 127.5,74.4 143.8,58.0 C148.5,53.4 154.0,51.2 159.7,51.0 C160.3,49.4 163.2,43.6 171.4,40.1 C171.4,40.1 176.1,42.5 178.8,56.2 C183.1,58.6 187.2,61.8 190.9,65.4 C194.5,69.0 197.7,73.2 200.1,77.6 C213.8,80.2 216.3,84.9 216.3,84.9 C212.7,93.1 206.9,96.0 205.4,96.6 C205.1,102.4 203.0,107.8 198.3,112.5 C181.9,128.9 168.3,122.5 157.7,114.1 C157.9,116.9 156.7,120.9 152.7,124.9 L141.0,136.5 C139.8,137.7 141.6,141.9 141.8,141.8 Z" fill="currentColor" class="octo-body"></path></svg></a>
<main class="main">
<div class="main-inner">
<div class="content-wrap">
<div class="content">
<div class="posts-expand">
<article itemscope itemtype="http://schema.org/Article" class="post-block " lang="zh-CN">
<link itemprop="mainEntityOfPage" href="http://www.fisheryung.top/mindspore算子分析之randompoisson.html">
<span hidden itemprop="author" itemscope itemtype="http://schema.org/Person">
<meta itemprop="image" content="/images/icons/avatar.jpg">
<meta itemprop="name" content="Fisher">
<meta itemprop="description" content="记录学习生活中的点滴">
</span>
<span hidden itemprop="publisher" itemscope itemtype="http://schema.org/Organization">
<meta itemprop="name" content="Fisher's Blog">
</span>
<header class="post-header">
<h1 class="post-title" itemprop="name headline">
Mindspore算子分析之RandomPoisson
</h1>
<div class="post-meta">
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="fa fa-calendar-o"></i>
</span>
<span class="post-meta-item-text">发表于</span>
<time title="创建时间:2022-09-20 19:45:11" itemprop="dateCreated datePublished" datetime="2022-09-20T19:45:11+08:00">2022-09-20</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="fa fa-calendar-check-o"></i>
</span>
<span class="post-meta-item-text">更新于</span>
<time title="修改时间:2022-09-22 11:04:44" itemprop="dateModified" datetime="2022-09-22T11:04:44+08:00">2022-09-22</time>
</span>
<span id="/mindspore算子分析之randompoisson.html" class="post-meta-item leancloud_visitors" data-flag-title="Mindspore算子分析之RandomPoisson" title="阅读次数">
<span class="post-meta-item-icon">
<i class="fa fa-eye"></i>
</span>
<span class="post-meta-item-text">阅读次数:</span>
<span class="leancloud-visitors-count"></span>
</span>
</div>
</header>
<div class="post-body" itemprop="articleBody">
<h1 id="算子介绍"><a href="#算子介绍" class="headerlink" title="算子介绍"></a>算子介绍</h1><h2 id="定义"><a href="#定义" class="headerlink" title="定义"></a>定义</h2><p>泊松分布是一种统计与概率学里常见的离散概率分布,适合描述单位时间内随机事件发生的次数的概率分布,如:某一服务设施在一定时间内受到的服务请求的次数、电话交换机接到呼叫的次数、汽车站台的候车人数等。</p><p>泊松分布的函数表达式为:</p><p>$$P(X=k) = \frac{e^{-\Lambda}\Lambda^k}{k!}$$</p><a id="more"></a>
<p>其中,参数$\Lambda$是随机事件发生次数的数学期望值。若$X$服从参数为$\Lambda$的泊松分布,记为$X \sim Pois(\Lambda)$。</p>
<p>关于泊松分布的更多定义,可以参考Wikipedia,<a href="https://zh.wikipedia.org/wiki/%E5%8D%9C%E7%93%A6%E6%9D%BE%E5%88%86%E5%B8%83" target="_blank" rel="noopener">链接</a>。</p>
<h2 id="算子功能"><a href="#算子功能" class="headerlink" title="算子功能"></a>算子功能</h2><p>给定描述待采样矩阵维度的1-D Tensor: <code>shape</code>,以及期望值<code>rate</code>(对应泊松分布定义中的$\Lambda$),算子根据输入的期望值<code>rate</code>以及待采样矩阵<code>shape</code>,生成随机泊松分布的采样结果,采样结果的维度由给定的<code>shape</code>与期望值<code>rate</code>的维度拼接而成。比如说,输入的待采样矩阵的维度<code>shape</code>为<code>[3, 3]</code>,输入的期望值<code>rate</code>维度为5,则输出的结果维度为<code>[3, 3, 5]</code>。</p>
<h1 id="源码分析"><a href="#源码分析" class="headerlink" title="源码分析"></a>源码分析</h1><h2 id="Python侧算子原语定义"><a href="#Python侧算子原语定义" class="headerlink" title="Python侧算子原语定义"></a>Python侧算子原语定义</h2><p>在<code>Mindspore</code>中,所有算子都使用算子原语(Primitive)进行封装,为底层<code>Ascend</code>,<code>GPU</code>,<code>AICPU</code>,<code>CPU</code>等设备的算子具体实现提供统一的调用接口。在Python侧的算子定义中,通常只需要实现类最基本的初始化即可,也就是实现<code>__init__</code>函数,在初始化过程中需要对初始化的参数进行合法性检查,在文件<code>mindspore/python/mindspore/ops/operations/random_ops.py</code>中,<code>RandomPoisson</code>算子的Python侧原语定义如下。</p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">RandomPoisson</span><span class="params">(Primitive)</span>:</span></span><br><span class="line"> <span class="string">r"""</span></span><br><span class="line"><span class="string"> Produces random non-negative values i, distributed according to discrete probability function:</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> .. math::</span></span><br><span class="line"><span class="string"> \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!},</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> seed (int): An optional int. Defaults to 0. If either `seed` or `seed2` are set to be non-zero,</span></span><br><span class="line"><span class="string"> the seed is set by the given seed. Otherwise, it is seeded by a random seed.</span></span><br><span class="line"><span class="string"> seed2 (int): An optional int. Defaults to 0. A second seed to avoid seed collision.</span></span><br><span class="line"><span class="string"> dtype (mindspore.dtype): The type of output. Default: mindspore.int64.</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Inputs:</span></span><br><span class="line"><span class="string"> - **shape** (Tensor) - The shape of random tensor to be generated, 1-D Tensor, whose dtype must be in</span></span><br><span class="line"><span class="string"> [int32, int64]</span></span><br><span class="line"><span class="string"> - **rate** (Tensor) - μ parameter the distribution was constructed with. The parameter defines mean number</span></span><br><span class="line"><span class="string"> of occurrences of the event. Its type must be in [float16, float32, float64, int32, int64]</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Outputs:</span></span><br><span class="line"><span class="string"> Tensor. Its shape is (*shape, *rate.shape). Its type is spcified by `dtype`.</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Raises:</span></span><br><span class="line"><span class="string"> TypeError: If `shape` is not a Tensor or its dtype is not int32 or int64.</span></span><br><span class="line"><span class="string"> TypeError: If `dtype` is not int32 or int64.</span></span><br><span class="line"><span class="string"> ValueError: If `shape` is not a 1-D tensor.</span></span><br><span class="line"><span class="string"> ValueError: If `shape` elements are negative.</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Supported Platforms:</span></span><br><span class="line"><span class="string"> ``Ascend````GPU````CPU``</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Examples:</span></span><br><span class="line"><span class="string"> >>> shape = Tensor(np.array([2, 3]), mstype.int32)</span></span><br><span class="line"><span class="string"> >>> rate = Tensor(np.array([2, 2]), mstype.int32)</span></span><br><span class="line"><span class="string"> >>> seed = 0</span></span><br><span class="line"><span class="string"> >>> seed2 = 0</span></span><br><span class="line"><span class="string"> >>> random_poisson = ops.RandomPoisson(seed=seed, seed2=seed2)</span></span><br><span class="line"><span class="string"> >>> output = random_poisson(shape,rate)</span></span><br><span class="line"><span class="string"> >>> print(output.shape)</span></span><br><span class="line"><span class="string"> (2, 3, 2)</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"></span><br><span class="line"><span class="meta"> @prim_attr_register</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, seed=<span class="number">0</span>, seed2=<span class="number">0</span>, dtype=mstype.int64)</span>:</span></span><br><span class="line"> <span class="string">"""Initialize Poisson"""</span></span><br><span class="line"> self.init_prim_io_names(inputs=[<span class="string">'shape'</span>, <span class="string">'rate'</span>], outputs=[<span class="string">'output'</span>])</span><br><span class="line"> Validator.check_value_type(<span class="string">'seed'</span>, seed, [int], self.name)</span><br><span class="line"> Validator.check_value_type(<span class="string">'seed2'</span>, seed2, [int], self.name)</span><br><span class="line"> valid_values = (mstype.int64, mstype.int32, mstype.float16, mstype.float32, mstype.float64)</span><br><span class="line"> Validator.check_type_name(<span class="string">"dtype"</span>, dtype, valid_values, self.name)</span><br></pre></td></tr></table></figure>
<p>每一段代码对应的作用如下:</p>
<ul>
<li>首行:定义泊松算子类,继承于算子原语<code>Primitive</code>类。</li>
<li>注释部分:算子文档,描述算子功能、算子参数、算子输入、算子输出、抛出错误的类型、支持的平台、代码样例,该注释也用于生成<code>Mindspore</code>在线文档。</li>
<li><strong>init</strong>函数:算子初始化,<code>init_prim_io_names</code>方法向<code>Mindspore</code>框架注册该算子的输入名称和输出名称,<code>Validator</code>类用于检查输入参数是否合法。</li>
</ul>
<h2 id="C-侧算子注册"><a href="#C-侧算子注册" class="headerlink" title="C++侧算子注册"></a>C++侧算子注册</h2><p>在C++侧,算子首先需要将自己注册到<code>Mindspore</code>框架的全局变量中,来“告知”框架,存在一个名字叫<code>RandomPoisson</code>的算子,算子在<code>mindspore/core/ops/core_ops.h</code>中注册,代码如下。</p>
<figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">GVAR_DEF(PrimitivePtr, kPrimRandomPoisson, <span class="built_in">std</span>::make_shared<Primitive>(<span class="string">"RandomPoisson"</span>));</span><br></pre></td></tr></table></figure>
<h2 id="C-侧算子原语定义"><a href="#C-侧算子原语定义" class="headerlink" title="C++侧算子原语定义"></a>C++侧算子原语定义</h2><p>算子在C++侧也需要定义自己的<code>Primitive</code>原语,用于算子在C++侧的初始化、参数合法性检查等。C++侧的原语定义分为<code>mindspore/core/ops/random_poisson.h</code>和<code>mindspore/core/ops/random_poisson.cc</code>两个文件,头文件主要是定义算子类,cc文件主要是类成员函数的具体实现,由于源码代码比较长,因此将代码实现具体功能的描述都写在注释中。</p>
<figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// 定义随机泊松算子原语类</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">ifndef</span> MINDSPORE_CORE_OPS_RANDOM_POISSON_H_</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">define</span> MINDSPORE_CORE_OPS_RANDOM_POISSON_H_</span></span><br><span class="line"><span class="comment">// C++标准库</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><map></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><vector></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><string></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><memory></span></span></span><br><span class="line"><span class="comment">// Mindspore算子原语父类</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"ops/base_operator.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"mindapi/base/types.h"</span></span></span><br><span class="line"></span><br><span class="line"><span class="comment">// Mindspore命名空间域,算子需要定义在`mindspore::ops`中</span></span><br><span class="line"><span class="keyword">namespace</span> mindspore {</span><br><span class="line"><span class="keyword">namespace</span> ops {</span><br><span class="line"><span class="comment">// 随机泊松算子的名称</span></span><br><span class="line"><span class="keyword">constexpr</span> <span class="keyword">auto</span> kRandomPoisson = <span class="string">"RandomPoisson"</span>;</span><br><span class="line"><span class="comment">// 算子类定义,继承于`BaseOperator`</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">MIND_API</span> <span class="title">RandomPoisson</span> :</span> <span class="keyword">public</span> BaseOperator {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="comment">// 使用宏定义,定义使用智能指针作为参数的构造函数和默认的析构函数</span></span><br><span class="line"> MIND_API_BASE_MEMBER(RandomPoisson);</span><br><span class="line"> <span class="comment">// 类构造函数,构造算子的输入、输出</span></span><br><span class="line"> RandomPoisson() : BaseOperator(kRandomPoisson) { InitIOName({<span class="string">"shape"</span>, <span class="string">"rate"</span>}, {<span class="string">"output"</span>}); }</span><br><span class="line"> <span class="comment">// 实现父类`Init`方法,并没有实际作用</span></span><br><span class="line"> <span class="function"><span class="keyword">void</span> <span class="title">Init</span><span class="params">()</span> <span class="keyword">const</span> </span>{}</span><br><span class="line"> <span class="comment">// 算子随机种子`seed`的setter和getter方法</span></span><br><span class="line"> <span class="function"><span class="keyword">void</span> <span class="title">set_seed</span><span class="params">(<span class="keyword">const</span> <span class="keyword">int64_t</span> seed)</span></span>;</span><br><span class="line"> <span class="keyword">int64_t</span> get_seed() <span class="keyword">const</span>;</span><br><span class="line"> <span class="function"><span class="keyword">void</span> <span class="title">set_seed2</span><span class="params">(<span class="keyword">const</span> <span class="keyword">int64_t</span> seed2)</span></span>;</span><br><span class="line"> <span class="keyword">int64_t</span> get_seed2() <span class="keyword">const</span>;</span><br><span class="line">};</span><br><span class="line"><span class="comment">// 算子执行计算的正式入口</span></span><br><span class="line">abstract::<span class="function">AbstractBasePtr <span class="title">RandomPoissonInfer</span><span class="params">(<span class="keyword">const</span> abstract::AnalysisEnginePtr &, <span class="keyword">const</span> PrimitivePtr &primitive,</span></span></span><br><span class="line"><span class="function"><span class="params"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><abstract::AbstractBasePtr> &input_args)</span></span>;</span><br><span class="line"><span class="comment">// 指向算子的智能指针</span></span><br><span class="line"><span class="keyword">using</span> kPrimPrimRandomPoissonPtr = <span class="built_in">std</span>::<span class="built_in">shared_ptr</span><RandomPoisson>;</span><br><span class="line">} <span class="comment">// namespace ops</span></span><br><span class="line">} <span class="comment">// namespace mindspore</span></span><br><span class="line"></span><br><span class="line"><span class="meta">#<span class="meta-keyword">endif</span> <span class="comment">// MINDSPORE_CORE_OPS_RANDOM_POISSON_H_</span></span></span><br></pre></td></tr></table></figure>
<figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// 随机泊松算子的头文件</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"ops/random_poisson.h"</span></span></span><br><span class="line"><span class="comment">// C++标准库</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><string></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><algorithm></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><memory></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><set></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><vector></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><map></span></span></span><br><span class="line"><span class="comment">// Mindspore公共工具类函数库</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"ops/op_utils.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"utils/check_convert_utils.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"abstract/ops/primitive_infer_map.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"abstract/param_validator.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"mindapi/src/helper.h"</span></span></span><br><span class="line"><span class="comment">// 与头文件相同的Mindspore命名空间域</span></span><br><span class="line"><span class="keyword">namespace</span> mindspore {</span><br><span class="line"><span class="keyword">namespace</span> ops {</span><br><span class="line"><span class="comment">// 自定义的命名空间域,将随机泊松算子需要的一些私有函数定义在该空间域中,以防止出现函数重名冲突</span></span><br><span class="line"><span class="keyword">namespace</span> {</span><br><span class="line"><span class="comment">// 对算子的输入进行维度合法性检查</span></span><br><span class="line">abstract::<span class="function">ShapePtr <span class="title">RandomPoissonInferShape</span><span class="params">(<span class="keyword">const</span> PrimitivePtr &primitive,</span></span></span><br><span class="line"><span class="function"><span class="params"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><AbstractBasePtr> &input_args)</span> </span>{</span><br><span class="line"> <span class="comment">// Mindspore公有工具函数,检查传入的参数是否为空指针</span></span><br><span class="line"> MS_EXCEPTION_IF_NULL(primitive);</span><br><span class="line"> <span class="keyword">auto</span> op_name = primitive->name();</span><br><span class="line"> <span class="comment">// 检查输入参数`shape`的维度,该参数只能为1-D Tensor</span></span><br><span class="line"> <span class="keyword">auto</span> shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];</span><br><span class="line"> <span class="keyword">if</span> (shape_shape.size() != <span class="number">1</span>) {</span><br><span class="line"> MS_EXCEPTION(ValueError) << <span class="string">"For RandomPoisson, the argument[shape] must be a 1-D tensor, but got "</span></span><br><span class="line"> << shape_shape.size() << <span class="string">"-D"</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// 检查输入参数`shape`的值,如果值为空则抛出错误</span></span><br><span class="line"> <span class="keyword">auto</span> shape_value = input_args[kInputIndex0]->BuildValue();</span><br><span class="line"> MS_EXCEPTION_IF_NULL(shape_value);</span><br><span class="line"> <span class="keyword">if</span> (!shape_value->isa<AnyValue>() && !shape_value->isa<None>()) {</span><br><span class="line"> <span class="keyword">auto</span> out_shape = CheckAndConvertUtils::CheckTensorIntValue(<span class="string">"shape"</span>, shape_value, op_name);</span><br><span class="line"> (<span class="keyword">void</span>)CheckAndConvertUtils::CheckPositiveVector(<span class="string">"shape"</span>, out_shape, op_name);</span><br><span class="line"> <span class="comment">// 将`shape`的值和`rate`的维度拼接起来,构造算子最终输出结果的维度</span></span><br><span class="line"> <span class="keyword">auto</span> rate_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];</span><br><span class="line"> <span class="keyword">auto</span> rate_rank = SizeToLong(rate_shape.size());</span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">int64_t</span> i = <span class="number">0</span>; i < rate_rank; i++) {</span><br><span class="line"> out_shape.push_back(rate_shape[i]);</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">std</span>::make_shared<abstract::Shape>(out_shape);</span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> <span class="comment">// 当输入参数`shape`的值为空时,则没有输出,在算子具体执行时会抛出错误</span></span><br><span class="line"> <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="keyword">int64_t</span>> output_shape = {<span class="number">-2</span>};</span><br><span class="line"> ShapeVector shape_min = {<span class="number">1</span>};</span><br><span class="line"> ShapeVector shape_max = {<span class="number">1</span>};</span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">std</span>::make_shared<abstract::Shape>(output_shape, shape_min, shape_max);</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="comment">// 对算子输入参数的数据类型进行合法性检查</span></span><br><span class="line"><span class="function">TypePtr <span class="title">RandomPoissonInferType</span><span class="params">(<span class="keyword">const</span> PrimitivePtr &prim, <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><AbstractBasePtr> &input_args)</span> </span>{</span><br><span class="line"> <span class="keyword">auto</span> prim_name = prim->name();</span><br><span class="line"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">set</span><TypePtr> valid_shape_types = {kInt32, kInt64};</span><br><span class="line"> (<span class="keyword">void</span>)CheckAndConvertUtils::CheckTypeValid(<span class="string">"shape"</span>, input_args[<span class="number">0</span>]->BuildType(), valid_shape_types, prim_name);</span><br><span class="line"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">set</span><TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt32, kInt64};</span><br><span class="line"> (<span class="keyword">void</span>)CheckAndConvertUtils::CheckTypeValid(<span class="string">"rate"</span>, input_args[<span class="number">1</span>]->BuildType(), valid_types, prim_name);</span><br><span class="line"> <span class="keyword">auto</span> dtype_value = prim->GetAttr(<span class="string">"dtype"</span>);</span><br><span class="line"> <span class="keyword">if</span> (!dtype_value->isa<Type>()) {</span><br><span class="line"> MS_EXCEPTION(TypeError) << <span class="string">"For RandomPoisson, the dtype of "</span> + prim_name + <span class="string">" is invalid!"</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">auto</span> output_type = dtype_value->cast<TypePtr>();</span><br><span class="line"> <span class="keyword">return</span> CheckAndConvertUtils::CheckSubClass(<span class="string">"dtype"</span>, output_type, valid_types, prim_name);</span><br><span class="line">}</span><br><span class="line">} <span class="comment">// namespace</span></span><br><span class="line"></span><br><span class="line"><span class="comment">// 第一个随机种子seed的getter方法</span></span><br><span class="line"><span class="keyword">int64_t</span> RandomPoisson::get_seed() <span class="keyword">const</span> {</span><br><span class="line"> <span class="keyword">auto</span> value_ptr = <span class="keyword">this</span>->GetAttr(kSeed);</span><br><span class="line"> <span class="keyword">return</span> GetValue<<span class="keyword">int64_t</span>>(value_ptr);</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="comment">// 第一个随机种子seed的setter方法</span></span><br><span class="line"><span class="keyword">void</span> RandomPoisson::set_seed(<span class="keyword">const</span> <span class="keyword">int64_t</span> seed) { (<span class="keyword">void</span>)<span class="keyword">this</span>->AddAttr(kSeed, api::MakeValue(seed)); }</span><br><span class="line"></span><br><span class="line"><span class="comment">// 第二个随机种子seed2的getter方法</span></span><br><span class="line"><span class="keyword">int64_t</span> RandomPoisson::get_seed2() <span class="keyword">const</span> {</span><br><span class="line"> <span class="keyword">auto</span> value_ptr = <span class="keyword">this</span>->GetAttr(kSeed2);</span><br><span class="line"> <span class="keyword">return</span> GetValue<<span class="keyword">int64_t</span>>(value_ptr);</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="comment">// 第二个随机种子seed2的setter方法</span></span><br><span class="line"><span class="keyword">void</span> RandomPoisson::set_seed2(<span class="keyword">const</span> <span class="keyword">int64_t</span> seed2) { (<span class="keyword">void</span>)<span class="keyword">this</span>->AddAttr(kSeed2, api::MakeValue(seed2)); }</span><br><span class="line"></span><br><span class="line"><span class="comment">// 算子执行计算的正式入口,调用以上的`RandomPoissonInferType`和`RandomPoissonInferShape`进行参数合法性检查后,正式开始算子的执行</span></span><br><span class="line"><span class="function">AbstractBasePtr <span class="title">RandomPoissonInfer</span><span class="params">(<span class="keyword">const</span> abstract::AnalysisEnginePtr &, <span class="keyword">const</span> PrimitivePtr &primitive,</span></span></span><br><span class="line"><span class="function"><span class="params"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><AbstractBasePtr> &input_args)</span> </span>{</span><br><span class="line"> MS_EXCEPTION_IF_NULL(primitive);</span><br><span class="line"> <span class="keyword">const</span> <span class="keyword">int64_t</span> kInputsNum = <span class="number">2</span>;</span><br><span class="line"> CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());</span><br><span class="line"> <span class="keyword">auto</span> infertype = RandomPoissonInferType(primitive, input_args);</span><br><span class="line"> <span class="keyword">auto</span> infershape = RandomPoissonInferShape(primitive, input_args);</span><br><span class="line"> <span class="keyword">return</span> abstract::MakeAbstract(infershape, infertype);</span><br><span class="line">}</span><br><span class="line"><span class="comment">// 注册一些需要在Host端直接获取的数据</span></span><br><span class="line">REGISTER_HOST_DEPENDS(kRandomPoisson, {<span class="number">0</span>});</span><br><span class="line"><span class="comment">// 向Mindspore框架注册算子</span></span><br><span class="line">MIND_API_OPERATOR_IMPL(RandomPoisson, BaseOperator);</span><br><span class="line"><span class="comment">// 向Mindspore框架注册算子入口</span></span><br><span class="line">REGISTER_PRIMITIVE_EVAL_IMPL(RandomPoisson, prim::kPrimRandomPoisson, RandomPoissonInfer, <span class="literal">nullptr</span>, <span class="literal">true</span>);</span><br><span class="line">} <span class="comment">// namespace ops</span></span><br><span class="line">} <span class="comment">// namespace mindspore</span></span><br></pre></td></tr></table></figure>
<h2 id="C-侧GPU算子实现"><a href="#C-侧GPU算子实现" class="headerlink" title="C++侧GPU算子实现"></a>C++侧GPU算子实现</h2><p>光是定义了算子的原语还是无法实现完整的功能的,接下来还需要实现算子在不同类型设备上的执行逻辑,由于我的任务是实现GPU算子,因此在这里只分析GPU算子的源码。GPU算子又分为C++实现和CUDA实现两个部分,C++侧负责算子执行前的一些准备工作,CUDA侧主要利用GPU并行计算的能力加速算子的计算。</p>
<p>首先介绍算子的C++侧实现,算子核函数类同样分为头文件<code>mindspore/ccsrc/plugin/device/gpu/kernel/random/random_poisson_gpu_kernel.h</code>和具体实现<code>mindspore/ccsrc/plugin/device/gpu/kernel/random/random_poisson_gpu_kernel.cc</code>。由于代码比较长,因此以注释的方式去分析和解读代码。</p>
<figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// C++侧算子核函数的头文件定义</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">ifndef</span> MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_POISSON_GPU_KERNEL_H_</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">define</span> MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_POISSON_GPU_KERNEL_H_</span></span><br><span class="line"><span class="comment">// CUDA相关头文件</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><curand_kernel.h></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><cuda_runtime_api.h></span></span></span><br><span class="line"><span class="comment">// C++标准库</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><vector></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><map></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><string></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><utility></span></span></span><br><span class="line"><span class="comment">// Mindspore的GPU算子Kernel有关的头文件</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"mindspore/core/ops/random_poisson.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"kernel/common_utils.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"plugin/device/gpu/kernel/gpu_kernel.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"plugin/device/gpu/kernel/gpu_kernel_factory.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_op_impl.cuh"</span></span></span><br><span class="line"><span class="comment">// Mindspore将GPU Kernel的定义放在命名空间域`mindspore::kernel`中</span></span><br><span class="line"><span class="keyword">namespace</span> mindspore {</span><br><span class="line"><span class="keyword">namespace</span> kernel {</span><br><span class="line"><span class="comment">// GPU算子类的定义,继承于`NativeGpuKernelMod`类,使用Helper类`MatchKernelHelper`辅助进行参数类型合法性检查、算子支持数据类型注册</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">RandomPoissonGpuKernelMod</span> :</span> <span class="keyword">public</span> NativeGpuKernelMod, <span class="keyword">public</span> MatchKernelHelper<RandomPoissonGpuKernelMod> {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="comment">// GPU Kernel的默认构造函数和析构函数</span></span><br><span class="line"> RandomPoissonGpuKernelMod() = <span class="keyword">default</span>;</span><br><span class="line"> ~RandomPoissonGpuKernelMod() override = <span class="keyword">default</span>;</span><br><span class="line"> <span class="comment">// 算子初始化,该函数在类的构造过程中只执行一次,可以在函数中设置一些非输入输出类的参数的值,如:随机种子seed等</span></span><br><span class="line"> <span class="function"><span class="keyword">bool</span> <span class="title">Init</span><span class="params">(<span class="keyword">const</span> BaseOperatorPtr &base_operator, <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelTensorPtr> &inputs,</span></span></span><br><span class="line"><span class="function"><span class="params"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelTensorPtr> &outputs)</span> override</span>;</span><br><span class="line"> <span class="comment">// 算子的运行入口,Mindspore为算子分配GPU的device_id,给出输入、工作空间、输出的device端地址</span></span><br><span class="line"> <span class="comment">// 这里调用Helper类的`kernel_func`函数,Helper类对所有参数进行非空检查,检查通过后调用`LaunchKernel`函数正式执行算子</span></span><br><span class="line"> <span class="function"><span class="keyword">bool</span> <span class="title">Launch</span><span class="params">(<span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><AddressPtr> &inputs, <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><AddressPtr> &workspace,</span></span></span><br><span class="line"><span class="function"><span class="params"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><AddressPtr> &outputs, <span class="keyword">void</span> *cuda_stream)</span> override </span>{</span><br><span class="line"> <span class="keyword">if</span> (is_null_input_) {</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">true</span>;</span><br><span class="line"> }</span><br><span class="line"> cuda_stream_ = cuda_stream;</span><br><span class="line"> <span class="keyword">return</span> kernel_func_(<span class="keyword">this</span>, inputs, workspace, outputs);</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// 该方法主要用于计算算子在本次运行中所需要的内存空间大小,该方法仅仅是计算需要的内存空间大小,而不进行实际的内存分配</span></span><br><span class="line"> <span class="comment">// 计算结果的单位为`byte`,因此通常计算结果大小是`element_count * sizeof(element)`</span></span><br><span class="line"> <span class="function"><span class="keyword">int</span> <span class="title">Resize</span><span class="params">(<span class="keyword">const</span> BaseOperatorPtr &base_operator, <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelTensorPtr> &inputs,</span></span></span><br><span class="line"><span class="function"><span class="params"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelTensorPtr> &outputs, <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">map</span><<span class="keyword">uint32_t</span>, tensor::TensorPtr> &)</span> override</span>;</span><br><span class="line"> <span class="comment">// Helper类函数的实现,将算子支持的所有数据类型组合并构造成支持列表</span></span><br><span class="line"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="built_in">std</span>::pair<KernelAttr, KernelRunFunc>> &GetFuncList() <span class="keyword">const</span> override;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">protected</span>:</span><br><span class="line"> <span class="comment">// 重设算子资源,通常在`Resize`函数中使用</span></span><br><span class="line"> <span class="function"><span class="keyword">void</span> <span class="title">ResetResource</span><span class="params">()</span> <span class="keyword">noexcept</span> </span>{</span><br><span class="line"> rate_elements_ = <span class="number">1</span>;</span><br><span class="line"> output_elements_ = <span class="number">1</span>;</span><br><span class="line"> is_null_input_ = <span class="literal">false</span>;</span><br><span class="line"> input_size_list_.clear();</span><br><span class="line"> output_size_list_.clear();</span><br><span class="line"> workspace_size_list_.clear();</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// 用于向Mindspore框架返回算子支持的数据类型,与Helper类的`OpSupport`和`GetFuncList`函数配合使用</span></span><br><span class="line"> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelAttr> GetOpSupport() override { <span class="keyword">return</span> OpSupport(); }</span><br><span class="line"></span><br><span class="line"> <span class="keyword">private</span>:</span><br><span class="line"> <span class="comment">// 算子核函数执行方法,在该方法中调用CUDA核函数,使用GPU加速算子计算</span></span><br><span class="line"> <span class="keyword">template</span> <<span class="keyword">typename</span> R, <span class="keyword">typename</span> T></span><br><span class="line"> <span class="function"><span class="keyword">bool</span> <span class="title">LaunchKernel</span><span class="params">(<span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><kernel::AddressPtr> &inputs, <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><AddressPtr> &workspace,</span></span></span><br><span class="line"><span class="function"><span class="params"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><kernel::AddressPtr> &outputs)</span></span>;</span><br><span class="line"> <span class="comment">// 算子私有变量,主要储存随机种子seed的值,和各个输入元素所占内存空间大小</span></span><br><span class="line"> <span class="keyword">int64_t</span> rate_elements_;</span><br><span class="line"> <span class="keyword">int64_t</span> output_elements_;</span><br><span class="line"> <span class="keyword">int64_t</span> unit_shape_size_;</span><br><span class="line"> <span class="keyword">int64_t</span> unit_rate_size_;</span><br><span class="line"> <span class="keyword">int64_t</span> unit_output_size_;</span><br><span class="line"> <span class="keyword">int64_t</span> seed_{<span class="number">0</span>};</span><br><span class="line"> <span class="keyword">int64_t</span> seed2_{<span class="number">0</span>};</span><br><span class="line"> <span class="keyword">bool</span> is_null_input_{<span class="literal">false</span>};</span><br><span class="line"> <span class="keyword">void</span> *cuda_stream_{<span class="literal">nullptr</span>};</span><br><span class="line">};</span><br><span class="line">} <span class="comment">// namespace kernel</span></span><br><span class="line">} <span class="comment">// namespace mindspore</span></span><br><span class="line"></span><br><span class="line"><span class="meta">#<span class="meta-keyword">endif</span> <span class="comment">// MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_POISSON_GPU_KERNEL_H_</span></span></span><br></pre></td></tr></table></figure>
<figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br><span class="line">130</span><br><span class="line">131</span><br><span class="line">132</span><br><span class="line">133</span><br><span class="line">134</span><br><span class="line">135</span><br><span class="line">136</span><br><span class="line">137</span><br><span class="line">138</span><br><span class="line">139</span><br><span class="line">140</span><br><span class="line">141</span><br><span class="line">142</span><br><span class="line">143</span><br><span class="line">144</span><br><span class="line">145</span><br><span class="line">146</span><br><span class="line">147</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// 算子GPU Kernel头文件</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"plugin/device/gpu/kernel/random/random_poisson_gpu_kernel.h"</span></span></span><br><span class="line"><span class="comment">// C++标准库</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><functional></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><utility></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><memory></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><string></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><algorithm></span></span></span><br><span class="line"><span class="comment">// Mindspore工具类</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"ir/anf.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"utils/log_adapter.h"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"kernel/common_utils.h"</span></span></span><br><span class="line"><span class="comment">// CUDA FP16头文件</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"include/cuda_fp16.h"</span></span></span><br><span class="line"><span class="comment">// 命名空间域`mindspore::kernel`</span></span><br><span class="line"><span class="keyword">namespace</span> mindspore {</span><br><span class="line"><span class="keyword">namespace</span> kernel {</span><br><span class="line"><span class="comment">// 匿名命名空间域,存放只有本算子使用的宏定义,用于注册算子支持的数据类型</span></span><br><span class="line"><span class="keyword">namespace</span> {</span><br><span class="line"><span class="keyword">using</span> KernelRunFunc = RandomPoissonGpuKernelMod::KernelRunFunc;</span><br><span class="line"><span class="meta">#<span class="meta-keyword">define</span> ADD_KERNEL(shape_dtype, rate_dtype, output_dtype, rate_type, output_type) \</span></span><br><span class="line"> { \</span><br><span class="line"> KernelAttr() \</span><br><span class="line"> .AddInputAttr(kNumberType##shape_dtype) \</span><br><span class="line"> .AddInputAttr(kNumberType##rate_dtype) \</span><br><span class="line"> .AddOutputAttr(kNumberType##output_dtype), \</span><br><span class="line"> &RandomPoissonGpuKernelMod::LaunchKernel<rate_type, output_type> \</span><br><span class="line"> }</span><br><span class="line">} <span class="comment">// namespace</span></span><br><span class="line"><span class="comment">// 算子初始化函数,只会在类构造时调用一次</span></span><br><span class="line"><span class="keyword">bool</span> RandomPoissonGpuKernelMod::Init(<span class="keyword">const</span> BaseOperatorPtr &base_operator, <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelTensorPtr> &inputs,</span><br><span class="line"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelTensorPtr> &outputs) {</span><br><span class="line"> MS_EXCEPTION_IF_NULL(base_operator);</span><br><span class="line"> kernel_name_ = base_operator->name();</span><br><span class="line"> <span class="comment">// 使用Helper函数对参数进行合法性检查</span></span><br><span class="line"> <span class="keyword">if</span> (!MatchKernelFunc(base_operator, inputs, outputs)) {</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">false</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// 获取各种参数类型所占的内存空间,用于在`Resize`函数中计算算子运行过程中所需内存大小</span></span><br><span class="line"> <span class="keyword">auto</span> kernel_attr = GetKernelAttrFromTensors(inputs, outputs);</span><br><span class="line"> unit_shape_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(<span class="number">0</span>).first);</span><br><span class="line"> unit_rate_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(<span class="number">1</span>).first);</span><br><span class="line"> unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(<span class="number">0</span>).first);</span><br><span class="line"> <span class="comment">// 获取算子的随机种子seed、seed2</span></span><br><span class="line"> <span class="keyword">auto</span> kernel_ptr = <span class="built_in">std</span>::make_shared<ops::RandomPoisson>(base_operator->GetPrim());</span><br><span class="line"> seed_ = <span class="keyword">static_cast</span><<span class="keyword">int64_t</span>>(kernel_ptr->get_seed());</span><br><span class="line"> seed2_ = <span class="keyword">static_cast</span><<span class="keyword">int64_t</span>>(kernel_ptr->get_seed2());</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">true</span>;</span><br><span class="line">}</span><br><span class="line"><span class="comment">// 计算算子运行过程中所需内存空间的大小,注意单位为`byte`</span></span><br><span class="line"><span class="keyword">int</span> RandomPoissonGpuKernelMod::Resize(<span class="keyword">const</span> BaseOperatorPtr &base_operator, <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelTensorPtr> &inputs,</span><br><span class="line"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><KernelTensorPtr> &outputs,</span><br><span class="line"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">map</span><<span class="keyword">uint32_t</span>, tensor::TensorPtr> &) {</span><br><span class="line"> <span class="comment">// 检查所有的输入维度是否合法</span></span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">const</span> <span class="keyword">auto</span> &input : inputs) {</span><br><span class="line"> <span class="comment">// If any input shape contains -1, means input shape is dynamic, so just return do nothing.</span></span><br><span class="line"> <span class="keyword">auto</span> input_shape = input->GetShapeVector();</span><br><span class="line"> <span class="keyword">if</span> (!IsValidShape(input_shape)) {</span><br><span class="line"> <span class="keyword">return</span> KRET_UNKNOWN_SHAPE;</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// 在计算本轮计算所需内存空间时,首先需要将之前产生的脏数据清空</span></span><br><span class="line"> ResetResource();</span><br><span class="line"> <span class="comment">// 计算输入输出的元素的个数</span></span><br><span class="line"> <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="keyword">int64_t</span>> shape_shape = <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="keyword">int64_t</span>>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),</span><br><span class="line"> inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());</span><br><span class="line"> <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="keyword">int64_t</span>> rate_shape = <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="keyword">int64_t</span>>(inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(),</span><br><span class="line"> inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());</span><br><span class="line"> <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="keyword">int64_t</span>> output_shape = <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="keyword">int64_t</span>>(outputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),</span><br><span class="line"> outputs.at(kIndex0)->GetDeviceShapeAdaptively().end());</span><br><span class="line"> <span class="keyword">int64_t</span> shape_elements = <span class="built_in">std</span>::accumulate(shape_shape.begin(), shape_shape.end(), <span class="number">1</span>, <span class="built_in">std</span>::multiplies<<span class="keyword">int64_t</span>>());</span><br><span class="line"> rate_elements_ = <span class="built_in">std</span>::accumulate(rate_shape.begin(), rate_shape.end(), <span class="number">1</span>, <span class="built_in">std</span>::multiplies<<span class="keyword">int64_t</span>>());</span><br><span class="line"> output_elements_ = <span class="built_in">std</span>::accumulate(output_shape.begin(), output_shape.end(), <span class="number">1</span>, <span class="built_in">std</span>::multiplies<<span class="keyword">int64_t</span>>());</span><br><span class="line"> <span class="keyword">if</span> (output_elements_ == <span class="number">0</span>) {</span><br><span class="line"> is_null_input_ = <span class="literal">true</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// 将计算结果储存到数组中</span></span><br><span class="line"> input_size_list_.emplace_back(shape_elements * unit_shape_size_);</span><br><span class="line"> input_size_list_.emplace_back(rate_elements_ * unit_rate_size_);</span><br><span class="line"> output_size_list_.emplace_back(output_elements_ * unit_output_size_);</span><br><span class="line"> workspace_size_list_.push_back(output_elements_ * <span class="keyword">sizeof</span>(curandState));</span><br><span class="line"> <span class="keyword">return</span> KRET_OK;</span><br><span class="line">}</span><br><span class="line"><span class="comment">// 算子正式执行入口函数</span></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> R, <span class="keyword">typename</span> T></span><br><span class="line"><span class="keyword">bool</span> RandomPoissonGpuKernelMod::LaunchKernel(<span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><kernel::AddressPtr> &inputs,</span><br><span class="line"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><AddressPtr> &workspace,</span><br><span class="line"> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><kernel::AddressPtr> &outputs) {</span><br><span class="line"> <span class="comment">// 获取输入输出的Device端地址</span></span><br><span class="line"> R *rate_addr = GetDeviceAddress<R>(inputs, <span class="number">1</span>);</span><br><span class="line"> T *output = GetDeviceAddress<T>(outputs, <span class="number">0</span>);</span><br><span class="line"> curandState *devStates = <span class="literal">nullptr</span>;</span><br><span class="line"> <span class="keyword">void</span> *workspace_addr = GetDeviceAddress<<span class="keyword">void</span> *>(workspace, <span class="number">0</span>);</span><br><span class="line"> devStates = <span class="keyword">reinterpret_cast</span><curandState *>(workspace_addr);</span><br><span class="line"> <span class="comment">// CUDA核函数</span></span><br><span class="line"> RandomPoisson(seed_, seed2_, devStates, rate_addr, rate_elements_, output, output_elements_,</span><br><span class="line"> <span class="keyword">reinterpret_cast</span><cudaStream_t>(cuda_stream_));</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">true</span>;</span><br><span class="line">}</span><br><span class="line"><span class="comment">// 注册算子支持的参数类型,由于参数排列组合类型比较多,因此将实际注册的函数抽离成一个宏定义,使得注册列表看起来更简洁</span></span><br><span class="line"><span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="built_in">std</span>::pair<KernelAttr, KernelRunFunc>> &RandomPoissonGpuKernelMod::GetFuncList() <span class="keyword">const</span> {</span><br><span class="line"> <span class="keyword">static</span> <span class="keyword">const</span> <span class="built_in">std</span>::<span class="built_in">vector</span><<span class="built_in">std</span>::pair<KernelAttr, KernelRunFunc>> func_list = {</span><br><span class="line"> ADD_KERNEL(Int32, Float16, Float16, half, half), ADD_KERNEL(Int32, Float16, Float32, half, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Float16, Float64, half, <span class="keyword">double</span>), ADD_KERNEL(Int32, Float16, Int32, half, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Float16, Int64, half, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int32, Float32, Float16, <span class="keyword">float</span>, half), ADD_KERNEL(Int32, Float32, Float32, <span class="keyword">float</span>, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Float32, Float64, <span class="keyword">float</span>, <span class="keyword">double</span>), ADD_KERNEL(Int32, Float32, Int32, <span class="keyword">float</span>, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Float32, Int64, <span class="keyword">float</span>, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int32, Float64, Float16, <span class="keyword">double</span>, half), ADD_KERNEL(Int32, Float64, Float32, <span class="keyword">double</span>, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Float64, Float64, <span class="keyword">double</span>, <span class="keyword">double</span>), ADD_KERNEL(Int32, Float64, Int32, <span class="keyword">double</span>, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Float64, Int64, <span class="keyword">double</span>, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int32, Int32, Float16, <span class="keyword">int</span>, half), ADD_KERNEL(Int32, Int32, Float32, <span class="keyword">int</span>, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Int32, Float64, <span class="keyword">int</span>, <span class="keyword">double</span>), ADD_KERNEL(Int32, Int32, Int32, <span class="keyword">int</span>, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Int32, Int64, <span class="keyword">int</span>, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int32, Int64, Float16, <span class="keyword">int64_t</span>, half), ADD_KERNEL(Int32, Int64, Float32, <span class="keyword">int64_t</span>, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Int64, Float64, <span class="keyword">int64_t</span>, <span class="keyword">double</span>), ADD_KERNEL(Int32, Int64, Int32, <span class="keyword">int64_t</span>, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int32, Int64, Int64, <span class="keyword">int64_t</span>, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int64, Float16, Float16, half, half), ADD_KERNEL(Int64, Float16, Float32, half, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Float16, Float64, half, <span class="keyword">double</span>), ADD_KERNEL(Int64, Float16, Int32, half, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Float16, Int64, half, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int64, Float32, Float16, <span class="keyword">float</span>, half), ADD_KERNEL(Int64, Float32, Float32, <span class="keyword">float</span>, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Float32, Float64, <span class="keyword">float</span>, <span class="keyword">double</span>), ADD_KERNEL(Int64, Float32, Int32, <span class="keyword">float</span>, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Float32, Int64, <span class="keyword">float</span>, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int64, Float64, Float16, <span class="keyword">double</span>, half), ADD_KERNEL(Int64, Float64, Float32, <span class="keyword">double</span>, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Float64, Float64, <span class="keyword">double</span>, <span class="keyword">double</span>), ADD_KERNEL(Int64, Float64, Int32, <span class="keyword">double</span>, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Float64, Int64, <span class="keyword">double</span>, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int64, Int32, Float16, <span class="keyword">int</span>, half), ADD_KERNEL(Int64, Int32, Float32, <span class="keyword">int</span>, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Int32, Float64, <span class="keyword">int</span>, <span class="keyword">double</span>), ADD_KERNEL(Int64, Int32, Int32, <span class="keyword">int</span>, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Int32, Int64, <span class="keyword">int</span>, <span class="keyword">int64_t</span>),</span><br><span class="line"></span><br><span class="line"> ADD_KERNEL(Int64, Int64, Float16, <span class="keyword">int64_t</span>, half), ADD_KERNEL(Int64, Int64, Float32, <span class="keyword">int64_t</span>, <span class="keyword">float</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Int64, Float64, <span class="keyword">int64_t</span>, <span class="keyword">double</span>), ADD_KERNEL(Int64, Int64, Int32, <span class="keyword">int64_t</span>, <span class="keyword">int</span>),</span><br><span class="line"> ADD_KERNEL(Int64, Int64, Int64, <span class="keyword">int64_t</span>, <span class="keyword">int64_t</span>)};</span><br><span class="line"> <span class="keyword">return</span> func_list;</span><br><span class="line">}</span><br><span class="line"><span class="comment">// 注册随机泊松算子的GPU Kernel</span></span><br><span class="line">MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, RandomPoisson, RandomPoissonGpuKernelMod);</span><br><span class="line">} <span class="comment">// namespace kernel</span></span><br><span class="line">} <span class="comment">// namespace mindspore</span></span><br></pre></td></tr></table></figure>
<h2 id="CUDA核函数实现"><a href="#CUDA核函数实现" class="headerlink" title="CUDA核函数实现"></a>CUDA核函数实现</h2><p>利用GPU并行计算能力,在计算量大的情况下,算子的加速效果十分明显。CUDA核函数的实现也分为函数定义和函数实现,分别写在文件<code>mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_op_impl.cuh</code>和文件<code>mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_op_impl.cu</code>中。</p>
<figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// 定义CUDA核函数头文件</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">ifndef</span> MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_OP_IMPL_CUH_</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">define</span> MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_OP_IMPL_CUH_</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><curand_kernel.h></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string"><random></span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"</span></span></span><br><span class="line"><span class="comment">// 随机泊松算子CUDA核函数定义</span></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> R, <span class="keyword">typename</span> T></span><br><span class="line"><span class="function">CUDA_LIB_EXPORT <span class="keyword">void</span> <span class="title">RandomPoisson</span><span class="params">(<span class="keyword">int</span> seed, <span class="keyword">int</span> seed2, curandState *globalState,</span></span></span><br><span class="line"><span class="function"><span class="params"> R *rate, <span class="keyword">int64_t</span> rate_size, T *output, <span class="keyword">size_t</span> count,</span></span></span><br><span class="line"><span class="function"><span class="params"> cudaStream_t cuda_stream)</span></span>;</span><br><span class="line"><span class="meta">#<span class="meta-keyword">endif</span> <span class="comment">// MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_OP_IMPL_CUH_</span></span></span><br></pre></td></tr></table></figure>
<figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"random_op_impl.cuh"</span></span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">include</span> <span class="meta-string">"include/cuda_fp16.h"</span></span></span><br><span class="line"><span class="comment">// CUDA核函数实现,注意只有`__global__`字段的函数才会在GPU上执行</span></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> R, <span class="keyword">typename</span> T></span><br><span class="line">__<span class="function">global__ <span class="keyword">void</span> <span class="title">RandomPoissonKernel</span><span class="params">(<span class="keyword">int</span> seed, curandState *globalState, R *rate, <span class="keyword">int</span> rate_size, T *output,</span></span></span><br><span class="line"><span class="function"><span class="params"> <span class="keyword">size_t</span> count)</span> </span>{</span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">size_t</span> i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {</span><br><span class="line"> <span class="comment">// 使用curand库的接口,初始化curandstate</span></span><br><span class="line"> curand_init(seed, i, <span class="number">0</span>, &globalState[i]);</span><br><span class="line"> <span class="comment">// 将结果的顺序与`rate`的顺序对齐</span></span><br><span class="line"> <span class="keyword">auto</span> j = i % rate_size;</span><br><span class="line"> <span class="comment">// 使用curand库的泊松接口生成随机泊松值</span></span><br><span class="line"> output[i] = (T)curand_poisson(&globalState[i], rate[j]);</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span>;</span><br><span class="line">}</span><br><span class="line"><span class="comment">// CUDA核函数入口</span></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> R, <span class="keyword">typename</span> T></span><br><span class="line"><span class="function"><span class="keyword">void</span> <span class="title">RandomPoisson</span><span class="params">(<span class="keyword">int</span> seed, <span class="keyword">int</span> seed2, curandState *globalState, R *rate, <span class="keyword">int64_t</span> rate_size, T *output, <span class="keyword">size_t</span> count,</span></span></span><br><span class="line"><span class="function"><span class="params"> cudaStream_t cuda_stream)</span> </span>{</span><br><span class="line"> <span class="comment">// 使用两个随机种子seed、seed2对random_device类进行初始化</span></span><br><span class="line"> <span class="keyword">int</span> RNG_seed = <span class="number">0</span>;</span><br><span class="line"> <span class="built_in">std</span>::random_device rd;</span><br><span class="line"> <span class="keyword">if</span> (seed2 != <span class="number">0</span>) {</span><br><span class="line"> RNG_seed = seed2;</span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (seed != <span class="number">0</span>) {</span><br><span class="line"> RNG_seed = seed;</span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> RNG_seed = <span class="keyword">static_cast</span><<span class="keyword">int</span>>(rd());</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// 调用CUDA核函数</span></span><br><span class="line"> RandomPoissonKernel<<<GET_BLOCKS(count), GET_THREADS, <span class="number">0</span>, cuda_stream>>>(RNG_seed, globalState, rate, rate_size,</span><br><span class="line"> output, count);</span><br><span class="line"> <span class="keyword">return</span>;</span><br><span class="line">}</span><br><span class="line"><span class="comment">// 注册核函数支持的数据类型,由于参数类型组合比较多,因此使用宏定义抽离</span></span><br><span class="line"><span class="meta">#<span class="meta-keyword">define</span> ADD_RANDOM_POISSON(rate_type, output_type) \</span></span><br><span class="line"> <span class="keyword">template</span> CUDA_LIB_EXPORT <span class="keyword">void</span> RandomPoisson<rate_type, output_type>(<span class="keyword">int</span> seed, <span class="keyword">int</span> seed2, curandState *globalState, \</span><br><span class="line"> rate_type *rate, <span class="keyword">int64_t</span> rate_size, \</span><br><span class="line"> output_type *output, <span class="keyword">size_t</span> count, \</span><br><span class="line"> cudaStream_t cuda_stream);</span><br><span class="line"></span><br><span class="line">ADD_RANDOM_POISSON(half, half)</span><br><span class="line">ADD_RANDOM_POISSON(half, <span class="keyword">float</span>)</span><br><span class="line">ADD_RANDOM_POISSON(half, <span class="keyword">double</span>)</span><br><span class="line">ADD_RANDOM_POISSON(half, <span class="keyword">int</span>)</span><br><span class="line">ADD_RANDOM_POISSON(half, <span class="keyword">int64_t</span>)</span><br><span class="line"></span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">float</span>, half)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">float</span>, <span class="keyword">float</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">float</span>, <span class="keyword">double</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">float</span>, <span class="keyword">int</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">float</span>, <span class="keyword">int64_t</span>)</span><br><span class="line"></span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">double</span>, half)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">double</span>, <span class="keyword">float</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">double</span>, <span class="keyword">double</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">double</span>, <span class="keyword">int</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">double</span>, <span class="keyword">int64_t</span>)</span><br><span class="line"></span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int</span>, half)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int</span>, <span class="keyword">float</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int</span>, <span class="keyword">double</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int</span>, <span class="keyword">int</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int</span>, <span class="keyword">int64_t</span>)</span><br><span class="line"></span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int64_t</span>, half)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int64_t</span>, <span class="keyword">float</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int64_t</span>, <span class="keyword">double</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int64_t</span>, <span class="keyword">int</span>)</span><br><span class="line">ADD_RANDOM_POISSON(<span class="keyword">int64_t</span>, <span class="keyword">int64_t</span>)</span><br></pre></td></tr></table></figure>
<h1 id="算子执行流程图"><a href="#算子执行流程图" class="headerlink" title="算子执行流程图"></a>算子执行流程图</h1><p>对算子源码分析完成后,接下来用一副流程图来描述随机泊松算子的完整执行过程。</p>
<p><img src="images/mindspore/RandomPoisson%E6%89%A7%E8%A1%8C%E6%B5%81%E7%A8%8B%E5%9B%BE.png" alt="执行流程图"></p>
<h1 id="算子测试"><a href="#算子测试" class="headerlink" title="算子测试"></a>算子测试</h1><p>对算子计算结果进行测试,测试覆盖了算子支持的所有数据类型,脚本如下:</p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> mindspore <span class="keyword">as</span> ms</span><br><span class="line"><span class="keyword">from</span> mindspore <span class="keyword">import</span> Tensor</span><br><span class="line"><span class="keyword">from</span> mindspore <span class="keyword">import</span> ops</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> time</span><br><span class="line"></span><br><span class="line">start_time = time.time()</span><br><span class="line">ms.set_context(device_target=<span class="string">"GPU"</span>)</span><br><span class="line"></span><br><span class="line">int_type = [ms.int32, ms.int64]</span><br><span class="line">float_type = [ms.int32, ms.int64, ms.float16, ms.float32, ms.float64]</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> x <span class="keyword">in</span> int_type:</span><br><span class="line"> <span class="keyword">for</span> y <span class="keyword">in</span> float_type:</span><br><span class="line"> <span class="keyword">for</span> z <span class="keyword">in</span> float_type:</span><br><span class="line"> shape = Tensor(np.array([<span class="number">2</span>, <span class="number">4</span>]), x)</span><br><span class="line"> rate = Tensor(np.array([<span class="number">5</span>, <span class="number">50</span>, <span class="number">500</span>]), y)</span><br><span class="line"> seed = <span class="number">10</span></span><br><span class="line"> seed2 = <span class="number">20</span></span><br><span class="line"> random_poisson = ops.operations.random_ops.RandomPoisson(seed=seed, seed2=seed2, dtype=z)</span><br><span class="line"> output = random_poisson(shape, rate)</span><br><span class="line"> print(<span class="string">'=====Devided line====='</span>)</span><br><span class="line"> print(output)</span><br><span class="line"> print(output.shape)</span><br><span class="line"> print(output.dtype)</span><br><span class="line"></span><br><span class="line">end_time = time.time()</span><br><span class="line">print(<span class="string">'Process time(s): '</span>, end_time - start_time)</span><br></pre></td></tr></table></figure>
<p>算子测试结果的部分截图如下:</p>
<p><img src="images/mindspore/RandomPoisson_test_result.png" alt="测试结果"></p>
</div>
<footer class="post-footer">
<div class="post-tags">
<a href="/tags/mindspore/" rel="tag"># Mindspore</a>
</div>
<div class="post-nav">
<div class="post-nav-next post-nav-item">
<a href="/gdb调试入门.html" rel="next" title="GDB调试入门">
<i class="fa fa-chevron-left"></i> GDB调试入门
</a>
</div>
<span class="post-nav-divider"></span>
<div class="post-nav-prev post-nav-item">
<a href="/在ubuntu18-04上编译安装mindspore.html" rel="prev" title="在Ubuntu18.04上编译安装MindSpore">
在Ubuntu18.04上编译安装MindSpore <i class="fa fa-chevron-right"></i>
</a>
</div>
</div>
</footer>
</article>
</div>
</div>
</div>
<div class="toggle sidebar-toggle">
<span class="toggle-line toggle-line-first"></span>
<span class="toggle-line toggle-line-middle"></span>
<span class="toggle-line toggle-line-last"></span>
</div>
<aside class="sidebar">
<div class="sidebar-inner">
<ul class="sidebar-nav motion-element">
<li class="sidebar-nav-toc">
文章目录
</li>
<li class="sidebar-nav-overview">
站点概览
</li>
</ul>
<!--noindex-->
<div class="post-toc-wrap sidebar-panel">
<div class="post-toc motion-element"><ol class="nav"><li class="nav-item nav-level-1"><a class="nav-link" href="#算子介绍"><span class="nav-number">1.</span> <span class="nav-text">算子介绍</span></a><ol class="nav-child"><li class="nav-item nav-level-2"><a class="nav-link" href="#定义"><span class="nav-number">1.1.</span> <span class="nav-text">定义</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#算子功能"><span class="nav-number">1.2.</span> <span class="nav-text">算子功能</span></a></li></ol></li><li class="nav-item nav-level-1"><a class="nav-link" href="#源码分析"><span class="nav-number">2.</span> <span class="nav-text">源码分析</span></a><ol class="nav-child"><li class="nav-item nav-level-2"><a class="nav-link" href="#Python侧算子原语定义"><span class="nav-number">2.1.</span> <span class="nav-text">Python侧算子原语定义</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#C-侧算子注册"><span class="nav-number">2.2.</span> <span class="nav-text">C++侧算子注册</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#C-侧算子原语定义"><span class="nav-number">2.3.</span> <span class="nav-text">C++侧算子原语定义</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#C-侧GPU算子实现"><span class="nav-number">2.4.</span> <span class="nav-text">C++侧GPU算子实现</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#CUDA核函数实现"><span class="nav-number">2.5.</span> <span class="nav-text">CUDA核函数实现</span></a></li></ol></li><li class="nav-item nav-level-1"><a class="nav-link" href="#算子执行流程图"><span class="nav-number">3.</span> <span class="nav-text">算子执行流程图</span></a></li><li class="nav-item nav-level-1"><a class="nav-link" href="#算子测试"><span class="nav-number">4.</span> <span class="nav-text">算子测试</span></a></li></ol></div>
</div>
<!--/noindex-->
<div class="site-overview-wrap sidebar-panel">
<div class="site-author motion-element" itemprop="author" itemscope itemtype="http://schema.org/Person">
<img class="site-author-image" itemprop="image" alt="Fisher"
src="/images/icons/avatar.jpg">
<p class="site-author-name" itemprop="name">Fisher</p>
<div class="site-description" itemprop="description">记录学习生活中的点滴</div>
</div>
<div class="site-state-wrap motion-element">
<nav class="site-state">
<div class="site-state-item site-state-posts">
<a href="/archives/">
<span class="site-state-item-count">97</span>
<span class="site-state-item-name">日志</span>
</a>
</div>
<div class="site-state-item site-state-tags">
<a href="/tags/">
<span class="site-state-item-count">25</span>
<span class="site-state-item-name">标签</span></a>
</div>
</nav>
</div>
<div class="links-of-author motion-element">
<span class="links-of-author-item">
<a href="https://github.com/FisherWY" title="GitHub → https://github.com/FisherWY" rel="noopener" target="_blank"><i class="fa fa-fw fa-github"></i>GitHub</a>
</span>
<span class="links-of-author-item">
<a href="mailto:[email protected]" title="E-Mail → mailto:[email protected]" rel="noopener" target="_blank"><i class="fa fa-fw fa-envelope"></i>E-Mail</a>
</span>
</div>
</div>
</div>
</aside>
<div id="sidebar-dimmer"></div>
</div>
</main>
<footer class="footer">
<div class="footer-inner">
<div class="beian"><a href="http://beian.miit.gov.cn" rel="noopener" target="_blank">粤ICP备2022017631号-1 </a>
<img src="/images/icons/beian.png" style="display: inline-block;">
</div>
<div class="copyright">
©
<span itemprop="copyrightYear">2023</span>
<span class="with-love">
<i class="fa fa-user"></i>
</span>
<span class="author" itemprop="copyrightHolder">Fisher</span>
</div>
<div class="powered-by">由 <a href="https://hexo.io" class="theme-link" rel="noopener" target="_blank">Hexo</a> 强力驱动 v3.9.0
</div>
<span class="post-meta-divider">|</span>
<div class="theme-info">主题 – <a href="https://theme-next.org" class="theme-link" rel="noopener" target="_blank">NexT.Gemini</a> v7.4.2
</div>
<script>
function leancloudSelector(url) {
return document.getElementById(url).querySelector('.leancloud-visitors-count');
}
if (CONFIG.page.isPost) {
function addCount(Counter) {
var visitors = document.querySelector('.leancloud_visitors');
var url = visitors.getAttribute('id').trim();
var title = visitors.getAttribute('data-flag-title').trim();
Counter('get', `/classes/Counter?where=${JSON.stringify({ url })}`)
.then(response => response.json())
.then(({ results }) => {
if (results.length > 0) {
var counter = results[0];
Counter('put', '/classes/Counter/' + counter.objectId, { time: { '__op': 'Increment', 'amount': 1 } })
.then(response => response.json())
.then(() => {
leancloudSelector(url).innerText = counter.time + 1;
})
.catch(error => {
console.log('Failed to save visitor count', error);
})
} else {
leancloudSelector(url).innerText = 'Counter not initialized! More info at console err msg.';
console.error('ATTENTION! LeanCloud counter has security bug, see how to solve it here: https://github.com/theme-next/hexo-leancloud-counter-security. \n However, you can still use LeanCloud without security, by setting `security` option to `false`.');
}
})
.catch(error => {
console.log('LeanCloud Counter Error', error);
});
}
} else {
function showTime(Counter) {
var visitors = document.querySelectorAll('.leancloud_visitors');
var entries = [...visitors].map(element => {
return element.getAttribute('id').trim();
});
Counter('get', `/classes/Counter?where=${JSON.stringify({ url: { '$in': entries } })}`)
.then(response => response.json())
.then(({ results }) => {
if (results.length === 0) {
document.querySelectorAll('.leancloud_visitors .leancloud-visitors-count').forEach(element => {
element.innerText = 0;
});
return;
}
for (var i = 0; i < results.length; i++) {
var item = results[i];
var url = item.url;
var time = item.time;
leancloudSelector(url).innerText = time;
}
for (var i = 0; i < entries.length; i++) {
var url = entries[i];
var element = leancloudSelector(url);
if (element.innerText == '') {
element.innerText = 0;
}
}
})
.catch(error => {
console.log('LeanCloud Counter Error', error);
});
}
}
fetch('https://app-router.leancloud.cn/2/route?appId=l8bvleb0PFB0er4hTWo3bGL1-gzGzoHsz')
.then(response => response.json())
.then(({ api_server }) => {
var Counter = (method, url, data) => {
return fetch(`https://${api_server}/1.1${url}`, {
method: method,
headers: {
'X-LC-Id': 'l8bvleb0PFB0er4hTWo3bGL1-gzGzoHsz',
'X-LC-Key': 'vOPowrKK83zOB6LhLKYOsGd1',
'Content-Type': 'application/json',
},
body: JSON.stringify(data)
});
};
if (CONFIG.page.isPost) {
const localhost = /http:\/\/(localhost|127.0.0.1|0.0.0.0)/;
if (localhost.test(document.URL)) return;
addCount(Counter);
} else if (document.querySelectorAll('.post-title-link').length >= 1) {
showTime(Counter);
}
});
</script>
</div>
</footer>
</div>
<script src="/lib/anime.min.js"></script>
<script src="/lib/velocity/velocity.min.js"></script>
<script src="/lib/velocity/velocity.ui.min.js"></script>
<script src="/js/utils.js"></script><script src="/js/motion.js"></script>
<script src="/js/schemes/pisces.js"></script>
<script src="/js/next-boot.js"></script>
<script>
(function(){
var bp = document.createElement('script');
var curProtocol = window.location.protocol.split(':')[0];
bp.src = (curProtocol === 'https') ? 'https://zz.bdstatic.com/linksubmit/push.js' : 'http://push.zhanzhang.baidu.com/push.js';
var s = document.getElementsByTagName("script")[0];
s.parentNode.insertBefore(bp, s);
})();
</script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$', '$'], ['\\(', '\\)'] ],
processEscapes: true,
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre', 'code']
},
TeX: {
equationNumbers: {
autoNumber: 'AMS'
}
}
});
MathJax.Hub.Register.StartupHook('TeX Jax Ready', function() {
MathJax.InputJax.TeX.prefilterHooks.Add(function(data) {
if (data.display) {
var next = data.script.nextSibling;
while (next && next.nodeName.toLowerCase() === '#text') {
next = next.nextSibling;
}
if (next && next.nodeName.toLowerCase() === 'br') {
next.parentNode.removeChild(next);
}
}
});
});
MathJax.Hub.Queue(function() {
var all = MathJax.Hub.getAllJax(), i;
for (i = 0; i < all.length; i += 1) {
element = document.getElementById(all[i].inputID + '-Frame').parentNode;
if (element.nodeName.toLowerCase() == 'li') {
element = element.parentNode;
}
element.classList.add('has-jax');
}
});
</script>
<script>
NexT.utils.getScript('//cdn.jsdelivr.net/npm/mathjax@2/MathJax.js?config=TeX-AMS-MML_HTMLorMML', () => {
MathJax.Hub.Typeset();
}, window.MathJax);
</script>
<script src="/live2dw/lib/L2Dwidget.min.js?094cbace49a39548bed64abff5988b05"></script><script>L2Dwidget.init({"pluginRootPath":"live2dw/","pluginJsPath":"lib/","pluginModelPath":"assets/","tagMode":false,"debug":false,"display":{"position":"left","width":175,"height":350},"model":{"jsonPath":"/live2dw/assets/koharu.model.json"},"mobile":{"show":false},"log":false});</script></body>
</html>