-
Notifications
You must be signed in to change notification settings - Fork 0
/
atom.xml
539 lines (277 loc) · 698 KB
/
atom.xml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<title>Thyssen Wen's Blog</title>
<link href="https://thinksky5124.github.io/atom.xml" rel="self"/>
<link href="https://thinksky5124.github.io/"/>
<updated>2024-04-16T08:57:05.647Z</updated>
<id>https://thinksky5124.github.io/</id>
<author>
<name>Thyssen Wen</name>
</author>
<generator uri="https://hexo.io/">Hexo</generator>
<entry>
<title>Hello World</title>
<link href="https://thinksky5124.github.io/2024/04/16/hello-world/"/>
<id>https://thinksky5124.github.io/2024/04/16/hello-world/</id>
<published>2024-04-16T08:57:05.647Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<p>Welcome to <a href="https://hexo.io/">Hexo</a>! This is your very first post. Check <a href="https://hexo.io/docs/">documentation</a> for more info. If you get any problems when using Hexo, you can find the answer in <a href="https://hexo.io/docs/troubleshooting.html">troubleshooting</a> or you can ask me on <a href="https://github.com/hexojs/hexo/issues">GitHub</a>.</p><h2 id="Quick-Start"><a href="#Quick-Start" class="headerlink" title="Quick Start"></a>Quick Start</h2><h3 id="Create-a-new-post"><a href="#Create-a-new-post" class="headerlink" title="Create a new post"></a>Create a new post</h3><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">$ hexo new <span class="string">"My New Post"</span></span><br></pre></td></tr></tbody></table></figure><p>More info: <a href="https://hexo.io/docs/writing.html">Writing</a></p><h3 id="Run-server"><a href="#Run-server" class="headerlink" title="Run server"></a>Run server</h3><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">$ hexo server</span><br></pre></td></tr></tbody></table></figure><p>More info: <a href="https://hexo.io/docs/server.html">Server</a></p><h3 id="Generate-static-files"><a href="#Generate-static-files" class="headerlink" title="Generate static files"></a>Generate static files</h3><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">$ hexo generate</span><br></pre></td></tr></tbody></table></figure><p>More info: <a href="https://hexo.io/docs/generating.html">Generating</a></p><h3 id="Deploy-to-remote-sites"><a href="#Deploy-to-remote-sites" class="headerlink" title="Deploy to remote sites"></a>Deploy to remote sites</h3><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">$ hexo deploy</span><br></pre></td></tr></tbody></table></figure><p>More info: <a href="https://hexo.io/docs/one-command-deployment.html">Deployment</a></p>]]></content>
<summary type="html"><p>Welcome to <a href="https://hexo.io/">Hexo</a>! This is your very first post. Check <a href="https://hexo.io/docs/">documentation</a> for</summary>
</entry>
<entry>
<title>TensorRT 模型构建与推理</title>
<link href="https://thinksky5124.github.io/2024/03/25/TensorRT_deploy_and_infer/"/>
<id>https://thinksky5124.github.io/2024/03/25/TensorRT_deploy_and_infer/</id>
<published>2024-03-25T03:17:41.679Z</published>
<updated>2024-03-25T03:17:41.971Z</updated>
<content type="html"><![CDATA[<h1 id="TensorRT-模型构建与推理"><a href="#TensorRT-模型构建与推理" class="headerlink" title="TensorRT 模型构建与推理"></a>TensorRT 模型构建与推理</h1><h2 id="TensorRT-简介"><a href="#TensorRT-简介" class="headerlink" title="TensorRT 简介"></a>TensorRT 简介</h2><p>TensorRT 是由 NVIDIA 发布的深度学习框架,用于在其硬件上运行深度学习推理。TensorRT 提供量化感知训练和离线量化功能,用户可以选择 INT8 和 FP16 两种优化模式,将深度学习模型应用到不同任务的生产部署,如视频流、语音识别、推荐、欺诈检测、文本生成和自然语言处理。TensorRT 经过高度优化,可在 NVIDIA GPU 上运行, 并且可能是目前在 NVIDIA GPU 运行模型最快的推理引擎。关于 TensorRT 更具体的信息可以访问 <a href="https://developer.nvidia.com/tensorrt">TensorRT官网</a> 了解。</p><h2 id="安装-TensorRT"><a href="#安装-TensorRT" class="headerlink" title="安装 TensorRT"></a>安装 TensorRT</h2><h3 id="Windows"><a href="#Windows" class="headerlink" title="Windows"></a>Windows</h3><p>默认在一台有 NVIDIA 显卡的机器上,提前安装好 <a href="https://developer.nvidia.com/cuda-toolkit-archive">CUDA</a> 和 <a href="https://developer.nvidia.com/rdp/cudnn-archive">CUDNN</a>,登录 NVIDIA 官方网站下载和主机 CUDA 版本适配的 TensorRT 压缩包即可。</p><p>以 CUDA 版本是 10.2 为例,选择适配 CUDA 10.2 的 <a href="https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.5.1/zip/tensorrt-8.2.5.1.windows10.x86_64.cuda-10.2.cudnn8.2.zip">zip 包</a>,下载完成后,有 conda 虚拟环境的用户可以优先切换到虚拟环境中,然后在 powershell 中执行类似如下的命令安装并测试:</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">cd</span> \the\path\of\tensorrt\zip\file </span><br><span class="line">Expand-Archive TensorRT-8.2.5.1.Windows10.x86_64.cuda-10.2.cudnn8.2.zip . </span><br><span class="line"><span class="variable">$env</span>:TENSORRT_DIR = <span class="string">"<span class="variable">$pwd</span>\TensorRT-8.2.5.1"</span> </span><br><span class="line"><span class="variable">$env</span>:path = <span class="string">"<span class="variable">$env</span>:TENSORRT_DIR\lib;"</span> + <span class="variable">$env</span>:path </span><br><span class="line">pip install <span class="variable">$env</span>:TENSORRT_DIR\python\tensorrt-8.2.5.1-cp36-none-win_amd64.whl </span><br><span class="line">python -c <span class="string">"import tensorrt;print(tensorrt.__version__)"</span></span><br></pre></td></tr></tbody></table></figure><p>上述命令会在安装后检查 TensorRT 版本,如果打印结果是 8.2.5.1,说明安装 Python 包成功了。</p><h3 id="Linux"><a href="#Linux" class="headerlink" title="Linux"></a>Linux</h3><p>和在 Windows 环境下安装类似,默认在一台有 NVIDIA 显卡的机器上,提前安装好 <a href="https://developer.nvidia.com/cuda-toolkit-archive">CUDA</a> 和 <a href="https://developer.nvidia.com/rdp/cudnn-archive">CUDNN</a>,登录 NVIDIA 官方网站下载和主机 CUDA 版本适配的 TensorRT 压缩包即可。</p><p>以 CUDA 版本是 10.2 为例,选择适配 CUDA 10.2 的 <a href="https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.5.1/tars/tensorrt-8.2.5.1.linux.x86_64-gnu.cuda-10.2.cudnn8.2.tar.gz">tar 包</a>,然后执行类似如下的命令安装并测试:</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">cd</span> /the/path/of/tensorrt/tar/gz/file </span><br><span class="line">tar -zxvf TensorRT-8.2.5.1.linux.x86_64-gnu.cuda-10.2.cudnn8.2.tar.gz </span><br><span class="line"><span class="built_in">export</span> TENSORRT_DIR=$(<span class="built_in">pwd</span>)/TensorRT-8.2.5.1 </span><br><span class="line"><span class="built_in">export</span> LD_LIBRARY_PATH=<span class="variable">$TENSORRT_DIR</span>/lib:<span class="variable">$LD_LIBRARY_PATH</span> </span><br><span class="line">pip install TensorRT-8.2.5.1/python/tensorrt-8.2.5.1-cp37-none-linux_x86_64.whl </span><br><span class="line">python -c <span class="string">"import tensorrt;print(tensorrt.__version__)"</span></span><br></pre></td></tr></tbody></table></figure><p>如果发现打印结果是 8.2.5.1,说明安装 Python 包成功了。</p><h2 id="模型构建"><a href="#模型构建" class="headerlink" title="模型构建"></a>模型构建</h2><p>我们使用 TensorRT 生成模型主要有两种方式:</p><ol><li>直接通过 TensorRT 的 API 逐层搭建网络;</li><li>将中间表示的模型转换成 TensorRT 的模型,比如将 ONNX 模型转换成 TensorRT 模型。</li></ol><p>接下来,我们将用 Python 和 C++ 语言分别使用这两种方式构建 TensorRT 模型,并将生成的模型进行推理。</p><h3 id="直接构建"><a href="#直接构建" class="headerlink" title="直接构建"></a>直接构建</h3><p>利用 TensorRT 的 API 逐层搭建网络,这一过程类似使用一般的训练框架,如使用 Pytorch 或者TensorFlow 搭建网络。需要注意的是对于权重部分,如卷积或者归一化层,需要将权重内容赋值到 TensorRT 的网络中。本文就不详细展示,只搭建一个对输入做池化的简单网络。</p><p>使用 Python API 构建</p><p>首先是使用 Python API 直接搭建 TensorRT 网络,这种方法主要是利用 <code>tensorrt.Builder</code> 的 <code>create_builder_config</code> 和 <code>create_network</code> 功能,分别构建 config 和 network,前者用于设置网络的最大工作空间等参数,后者就是网络主体,需要对其逐层添加内容。</p><p>此外,需要定义好输入和输出名称,将构建好的网络序列化,保存成本地文件。值得注意的是:如果想要网络接受不同分辨率的输入输出,需要使用 <code>tensorrt.Builder</code> 的 <code>create_optimization_profile</code> 函数,并设置最小、最大的尺寸。</p><p>实现代码如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> tensorrt <span class="keyword">as</span> trt </span><br><span class="line"> </span><br><span class="line">verbose = <span class="literal">True</span> </span><br><span class="line">IN_NAME = <span class="string">'input'</span> </span><br><span class="line">OUT_NAME = <span class="string">'output'</span> </span><br><span class="line">IN_H = <span class="number">224</span> </span><br><span class="line">IN_W = <span class="number">224</span> </span><br><span class="line">BATCH_SIZE = <span class="number">1</span> </span><br><span class="line"> </span><br><span class="line">EXPLICIT_BATCH = <span class="number">1</span> << (<span class="built_in">int</span>)( </span><br><span class="line"> trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) </span><br><span class="line"> </span><br><span class="line">TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) <span class="keyword">if</span> verbose <span class="keyword">else</span> trt.Logger() </span><br><span class="line"><span class="keyword">with</span> trt.Builder(TRT_LOGGER) <span class="keyword">as</span> builder, builder.create_builder_config( </span><br><span class="line">) <span class="keyword">as</span> config, builder.create_network(EXPLICIT_BATCH) <span class="keyword">as</span> network: </span><br><span class="line"> <span class="comment"># define network </span></span><br><span class="line">input_tensor = network.add_input( </span><br><span class="line"> name=IN_NAME, dtype=trt.float32, shape=(BATCH_SIZE, <span class="number">3</span>, IN_H, IN_W)) </span><br><span class="line"> pool = network.add_pooling( </span><br><span class="line"> <span class="built_in">input</span>=input_tensor, <span class="built_in">type</span>=trt.PoolingType.MAX, window_size=(<span class="number">2</span>, <span class="number">2</span>)) </span><br><span class="line"> pool.stride = (<span class="number">2</span>, <span class="number">2</span>) </span><br><span class="line"> pool.get_output(<span class="number">0</span>).name = OUT_NAME </span><br><span class="line"> network.mark_output(pool.get_output(<span class="number">0</span>)) </span><br><span class="line"> </span><br><span class="line"> <span class="comment"># serialize the model to engine file</span></span><br><span class="line">profile = builder.create_optimization_profile() </span><br><span class="line"> profile.set_shape_input(<span class="string">'input'</span>, *[[BATCH_SIZE, <span class="number">3</span>, IN_H, IN_W]]*<span class="number">3</span>) </span><br><span class="line"> builder.max_batch_size = <span class="number">1</span> </span><br><span class="line"> config.max_workspace_size = <span class="number">1</span> << <span class="number">30</span> </span><br><span class="line"> engine = builder.build_engine(network, config) </span><br><span class="line"> <span class="keyword">with</span> <span class="built_in">open</span>(<span class="string">'model_python_trt.engine'</span>, mode=<span class="string">'wb'</span>) <span class="keyword">as</span> f: </span><br><span class="line"> f.write(<span class="built_in">bytearray</span>(engine.serialize())) </span><br><span class="line"> <span class="built_in">print</span>(<span class="string">"generating file done!"</span>)</span><br></pre></td></tr></tbody></table></figure><p>使用 C++ API 构建</p><p>对于想要直接用 C++ 语言构建网络的小伙伴来说,整个流程和上述 Python 的执行过程非常类似,需要注意的点主要有:</p><ol><li><code>nvinfer1:: createInferBuilder</code> 对应 Python 中的 <code>tensorrt.Builder</code>,需要传入 <code>ILogger</code> 类的实例,但是 <code>ILogger</code> 是一个抽象类,需要用户继承该类并实现内部的虚函数。不过此处我们直接使用了 TensorRT 包解压后的 samples 文件夹 <code>../samples/common/logger.h</code> 文件里的实现 <code>Logger</code> 子类。</li><li>设置 TensorRT 模型的输入尺寸,需要多次调用 <code>IOptimizationProfile</code> 的 <code>setDimensions</code> 方法,比 Python 略繁琐一些。<code>IOptimizationProfile</code> 需要用 <code>createOptimizationProfile</code> 函数,对应 Python 的 <code>create_builder_config</code> 函数。</li></ol><p>实现代码如下:</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><fstream></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><iostream></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><NvInfer.h></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><../samples/common/logger.h></span> </span></span><br><span class="line"><span class="keyword">using</span> <span class="keyword">namespace</span> nvinfer1; </span><br><span class="line"><span class="keyword">using</span> <span class="keyword">namespace</span> sample; </span><br><span class="line"> </span><br><span class="line"><span class="type">const</span> <span class="type">char</span>* IN_NAME = <span class="string">"input"</span>; </span><br><span class="line"><span class="type">const</span> <span class="type">char</span>* OUT_NAME = <span class="string">"output"</span>; </span><br><span class="line"><span class="type">static</span> <span class="type">const</span> <span class="type">int</span> IN_H = <span class="number">224</span>; </span><br><span class="line"><span class="type">static</span> <span class="type">const</span> <span class="type">int</span> IN_W = <span class="number">224</span>; </span><br><span class="line"><span class="type">static</span> <span class="type">const</span> <span class="type">int</span> BATCH_SIZE = <span class="number">1</span>; </span><br><span class="line"><span class="type">static</span> <span class="type">const</span> <span class="type">int</span> EXPLICIT_BATCH = <span class="number">1</span> << (<span class="type">int</span>)(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); </span><br><span class="line"> </span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">main</span><span class="params">(<span class="type">int</span> argc, <span class="type">char</span> argv)</span> </span></span><br><span class="line"><span class="function"></span>{ </span><br><span class="line"> <span class="comment">// Create builder</span></span><br><span class="line"> Logger m_logger; </span><br><span class="line"> IBuilder* builder = <span class="built_in">createInferBuilder</span>(m_logger); </span><br><span class="line"> IBuilderConfig* config = builder-><span class="built_in">createBuilderConfig</span>(); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Create model to populate the network</span></span><br><span class="line"> INetworkDefinition* network = builder-><span class="built_in">createNetworkV2</span>(EXPLICIT_BATCH); </span><br><span class="line"> ITensor* input_tensor = network-><span class="built_in">addInput</span>(IN_NAME, DataType::kFLOAT, Dims4{ BATCH_SIZE, <span class="number">3</span>, IN_H, IN_W }); </span><br><span class="line"> IPoolingLayer* pool = network-><span class="built_in">addPoolingNd</span>(*input_tensor, PoolingType::kMAX, DimsHW{ <span class="number">2</span>, <span class="number">2</span> }); </span><br><span class="line"> pool-><span class="built_in">setStrideNd</span>(DimsHW{ <span class="number">2</span>, <span class="number">2</span> }); </span><br><span class="line"> pool-><span class="built_in">getOutput</span>(<span class="number">0</span>)-><span class="built_in">setName</span>(OUT_NAME); </span><br><span class="line"> network-><span class="built_in">markOutput</span>(*pool-><span class="built_in">getOutput</span>(<span class="number">0</span>)); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Build engine</span></span><br><span class="line"> IOptimizationProfile* profile = builder-><span class="built_in">createOptimizationProfile</span>(); </span><br><span class="line"> profile-><span class="built_in">setDimensions</span>(IN_NAME, OptProfileSelector::kMIN, <span class="built_in">Dims4</span>(BATCH_SIZE, <span class="number">3</span>, IN_H, IN_W)); </span><br><span class="line"> profile-><span class="built_in">setDimensions</span>(IN_NAME, OptProfileSelector::kOPT, <span class="built_in">Dims4</span>(BATCH_SIZE, <span class="number">3</span>, IN_H, IN_W)); </span><br><span class="line"> profile-><span class="built_in">setDimensions</span>(IN_NAME, OptProfileSelector::kMAX, <span class="built_in">Dims4</span>(BATCH_SIZE, <span class="number">3</span>, IN_H, IN_W)); </span><br><span class="line"> config-><span class="built_in">setMaxWorkspaceSize</span>(<span class="number">1</span> << <span class="number">20</span>); </span><br><span class="line"> ICudaEngine* engine = builder-><span class="built_in">buildEngineWithConfig</span>(*network, *config); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Serialize the model to engine file </span></span><br><span class="line"> IHostMemory* modelStream{ <span class="literal">nullptr</span> }; </span><br><span class="line"> <span class="built_in">assert</span>(engine != <span class="literal">nullptr</span>); </span><br><span class="line"> modelStream = engine-><span class="built_in">serialize</span>(); </span><br><span class="line"> </span><br><span class="line"> <span class="function">std::ofstream <span class="title">p</span><span class="params">(<span class="string">"model.engine"</span>, std::ios::binary)</span></span>; </span><br><span class="line"> <span class="keyword">if</span> (!p) { </span><br><span class="line"> std::cerr << <span class="string">"could not open output file to save model"</span> << std::endl; </span><br><span class="line"> <span class="keyword">return</span> <span class="number">-1</span>; </span><br><span class="line"> } </span><br><span class="line"> p.<span class="built_in">write</span>(<span class="built_in">reinterpret_cast</span><<span class="type">const</span> <span class="type">char</span>*>(modelStream-><span class="built_in">data</span>()), modelStream-><span class="built_in">size</span>()); </span><br><span class="line"> std::cout << <span class="string">"generating file done!"</span> << std::endl; </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Release resources</span></span><br><span class="line"> modelStream-><span class="built_in">destroy</span>(); </span><br><span class="line"> network-><span class="built_in">destroy</span>(); </span><br><span class="line"> engine-><span class="built_in">destroy</span>(); </span><br><span class="line"> builder-><span class="built_in">destroy</span>(); </span><br><span class="line"> config-><span class="built_in">destroy</span>(); </span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span>; </span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><h3 id="IR-转换模型"><a href="#IR-转换模型" class="headerlink" title="IR 转换模型"></a>IR 转换模型</h3><p>除了直接通过 TensorRT 的 API 逐层搭建网络并序列化模型,TensorRT 还支持将中间表示的模型(如 ONNX)转换成 TensorRT 模型。</p><p>使用 Python API 转换</p><p>我们首先使用 Pytorch 实现一个和上文一致的模型,即只对输入做一次池化并输出;然后将 Pytorch 模型转换成 ONNX 模型;最后将 ONNX 模型转换成 TensorRT 模型。</p><p>这里主要使用了 TensorRT 的 <code>OnnxParser</code> 功能,它可以将 ONNX 模型解析到 TensorRT 的网络中。最后我们同样可以得到一个 TensorRT 模型,其功能与上述方式实现的模型功能一致。</p><p>实现代码如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch </span><br><span class="line"><span class="keyword">import</span> onnx </span><br><span class="line"><span class="keyword">import</span> tensorrt <span class="keyword">as</span> trt </span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line">onnx_model = <span class="string">'model.onnx'</span> </span><br><span class="line"> </span><br><span class="line"><span class="keyword">class</span> <span class="title class_">NaiveModel</span>(torch.nn.Module): </span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>): </span><br><span class="line"> <span class="built_in">super</span>().__init__() </span><br><span class="line"> self.pool = torch.nn.MaxPool2d(<span class="number">2</span>, <span class="number">2</span>) </span><br><span class="line"> </span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>): </span><br><span class="line"> <span class="keyword">return</span> self.pool(x) </span><br><span class="line"> </span><br><span class="line">device = torch.device(<span class="string">'cuda:0'</span>) </span><br><span class="line"> </span><br><span class="line"><span class="comment"># generate ONNX model</span></span><br><span class="line">torch.onnx.export(NaiveModel(), torch.randn(<span class="number">1</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>), onnx_model, input_names=[<span class="string">'input'</span>], output_names=[<span class="string">'output'</span>], opset_version=<span class="number">11</span>) </span><br><span class="line">onnx_model = onnx.load(onnx_model) </span><br><span class="line"> </span><br><span class="line"><span class="comment"># create builder and network </span></span><br><span class="line">logger = trt.Logger(trt.Logger.ERROR) </span><br><span class="line">builder = trt.Builder(logger) </span><br><span class="line">EXPLICIT_BATCH = <span class="number">1</span> << (<span class="built_in">int</span>)( </span><br><span class="line"> trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) </span><br><span class="line">network = builder.create_network(EXPLICIT_BATCH) </span><br><span class="line"> </span><br><span class="line"><span class="comment"># parse onnx </span></span><br><span class="line">parser = trt.OnnxParser(network, logger) </span><br><span class="line"> </span><br><span class="line"><span class="keyword">if</span> <span class="keyword">not</span> parser.parse(onnx_model.SerializeToString()): </span><br><span class="line"> error_msgs = <span class="string">''</span> </span><br><span class="line"> <span class="keyword">for</span> error <span class="keyword">in</span> <span class="built_in">range</span>(parser.num_errors): </span><br><span class="line"> error_msgs += <span class="string">f'<span class="subst">{parser.get_error(error)}</span>\n'</span> </span><br><span class="line"> <span class="keyword">raise</span> RuntimeError(<span class="string">f'Failed to parse onnx, <span class="subst">{error_msgs}</span>'</span>) </span><br><span class="line"> </span><br><span class="line">config = builder.create_builder_config() </span><br><span class="line">config.max_workspace_size = <span class="number">1</span><<<span class="number">20</span> </span><br><span class="line">profile = builder.create_optimization_profile() </span><br><span class="line"> </span><br><span class="line">profile.set_shape(<span class="string">'input'</span>, [<span class="number">1</span>,<span class="number">3</span> ,<span class="number">224</span> ,<span class="number">224</span>], [<span class="number">1</span>,<span class="number">3</span>,<span class="number">224</span>, <span class="number">224</span>], [<span class="number">1</span>,<span class="number">3</span> ,<span class="number">224</span> ,<span class="number">224</span>]) </span><br><span class="line">config.add_optimization_profile(profile) </span><br><span class="line"><span class="comment"># create engine </span></span><br><span class="line"><span class="keyword">with</span> torch.cuda.device(device): </span><br><span class="line"> engine = builder.build_engine(network, config) </span><br><span class="line"> </span><br><span class="line"><span class="keyword">with</span> <span class="built_in">open</span>(<span class="string">'model.engine'</span>, mode=<span class="string">'wb'</span>) <span class="keyword">as</span> f: </span><br><span class="line"> f.write(<span class="built_in">bytearray</span>(engine.serialize())) </span><br><span class="line"> <span class="built_in">print</span>(<span class="string">"generating file done!"</span>)</span><br></pre></td></tr></tbody></table></figure><p>IR 转换时,如果有多 Batch、多输入、动态 shape 的需求,都可以通过多次调用 <code>set_shape</code> 函数进行设置。<code>set_shape</code> 函数接受的传参分别是:输入节点名称,可接受的最小输入尺寸,最优的输入尺寸,可接受的最大输入尺寸。一般要求这三个尺寸的大小关系为单调递增。</p><p>使用 C++ API 转换</p><p>介绍了如何用 Python 语言将 ONNX 模型转换成 TensorRT 模型后,再介绍下如何用 C++ 将 ONNX 模型转换成 TensorRT 模型。这里通过 <code>NvOnnxParser</code>,我们可以将上一小节转换时得到的 ONNX 文件直接解析到网络中。</p><p>实现代码如下:</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><fstream></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><iostream></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><NvInfer.h></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><NvOnnxParser.h></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><../samples/common/logger.h></span> </span></span><br><span class="line"><span class="keyword">using</span> <span class="keyword">namespace</span> nvinfer1; </span><br><span class="line"><span class="keyword">using</span> <span class="keyword">namespace</span> nvonnxparser; </span><br><span class="line"><span class="keyword">using</span> <span class="keyword">namespace</span> sample; </span><br><span class="line"> </span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">main</span><span class="params">(<span class="type">int</span> argc, <span class="type">char</span> argv)</span> </span></span><br><span class="line"><span class="function"></span>{ </span><br><span class="line"> <span class="comment">// Create builder </span></span><br><span class="line"> Logger m_logger; </span><br><span class="line"> IBuilder* builder = <span class="built_in">createInferBuilder</span>(m_logger); </span><br><span class="line"> <span class="type">const</span> <span class="keyword">auto</span> explicitBatch = <span class="number">1U</span> << <span class="built_in">static_cast</span><<span class="type">uint32_t</span>>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); </span><br><span class="line"> IBuilderConfig* config = builder-><span class="built_in">createBuilderConfig</span>(); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Create model to populate the network </span></span><br><span class="line"> INetworkDefinition* network = builder-><span class="built_in">createNetworkV2</span>(explicitBatch); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Parse ONNX file </span></span><br><span class="line"> IParser* parser = nvonnxparser::<span class="built_in">createParser</span>(*network, m_logger); </span><br><span class="line"> <span class="type">bool</span> parser_status = parser-><span class="built_in">parseFromFile</span>(<span class="string">"model.onnx"</span>, <span class="built_in">static_cast</span><<span class="type">int</span>>(ILogger::Severity::kWARNING)); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Get the name of network input </span></span><br><span class="line"> Dims dim = network-><span class="built_in">getInput</span>(<span class="number">0</span>)-><span class="built_in">getDimensions</span>(); </span><br><span class="line"> <span class="keyword">if</span> (dim.d[<span class="number">0</span>] == <span class="number">-1</span>) *<span class="comment">// -1 means it is a dynamic model* </span></span><br><span class="line"> { </span><br><span class="line"> <span class="type">const</span> <span class="type">char</span>* name = network-><span class="built_in">getInput</span>(<span class="number">0</span>)-><span class="built_in">getName</span>(); </span><br><span class="line"> IOptimizationProfile* profile = builder-><span class="built_in">createOptimizationProfile</span>(); </span><br><span class="line"> profile-><span class="built_in">setDimensions</span>(name, OptProfileSelector::kMIN, <span class="built_in">Dims4</span>(<span class="number">1</span>, dim.d[<span class="number">1</span>], dim.d[<span class="number">2</span>], dim.d[<span class="number">3</span>])); </span><br><span class="line"> profile-><span class="built_in">setDimensions</span>(name, OptProfileSelector::kOPT, <span class="built_in">Dims4</span>(<span class="number">1</span>, dim.d[<span class="number">1</span>], dim.d[<span class="number">2</span>], dim.d[<span class="number">3</span>])); </span><br><span class="line"> profile-><span class="built_in">setDimensions</span>(name, OptProfileSelector::kMAX, <span class="built_in">Dims4</span>(<span class="number">1</span>, dim.d[<span class="number">1</span>], dim.d[<span class="number">2</span>], dim.d[<span class="number">3</span>])); </span><br><span class="line"> config-><span class="built_in">addOptimizationProfile</span>(profile); </span><br><span class="line"> } </span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Build engine </span></span><br><span class="line"> config-><span class="built_in">setMaxWorkspaceSize</span>(<span class="number">1</span> << <span class="number">20</span>); </span><br><span class="line"> ICudaEngine* engine = builder-><span class="built_in">buildEngineWithConfig</span>(*network, *config); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Serialize the model to engine file </span></span><br><span class="line"> IHostMemory* modelStream{ <span class="literal">nullptr</span> }; </span><br><span class="line"> <span class="built_in">assert</span>(engine != <span class="literal">nullptr</span>); </span><br><span class="line"> modelStream = engine-><span class="built_in">serialize</span>(); </span><br><span class="line"> </span><br><span class="line"> <span class="function">std::ofstream <span class="title">p</span><span class="params">(<span class="string">"model.engine"</span>, std::ios::binary)</span></span>; </span><br><span class="line"> <span class="keyword">if</span> (!p) { </span><br><span class="line"> std::cerr << <span class="string">"could not open output file to save model"</span> << std::endl; </span><br><span class="line"> <span class="keyword">return</span> <span class="number">-1</span>; </span><br><span class="line"> } </span><br><span class="line"> p.<span class="built_in">write</span>(<span class="built_in">reinterpret_cast</span><<span class="type">const</span> <span class="type">char</span>*>(modelStream-><span class="built_in">data</span>()), modelStream-><span class="built_in">size</span>()); </span><br><span class="line"> std::cout << <span class="string">"generate file success!"</span> << std::endl; </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Release resources </span></span><br><span class="line"> modelStream-><span class="built_in">destroy</span>(); </span><br><span class="line"> network-><span class="built_in">destroy</span>(); </span><br><span class="line"> engine-><span class="built_in">destroy</span>(); </span><br><span class="line"> builder-><span class="built_in">destroy</span>(); </span><br><span class="line"> config-><span class="built_in">destroy</span>(); </span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span>; </span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><h2 id="模型推理"><a href="#模型推理" class="headerlink" title="模型推理"></a>模型推理</h2><p>前面,我们使用了两种构建 TensorRT 模型的方式,分别用 Python 和 C++ 两种语言共生成了四个 TensorRT 模型,这四个模型的功能理论上是完全一致的。</p><p>接下来,我们将分别使用 Python 和 C++ 两种语言对生成的 TensorRT 模型进行推理。</p><h3 id="使用-Python-API-推理"><a href="#使用-Python-API-推理" class="headerlink" title="使用 Python API 推理"></a>使用 Python API 推理</h3><p>首先是使用 Python API 推理 TensorRT 模型,这里部分代码参考了 <a href="https://github.com/open-mmlab/mmdeploy">MMDeploy</a>。运行下面代码,可以发现输入一个 <code>1x3x224x224</code> 的张量,输出一个 <code>1x3x112x112</code> 的张量,完全符合我们对输入池化后结果的预期。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> typing <span class="keyword">import</span> <span class="type">Union</span>, <span class="type">Optional</span>, <span class="type">Sequence</span>,<span class="type">Dict</span>,<span class="type">Any</span> </span><br><span class="line"> </span><br><span class="line"><span class="keyword">import</span> torch </span><br><span class="line"><span class="keyword">import</span> tensorrt <span class="keyword">as</span> trt </span><br><span class="line"> </span><br><span class="line"><span class="keyword">class</span> <span class="title class_">TRTWrapper</span>(torch.nn.Module): </span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self,engine: <span class="type">Union</span>[<span class="built_in">str</span>, trt.ICudaEngine], </span></span><br><span class="line"><span class="params"> output_names: <span class="type">Optional</span>[<span class="type">Sequence</span>[<span class="built_in">str</span>]] = <span class="literal">None</span></span>) -> <span class="literal">None</span>: </span><br><span class="line"> <span class="built_in">super</span>().__init__() </span><br><span class="line"> self.engine = engine </span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">isinstance</span>(self.engine, <span class="built_in">str</span>): </span><br><span class="line"> <span class="keyword">with</span> trt.Logger() <span class="keyword">as</span> logger, trt.Runtime(logger) <span class="keyword">as</span> runtime: </span><br><span class="line"> <span class="keyword">with</span> <span class="built_in">open</span>(self.engine, mode=<span class="string">'rb'</span>) <span class="keyword">as</span> f: </span><br><span class="line"> engine_bytes = f.read() </span><br><span class="line"> self.engine = runtime.deserialize_cuda_engine(engine_bytes) </span><br><span class="line"> self.context = self.engine.create_execution_context() </span><br><span class="line"> names = [_ <span class="keyword">for</span> _ <span class="keyword">in</span> self.engine] </span><br><span class="line"> input_names = <span class="built_in">list</span>(<span class="built_in">filter</span>(self.engine.binding_is_input, names)) </span><br><span class="line"> self._input_names = input_names </span><br><span class="line"> self._output_names = output_names </span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> self._output_names <span class="keyword">is</span> <span class="literal">None</span>: </span><br><span class="line"> output_names = <span class="built_in">list</span>(<span class="built_in">set</span>(names) - <span class="built_in">set</span>(input_names)) </span><br><span class="line"> self._output_names = output_names </span><br><span class="line"> </span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, inputs: <span class="type">Dict</span>[<span class="built_in">str</span>, torch.Tensor]</span>): </span><br><span class="line"> <span class="keyword">assert</span> self._input_names <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span> </span><br><span class="line"> <span class="keyword">assert</span> self._output_names <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span> </span><br><span class="line"> bindings = [<span class="literal">None</span>] * (<span class="built_in">len</span>(self._input_names) + <span class="built_in">len</span>(self._output_names)) </span><br><span class="line"> profile_id = <span class="number">0</span> </span><br><span class="line"> <span class="keyword">for</span> input_name, input_tensor <span class="keyword">in</span> inputs.items(): </span><br><span class="line"> <span class="comment"># check if input shape is valid </span></span><br><span class="line">profile = self.engine.get_profile_shape(profile_id, input_name) </span><br><span class="line"> <span class="keyword">assert</span> input_tensor.dim() == <span class="built_in">len</span>( </span><br><span class="line"> profile[<span class="number">0</span>]), <span class="string">'Input dim is different from engine profile.'</span> </span><br><span class="line"> <span class="keyword">for</span> s_min, s_input, s_max <span class="keyword">in</span> <span class="built_in">zip</span>(profile[<span class="number">0</span>], input_tensor.shape, </span><br><span class="line"> profile[<span class="number">2</span>]): </span><br><span class="line"> <span class="keyword">assert</span> s_min <= s_input <= s_max, \ </span><br><span class="line"> <span class="string">'Input shape should be between '</span> \ </span><br><span class="line"> + <span class="string">f'<span class="subst">{profile[<span class="number">0</span>]}</span> and <span class="subst">{profile[<span class="number">2</span>]}</span>'</span> \ </span><br><span class="line"> + <span class="string">f' but get <span class="subst">{<span class="built_in">tuple</span>(input_tensor.shape)}</span>.'</span> </span><br><span class="line"> idx = self.engine.get_binding_index(input_name) </span><br><span class="line"> </span><br><span class="line"> <span class="comment"># All input tensors must be gpu variables </span></span><br><span class="line"><span class="keyword">assert</span> <span class="string">'cuda'</span> <span class="keyword">in</span> input_tensor.device.<span class="built_in">type</span> </span><br><span class="line"> input_tensor = input_tensor.contiguous() </span><br><span class="line"> <span class="keyword">if</span> input_tensor.dtype == torch.long: </span><br><span class="line"> input_tensor = input_tensor.<span class="built_in">int</span>() </span><br><span class="line"> self.context.set_binding_shape(idx, <span class="built_in">tuple</span>(input_tensor.shape)) </span><br><span class="line"> bindings[idx] = input_tensor.contiguous().data_ptr() </span><br><span class="line"> </span><br><span class="line"> <span class="comment"># create output tensors </span></span><br><span class="line">outputs = {} </span><br><span class="line"> <span class="keyword">for</span> output_name <span class="keyword">in</span> self._output_names: </span><br><span class="line"> idx = self.engine.get_binding_index(output_name) </span><br><span class="line"> dtype = torch.float32 </span><br><span class="line"> shape = <span class="built_in">tuple</span>(self.context.get_binding_shape(idx)) </span><br><span class="line"> </span><br><span class="line"> device = torch.device(<span class="string">'cuda'</span>) </span><br><span class="line"> output = torch.empty(size=shape, dtype=dtype, device=device) </span><br><span class="line"> outputs[output_name] = output </span><br><span class="line"> bindings[idx] = output.data_ptr() </span><br><span class="line"> self.context.execute_async_v2(bindings, </span><br><span class="line"> torch.cuda.current_stream().cuda_stream) </span><br><span class="line"> <span class="keyword">return</span> outputs </span><br><span class="line"> </span><br><span class="line">model = TRTWrapper(<span class="string">'model.engine'</span>, [<span class="string">'output'</span>]) </span><br><span class="line">output = model(<span class="built_in">dict</span>(<span class="built_in">input</span> = torch.randn(<span class="number">1</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>).cuda())) </span><br><span class="line"><span class="built_in">print</span>(output)</span><br></pre></td></tr></tbody></table></figure><h3 id="使用-C-API-推理"><a href="#使用-C-API-推理" class="headerlink" title="使用 C++ API 推理"></a>使用 C++ API 推理</h3><p>最后,在很多实际生产环境中,我们都会使用 C++ 语言完成具体的任务,以达到更加高效的代码运行效果,另外 TensoRT 的用户一般也都更看重其在 C++ 下的使用,所以我们也用 C++ 语言实现一遍模型推理,这也可以和用 Python API 推理模型做一个对比。</p><p>实现代码如下:</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><fstream></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><iostream></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><NvInfer.h></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><../samples/common/logger.h></span> </span></span><br><span class="line"><span class="meta">#<span class="keyword">define</span> CHECK(status) \ </span></span><br><span class="line"> <span class="keyword">do</span>\ </span><br><span class="line"> {\ </span><br><span class="line"> <span class="keyword">auto</span> ret = (status);\ </span><br><span class="line"> <span class="keyword">if</span> (ret != <span class="number">0</span>)\ </span><br><span class="line"> {\ </span><br><span class="line"> std::cerr << <span class="string">"Cuda failure: "</span> << ret << std::endl;\ </span><br><span class="line"> <span class="built_in">abort</span>();\ </span><br><span class="line"> }\ </span><br><span class="line"> } <span class="keyword">while</span> (<span class="number">0</span>) </span><br><span class="line"> </span><br><span class="line"><span class="keyword">using</span> <span class="keyword">namespace</span> nvinfer1; </span><br><span class="line"><span class="keyword">using</span> <span class="keyword">namespace</span> sample; </span><br><span class="line"> </span><br><span class="line"><span class="type">const</span> <span class="type">char</span>* IN_NAME = <span class="string">"input"</span>; </span><br><span class="line"><span class="type">const</span> <span class="type">char</span>* OUT_NAME = <span class="string">"output"</span>; </span><br><span class="line"><span class="type">static</span> <span class="type">const</span> <span class="type">int</span> IN_H = <span class="number">224</span>; </span><br><span class="line"><span class="type">static</span> <span class="type">const</span> <span class="type">int</span> IN_W = <span class="number">224</span>; </span><br><span class="line"><span class="type">static</span> <span class="type">const</span> <span class="type">int</span> BATCH_SIZE = <span class="number">1</span>; </span><br><span class="line"><span class="type">static</span> <span class="type">const</span> <span class="type">int</span> EXPLICIT_BATCH = <span class="number">1</span> << (<span class="type">int</span>)(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); </span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">doInference</span><span class="params">(IExecutionContext& context, <span class="type">float</span>* input, <span class="type">float</span>* output, <span class="type">int</span> batchSize)</span> </span></span><br><span class="line"><span class="function"></span>{ </span><br><span class="line"> <span class="type">const</span> ICudaEngine& engine = context.<span class="built_in">getEngine</span>(); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Pointers to input and output device buffers to pass to engine. </span></span><br><span class="line"> <span class="comment">// Engine requires exactly IEngine::getNbBindings() number of buffers. </span></span><br><span class="line"> <span class="built_in">assert</span>(engine.<span class="built_in">getNbBindings</span>() == <span class="number">2</span>); </span><br><span class="line"> <span class="type">void</span>* buffers[<span class="number">2</span>]; </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// In order to bind the buffers, we need to know the names of the input and output tensors. </span></span><br><span class="line"> <span class="comment">// Note that indices are guaranteed to be less than IEngine::getNbBindings()</span></span><br><span class="line"> <span class="type">const</span> <span class="type">int</span> inputIndex = engine.<span class="built_in">getBindingIndex</span>(IN_NAME); </span><br><span class="line"> <span class="type">const</span> <span class="type">int</span> outputIndex = engine.<span class="built_in">getBindingIndex</span>(OUT_NAME); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Create GPU buffers on device </span></span><br><span class="line"> <span class="built_in">CHECK</span>(<span class="built_in">cudaMalloc</span>(&buffers[inputIndex], batchSize * <span class="number">3</span> * IN_H * IN_W * <span class="built_in">sizeof</span>(<span class="type">float</span>))); </span><br><span class="line"> <span class="built_in">CHECK</span>(<span class="built_in">cudaMalloc</span>(&buffers[outputIndex], batchSize * <span class="number">3</span> * IN_H * IN_W /<span class="number">4</span> * <span class="built_in">sizeof</span>(<span class="type">float</span>))); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Create stream </span></span><br><span class="line"> cudaStream_t stream; </span><br><span class="line"> <span class="built_in">CHECK</span>(<span class="built_in">cudaStreamCreate</span>(&stream)); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host </span></span><br><span class="line"> <span class="built_in">CHECK</span>(<span class="built_in">cudaMemcpyAsync</span>(buffers[inputIndex], input, batchSize * <span class="number">3</span> * IN_H * IN_W * <span class="built_in">sizeof</span>(<span class="type">float</span>), cudaMemcpyHostToDevice, stream)); </span><br><span class="line"> context.<span class="built_in">enqueue</span>(batchSize, buffers, stream, <span class="literal">nullptr</span>); </span><br><span class="line"> <span class="built_in">CHECK</span>(<span class="built_in">cudaMemcpyAsync</span>(output, buffers[outputIndex], batchSize * <span class="number">3</span> * IN_H * IN_W / <span class="number">4</span> * <span class="built_in">sizeof</span>(<span class="type">float</span>), cudaMemcpyDeviceToHost, stream)); </span><br><span class="line"> <span class="built_in">cudaStreamSynchronize</span>(stream); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Release stream and buffers </span></span><br><span class="line"> <span class="built_in">cudaStreamDestroy</span>(stream); </span><br><span class="line"> <span class="built_in">CHECK</span>(<span class="built_in">cudaFree</span>(buffers[inputIndex])); </span><br><span class="line"> <span class="built_in">CHECK</span>(<span class="built_in">cudaFree</span>(buffers[outputIndex])); </span><br><span class="line">} </span><br><span class="line"> </span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">main</span><span class="params">(<span class="type">int</span> argc, <span class="type">char</span> argv)</span> </span></span><br><span class="line"><span class="function"></span>{ </span><br><span class="line"> <span class="comment">// create a model using the API directly and serialize it to a stream</span></span><br><span class="line"> <span class="type">char</span> *trtModelStream{ <span class="literal">nullptr</span> }; </span><br><span class="line"> <span class="type">size_t</span> size{ <span class="number">0</span> }; </span><br><span class="line"> </span><br><span class="line"> <span class="function">std::ifstream <span class="title">file</span><span class="params">(<span class="string">"model.engine"</span>, std::ios::binary)</span></span>; </span><br><span class="line"> <span class="keyword">if</span> (file.<span class="built_in">good</span>()) { </span><br><span class="line"> file.<span class="built_in">seekg</span>(<span class="number">0</span>, file.end); </span><br><span class="line"> size = file.<span class="built_in">tellg</span>(); </span><br><span class="line"> file.<span class="built_in">seekg</span>(<span class="number">0</span>, file.beg); </span><br><span class="line"> trtModelStream = <span class="keyword">new</span> <span class="type">char</span>[size]; </span><br><span class="line"> <span class="built_in">assert</span>(trtModelStream); </span><br><span class="line"> file.<span class="built_in">read</span>(trtModelStream, size); </span><br><span class="line"> file.<span class="built_in">close</span>(); </span><br><span class="line"> } </span><br><span class="line"> </span><br><span class="line"> Logger m_logger; </span><br><span class="line"> IRuntime* runtime = <span class="built_in">createInferRuntime</span>(m_logger); </span><br><span class="line"> <span class="built_in">assert</span>(runtime != <span class="literal">nullptr</span>); </span><br><span class="line"> ICudaEngine* engine = runtime-><span class="built_in">deserializeCudaEngine</span>(trtModelStream, size, <span class="literal">nullptr</span>); </span><br><span class="line"> <span class="built_in">assert</span>(engine != <span class="literal">nullptr</span>); </span><br><span class="line"> IExecutionContext* context = engine-><span class="built_in">createExecutionContext</span>(); </span><br><span class="line"> <span class="built_in">assert</span>(context != <span class="literal">nullptr</span>); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// generate input data </span></span><br><span class="line"> <span class="type">float</span> data[BATCH_SIZE * <span class="number">3</span> * IN_H * IN_W]; </span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> i = <span class="number">0</span>; i < BATCH_SIZE * <span class="number">3</span> * IN_H * IN_W; i++) </span><br><span class="line"> data[i] = <span class="number">1</span>; </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Run inference </span></span><br><span class="line"> <span class="type">float</span> prob[BATCH_SIZE * <span class="number">3</span> * IN_H * IN_W /<span class="number">4</span>]; </span><br><span class="line"> <span class="built_in">doInference</span>(*context, data, prob, BATCH_SIZE); </span><br><span class="line"> </span><br><span class="line"> <span class="comment">// Destroy the engine</span></span><br><span class="line"> context-><span class="built_in">destroy</span>(); </span><br><span class="line"> engine-><span class="built_in">destroy</span>(); </span><br><span class="line"> runtime-><span class="built_in">destroy</span>(); </span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span>; </span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>通过本文的学习,我们掌握了两种构建 TensorRT 模型的方式:直接通过 TensorRT 的 API 逐层搭建网络;将中间表示的模型转换成 TensorRT 的模型。不仅如此,我们还分别用 C++ 和 Python 两种语言完成了 TensorRT 模型的构建及推理,相信大家都有所收获!在下一篇文章中,我们将和大家一起学习何添加 TensorRT 自定义算子,敬请期待哦~</p><h2 id="FAQ"><a href="#FAQ" class="headerlink" title="FAQ"></a>FAQ</h2><ul><li>Q:运行代码时报错:Could not find: cudnn64_8.dll. Is it on your PATH?</li><li>A:首先检查下自己的环境变量中是否包含 cudnn64_8.dll 所在的路径,若发现 cudnn 的路径在 C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2\bin 中,但是里面只有 cudnn64_7.dll,解决方法是去 NVIDIA 官网下载 cuDNN zip 包,解压后,复制其中的 cudnn64_8.dll 到 CUDA Toolkit 的 bin 目录下。这时也可以复制一份 cudnn64_7.dll,然后将复制的那份改名成 cudnn64_8.dll,同样可以解决这个问题。</li></ul><h2 id="参考"><a href="#参考" class="headerlink" title="参考"></a>参考</h2><p><a href="https://github.com/wang-xinyu/tensorrtx">GitHub - wang-xinyu/tensorrtx: Implementation of popular deep learning networks with TensorRT networ</a></p><p><a href="https://github.com/NVIDIA/TensorRT">GitHub - NVIDIA/TensorRT: TensorRT is a C++ library for high performance inference on NVIDIA GPUs an</a></p>]]></content>
<summary type="html"><h1 id="TensorRT-模型构建与推理"><a href="#TensorRT-模型构建与推理" class="headerlink" title="TensorRT 模型构建与推理"></a>TensorRT 模型构建与推理</h1><h2 id="TensorRT-</summary>
<category term="高效人工智能" scheme="https://thinksky5124.github.io/categories/%E9%AB%98%E6%95%88%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD/"/>
<category term="TensorRT" scheme="https://thinksky5124.github.io/categories/TensorRT/"/>
<category term="TensorRT" scheme="https://thinksky5124.github.io/tags/TensorRT/"/>
<category term="模型部署" scheme="https://thinksky5124.github.io/tags/%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2/"/>
</entry>
<entry>
<title>CPU/GPU联合编程</title>
<link href="https://thinksky5124.github.io/2022/08/18/CPU_GPU%E8%81%94%E5%90%88%E7%BC%96%E7%A8%8B/"/>
<id>https://thinksky5124.github.io/2022/08/18/CPU_GPU%E8%81%94%E5%90%88%E7%BC%96%E7%A8%8B/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.643Z</updated>
<content type="html"><![CDATA[<h1 id="CPU-GPU联合编程"><a href="#CPU-GPU联合编程" class="headerlink" title="CPU/GPU联合编程"></a>CPU/GPU联合编程</h1><p>由示例代码可以知道,只要调用了 cuda 函数把模型移动到 GPU 之上,我们就可以使用 CUDA global 核函数在GPU上进行并行运算。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">model = ToyModel().cuda(device_ids[<span class="number">0</span>]) <span class="comment"># 这里复制模型到 GPU 之上</span></span><br><span class="line">ddp_model = DDP(model, device_ids)</span><br><span class="line"></span><br><span class="line">loss_fn = nn.MSELoss()</span><br><span class="line">optimizer = optim.SGD(ddp_model.parameters(), lr=<span class="number">0.001</span>)</span><br><span class="line"></span><br><span class="line">optimizer.zero_grad()</span><br><span class="line">outputs = ddp_model(torch.randn(<span class="number">20</span>, <span class="number">10</span>))</span><br></pre></td></tr></tbody></table></figure><p>但是我们忽略了一个问题,就是 PyTorch 怎么知道此时应该调用GPU对应的 global 核函数?为什么 PyTorch 就不调用 CPU 函数或者其他设备的函数了?这就是我们接下来需要分析的。</p><h3 id="Dispatcher-机制"><a href="#Dispatcher-机制" class="headerlink" title="Dispatcher 机制"></a><strong><strong>Dispatcher 机制</strong></strong></h3><p>在PyTorch中,operator 所表现出预期行为是由很多机制共同作用导致的,比如:</p><ul><li>做实际工作的kernel。</li><li>是否支持反向自动微分,例如,使 loss.backward() 正常工作的标记位。</li><li>是否启用了torch.jit.trace。</li><li>如果你正在vmap调用中,所运行operator将会表现出不同的批处理行为</li></ul><p>对Pytorch operator而言,它需要对一个单一函数,如add,里面的所有行为都安排好在哪做怎么做,这样实现代码就会变成了一个非常混乱而且不可维护的局面,所以需要有一个机制来解决这个问题,而且这个机制应该是一个抽象,而不是简单的if语句。最后它必须在尽可能不降低PyTorch性能的情况下做到这一点。这个机制就是 Dispatcher。</p><p><img src="https://s2.loli.net/2024/03/25/sXiz2UkOnHLrbpa.png" alt="Dispatcher.png"></p><h3 id="什么是-Dispatcher"><a href="#什么是-Dispatcher" class="headerlink" title="什么是 Dispatcher"></a><strong>什么是 Dispatcher</strong></h3><p>dispatcher对于每个operator都会维护一个函数指针表,这些函数为每个dispatch key提供了对应的实现,这套机制大致对应于PyTorch中的一个横切关注点。在上图中,你可以看到在这个表中有针对不同后端(CPU、CUDA、XLA)以及更高级概念(例如 autograd 和跟踪)的dispatch条目。dispatcher的工作是根据输入的tensor和其他一些东西(比如参数个数,返回值类型等等)来计算出一个dispatch key,然后跳转到函数指针表所指向的函数。</p><p>熟悉 C++ 的人可能会注意到,这个函数指针表与C++中的虚表非常相似。在C++中,对象的虚函数是通过将每个对象与一个虚表的指针相关联来实现的,该虚表包含了有关对象上每个虚函数的实现。在PyTorch中,我们基本上重新实现了虚拟表,但有一些区别。</p><ul><li>dispatch表之中包括了 dispatch key 和其对应的函数指针,我们可以发现,dispatch key不仅仅有后端(CPU、CUDA、XLA),也有一些更高级的概念(例如 autograd 和跟踪)。</li><li>dispatch表是按operator分配的,而虚表是按类分配的。这意味着我们可以通过分配一个新的dispatch表来扩展所支持的operator集。与其不同的是,对于一个C++对象,你可以通过继承子类来扩展类型,但你不能轻易添加虚函数。与普通的面向对象系统不同,PyTorch大部分的可扩展性在于定义新的operator(而不是新的子类),所以这种权衡是合理的。此外,dispatch key的种类不是公开可扩展的,PyTorch核心团队希望那些想添加新dispatch key的使用者通过向PyTorch核心团队提交一个补丁来添加他们的dispatch key。</li><li>PyTorch的dispatch key的计算考虑了operator的所有参数(multiple dispatch)以及线程本地状态(TLS)。这与虚表不同,在虚表中只有第一个对象(this指针)很重要。</li><li>最后,dispatcher支持boxing和unboxing作为op的调用约定的一部分。在文章的最后部分会有更多关于这个的内容。</li></ul><p>有趣的历史笔记:PyTorch曾经使用虚函数来实现动态dispatch,当我们意识到需要比虚表更多的能力时,我们重新实现了动态dispatch。</p><p><img src="https://s2.loli.net/2024/03/25/Qp6McLhIClztuWB.png" alt="dynamic_dispatch.png"></p><h3 id="如何计算key"><a href="#如何计算key" class="headerlink" title="如何计算key"></a><strong><strong>如何计算key</strong></strong></h3><p>那么,PyTorch究竟是如何计算dispatch key的呢?PyTorch是基于dispatch key set来完成的,dispatch key set是一个基本抽象,它是dispatch key的一个bitset。大致来讲,PyTorch综合来自不同来源的dispatch key sets(在某些情况下屏蔽一些key)来得到一个最终的dispatch key set。然后我们在这个set中挑选优先级最高的key(dispatch keys按某些优先级隐式排序),这就是PyTorch这次应该调用的结果。那么,这些dispatch key sets的来源是什么?</p><ul><li>每个张量输入都有一个由该张量上的所有dispatch key组成的dispatch key set(直观地说,这些dispatch key的值会是类似 “CPU”字符串这样的东西,这告诉我们该张量是一个CPU张量,所以应该由dispatch表中的CPU handler来处理)。</li><li>PyTorch还有一个local include set,用于 “模态(modal) “功能,例如tracing,它不与任何张量关联,而是某种线程的本地模态,用户可以在某些范围内打开或关闭。</li><li>最后,PyTorch有一个global set,它包含了始终应该被考虑的dispatch key(自从写下这张PPT以来,Autograd已经从global set转移到了张量之上。然而系统的高级结构并没有改变)。</li></ul><p>除了这些,还有一个local exclude set,其用从dispatch排除某些dispatch key。一个常见的场景是一个handler负责处理一个key,然后通过local exclude set将自己屏蔽掉,这样PyTorch以后就不会尝试重新处理这个key。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210055746-281799638.jpg" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210055746-281799638.jpg"></p><h3 id="注册"><a href="#注册" class="headerlink" title="注册"></a><strong>注册</strong></h3><p>我们接下来看看如何注册这个dispatch key 到 dispatch 表之中。这个过程通过operator registration API来实现。操作符注册 API 有三种主要方式:</p><ul><li>为operator定义模式。</li><li>然后在对应的key上注册实现。</li><li>最后,有一个 fallback 方法,用户可以使用它为某个key对应的<em>所有</em>运算符定义同一个处理程序。</li></ul><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210123784-2083674405.jpg" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210123784-2083674405.jpg"></p><p>为了可视化 operator registration的工作,让我们想象一下,所有op的dispatch表共同形成一个二维网格,像这样:</p><ul><li>纵轴上是PyTorch中支持的每个op。</li><li>横轴上是系统支持的每个dispatch key。</li></ul><p>operator registration 行为就是在这两个轴定义出的单元格中填写对应的实现。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210157628-304705603.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210157628-304705603.png"></p><p>在一个特定的dispatch key上为一个operator注册kernel函数时,我们会填写一个单元格(下面的蓝色)的内容。比如下图就是一个 cpu kernel mul 算子。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210225747-821296730.jpg" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210225747-821296730.jpg"></p><p>用户也可以使用 “catch-all” 来为所有的 dispatch keys 注册同一个kernel,比如下图的红色行。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211109230431379-222705518.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211109230431379-222705518.png"></p><p>用户也可以为下图的 “aten::add”,”aten::mul”,”aten::sub” 这样的kernel 指定同一个 dispatch key,如下图绿色列。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211109230453547-431252092.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211109230453547-431252092.png"></p><p>这些注册形式有一个优先级:特定的内核实现具有最高优先级,然后是 catch,最后是 fallback,如下图的 1,2,3 顺序,首先选择1,然后是 2,最后是 3。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211109230508275-1642852144.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211109230508275-1642852144.png"></p>]]></content>
<summary type="html">CPU/GPU联合编程</summary>
<category term="分布式训练" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="分布式训练" scheme="https://thinksky5124.github.io/tags/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="工程" scheme="https://thinksky5124.github.io/tags/%E5%B7%A5%E7%A8%8B/"/>
</entry>
<entry>
<title>CUDA编程模型基础</title>
<link href="https://thinksky5124.github.io/2022/08/18/CUDA%E7%BC%96%E7%A8%8B%E6%A8%A1%E5%9E%8B%E5%9F%BA%E7%A1%80/"/>
<id>https://thinksky5124.github.io/2022/08/18/CUDA%E7%BC%96%E7%A8%8B%E6%A8%A1%E5%9E%8B%E5%9F%BA%E7%A1%80/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.643Z</updated>
<content type="html"><![CDATA[<h1 id="CUDA编程模型基础"><a href="#CUDA编程模型基础" class="headerlink" title="CUDA编程模型基础"></a>CUDA编程模型基础</h1><p>CUDA是英伟达为GPU编程提供的异构编程库。</p><h3 id="异构模型"><a href="#异构模型" class="headerlink" title="异构模型"></a>异构模型</h3><p>CUDA编程模型是一个异构模型。程序运行在一个异构系统之上,这个异构系统由CPU和GPU构成,它们之间由总线分开,程序运行时候是由CPU和GPU协同工作。</p><p>在CUDA之中,有两个重要概念:host和device。</p><ul><li>Host :CPU及其内存。</li><li>Device :GPU及其内存。</li></ul><p>因此,CUDA 架构下的一个程序也对应分为两个部份:Host 代码和Device代码,它们分别在CPU和GPU上运行。host与device之间可以通信进行数据拷贝。</p><ul><li>主机代码(Host Code):在 CPU 上执行的部份,使用Linux(GNU gcc)和Windows(Microsoft Visual C)编译器来编译。大致可以认为认为C语言工作对象是CPU和内存条。</li><li>设备代码(Device Code):在GPU上执行的部份,使用 NVIDIA NVCC 编译器来编译。大致可以认为 CUDA C工作对象是GPU及GPU上内存(也叫设备内存)。</li></ul><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">+-------------------+ +--------------------+</span><br><span class="line">| | | |</span><br><span class="line">| +----------+ | | +----------+ |</span><br><span class="line">| | | | | | | |</span><br><span class="line">| | RAM | | | | RAM | |</span><br><span class="line">| | | | | | | |</span><br><span class="line">| +----+-----+ | | +----+-----+ |</span><br><span class="line">| | +--------+ | |</span><br><span class="line">| | | | | |</span><br><span class="line">| +----+-----+ | | +----+-----+ |</span><br><span class="line">| | | | | | | |</span><br><span class="line">| | CPU | | | | GPU | |</span><br><span class="line">| | | | | | | |</span><br><span class="line">| +----------+ | | +----------+ |</span><br><span class="line">| | | |</span><br><span class="line">+-------------------+ +--------------------+</span><br><span class="line"></span><br><span class="line"> Host Device</span><br></pre></td></tr></tbody></table></figure><h3 id="并行思想"><a href="#并行思想" class="headerlink" title="并行思想"></a><strong>并行思想</strong></h3><p>CUDA 编程的思路是并行思想,大致如下:</p><ul><li>把一个很大的执行任务划分成若干个简单的可以重复的操作,然后使用若干个线程来分别执行这些操作,达到并行的目的。</li><li>执行任务处理的数据也要对应分组成多个小数据块。比如一个大数据分成若干个GPU组,每个GPU组要再次分成多个线程组,线程组内的张量可能需要再细分为张量处理器能处理的小组。</li></ul><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106205857346-47120320.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106205857346-47120320.png"></p><p>因此,一个典型的CUDA程序包括串行代码和并行代码。</p><ul><li>串行代码是标准C代码,由host执行。</li><li>并行代码是CUDA C代码,在device中执行。</li></ul><p>CUDA 主程序由CPU开始,即程序由host执行串行代码开始,当遇到需要数据并行处理的部分,则由device执行并行代码来作为补足。device可以独立于host进行大部分操作。当一个device代码启动之后,控制权会立刻返还给CPU来执行其他任务,所以这是一个异步过程。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106205922879-653846926.png" alt="图来自 [https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html。](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html%E3%80%82)"></p><p>图来自 <a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html%E3%80%82">https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html。</a></p><h3 id="处理流程"><a href="#处理流程" class="headerlink" title="处理流程"></a><strong>处理流程</strong></h3><p>典型的CUDA程序的执行流程如下:</p><ul><li>分配host内存空间并且初始化数据。</li><li>分配device显存空间。</li><li>将要计算的数据从Host内存之上复制到device显存之上。</li><li>调用CUDA核函数在device上完成用户指定的运算。</li><li>将计算后GPU内存上的结果复制到Host内存上。</li><li>释放device和host上分配的内存。</li></ul><p>具体可以参见下图。</p><p><img src="https://img-blog.csdnimg.cn/img_convert/010da16d222a960934288b03c67ad6dd.png" alt="https://img-blog.csdnimg.cn/img_convert/010da16d222a960934288b03c67ad6dd.png"></p><h3 id="函数"><a href="#函数" class="headerlink" title="函数"></a><strong>函数</strong></h3><ul><li><p><strong>核函数</strong></p><p> 核函数是在device线程中并行执行的函数。在 CUDA 程序中,主程序在调用GPU内核之前需要对核进行执行配置,以确定线程块数,每个线程块中线程数和共享内存大小。比如在调用时需要用<code><<参数1,参数2>></code>来指定核函数需要的线程数量以及线程是如何组织,这样在GPU之中就会启动若干个线程来并行执行这个核函数,每个线程被分配一个唯一的线程号。</p><p> CUDA通过函数类型限定词来区别host和device上的函数,主要的三个函数类型限定词为:</p><p> 具体如下:</p><table><thead><tr><th>限定符</th><th>执行</th><th>调用</th></tr></thead><tbody><tr><td><strong>global</strong></td><td>设备端执行</td><td>可以从主机调用也可以从某些特定设备调用</td></tr><tr><td><strong>device</strong></td><td>设备端执行</td><td>设备端调用</td></tr><tr><td><strong>host</strong></td><td>主机端执行</td><td>主机调用</td></tr></tbody></table><p> 具体如下:</p> <figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line">+------------------------+ +------------------------+</span><br><span class="line">| | | |</span><br><span class="line">| | | |</span><br><span class="line">| __host__ __global__ | | __device__ |</span><br><span class="line">| + + | | |</span><br><span class="line">| | | | | + |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | v---------------> | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | +<--------------v | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| v v | | v |</span><br><span class="line">| | | |</span><br><span class="line">+------------------------+ +------------------------+</span><br><span class="line"></span><br><span class="line"> Host Device</span><br></pre></td></tr></tbody></table></figure><p> 这三个限定词其实也是 CUDA 中常见的三种运行场景。其中,device 函数和global函数因为需要在GPU上运行,因此不能调用常见的一些 C/C++ 函数(因为这些函数没有对应的 GPU 实现)。</p><p> 如下代码是 NVIDIA 的例子,使用内置的 threadIdx 变量,把 A 和 B 两个张量进行相加,得到 C。因此,N 个线程之中每个都会执行 VecAdd() 。</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// Kernel definition</span></span><br><span class="line"><span class="function">__global__ <span class="type">void</span> <span class="title">VecAdd</span><span class="params">(<span class="type">float</span>* A, <span class="type">float</span>* B, <span class="type">float</span>* C)</span></span></span><br><span class="line"><span class="function"></span>{</span><br><span class="line"> <span class="type">int</span> i = threadIdx.x;</span><br><span class="line"> C[i] = A[i] + B[i];</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">int</span> <span class="title">main</span><span class="params">()</span></span></span><br><span class="line"><span class="function"></span>{</span><br><span class="line"> ...</span><br><span class="line"> <span class="comment">// Kernel invocation with N threads</span></span><br><span class="line"> VecAdd<<<<span class="number">1</span>, N>>>(A, B, C);</span><br><span class="line"> ...</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure></li><li><p><strong>PyTorch 样例</strong></p><p> 我们从 third_party/cub/cub/device/dispatch/dispatch_reduce.cuh 找一个核函数例子来看看。</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">/**</span></span><br><span class="line"><span class="comment"> * Reduce region kernel entry point (multi-block). Computes privatized reductions, one per thread block.</span></span><br><span class="line"><span class="comment"> */</span></span><br><span class="line"><span class="keyword">template</span> <</span><br><span class="line"> <span class="keyword">typename</span> ChainedPolicyT, <span class="comment">///< Chained tuning policy</span></span><br><span class="line"> <span class="keyword">typename</span> InputIteratorT, <span class="comment">///< Random-access input iterator type for reading input items \iterator</span></span><br><span class="line"> <span class="keyword">typename</span> OutputIteratorT, <span class="comment">///< Output iterator type for recording the reduced aggregate \iterator</span></span><br><span class="line"> <span class="keyword">typename</span> OffsetT, <span class="comment">///< Signed integer type for global offsets</span></span><br><span class="line"> <span class="keyword">typename</span> ReductionOpT> <span class="comment">///< Binary reduction functor type having member <tt>T operator()(const T &a, const T &b)</tt></span></span><br><span class="line">__launch_bounds__ (<span class="built_in">int</span>(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS))</span><br><span class="line"><span class="function">__global__ <span class="type">void</span> <span class="title">DeviceReduceKernel</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> InputIteratorT d_in, <span class="comment">///< [in] Pointer to the input sequence of data items</span></span></span></span><br><span class="line"><span class="params"><span class="function"> OutputIteratorT d_out, <span class="comment">///< [out] Pointer to the output aggregate</span></span></span></span><br><span class="line"><span class="params"><span class="function"> OffsetT num_items, <span class="comment">///< [in] Total number of input data items</span></span></span></span><br><span class="line"><span class="params"><span class="function"> GridEvenShare<OffsetT> even_share, <span class="comment">///< [in] Even-share descriptor for mapping an equal number of tiles onto each thread block</span></span></span></span><br><span class="line"><span class="params"><span class="function"> ReductionOpT reduction_op)</span> <span class="comment">///< [in] Binary reduction functor</span></span></span><br><span class="line"><span class="function"></span>{</span><br><span class="line"> <span class="comment">// The output value type</span></span><br><span class="line"> <span class="keyword">typedef</span> <span class="keyword">typename</span> If<(Equals<<span class="keyword">typename</span> std::iterator_traits<OutputIteratorT>::value_type, <span class="type">void</span>>::VALUE), <span class="comment">// OutputT = (if output iterator's value type is void) ?</span></span><br><span class="line"> <span class="keyword">typename</span> std::iterator_traits<InputIteratorT>::value_type, <span class="comment">// ... then the input iterator's value type,</span></span><br><span class="line"> <span class="keyword">typename</span> std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; <span class="comment">// ... else the output iterator's value type</span></span><br><span class="line"></span><br><span class="line"> <span class="comment">// Thread block type for reducing input tiles</span></span><br><span class="line"> <span class="keyword">typedef</span> AgentReduce<</span><br><span class="line"> <span class="keyword">typename</span> ChainedPolicyT::ActivePolicy::ReducePolicy,</span><br><span class="line"> InputIteratorT,</span><br><span class="line"> OutputIteratorT,</span><br><span class="line"> OffsetT,</span><br><span class="line"> ReductionOpT></span><br><span class="line"> AgentReduceT;</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Shared memory storage</span></span><br><span class="line"> __shared__ <span class="keyword">typename</span> AgentReduceT::TempStorage temp_storage;</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Consume input tiles</span></span><br><span class="line"> OutputT block_aggregate = <span class="built_in">AgentReduceT</span>(temp_storage, d_in, reduction_op).<span class="built_in">ConsumeTiles</span>(even_share);</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Output result</span></span><br><span class="line"> <span class="keyword">if</span> (threadIdx.x == <span class="number">0</span>)</span><br><span class="line"> d_out[blockIdx.x] = block_aggregate;</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure></li></ul>]]></content>
<summary type="html">CUDA编程模型基础</summary>
<category term="分布式训练" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="分布式训练" scheme="https://thinksky5124.github.io/tags/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="工程" scheme="https://thinksky5124.github.io/tags/%E5%B7%A5%E7%A8%8B/"/>
</entry>
<entry>
<title>Dispatcher</title>
<link href="https://thinksky5124.github.io/2022/08/18/Dispatcher/"/>
<id>https://thinksky5124.github.io/2022/08/18/Dispatcher/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.643Z</updated>
<content type="html"><![CDATA[<h1 id="Dispatcher"><a href="#Dispatcher" class="headerlink" title="Dispatcher"></a>Dispatcher</h1><p>我们接下来通过源码来看看。</p><h3 id="虚函数表"><a href="#虚函数表" class="headerlink" title="虚函数表"></a><strong>虚函数表</strong></h3><ul><li><p><strong>Schema 例子</strong></p><p> 每个kernel 算子(虚函数)都有一个对应的schema,我们可以从 aten/src/ATen/native/native_functions.yaml 之中找到一些虚函数 schema 的例子,这些都是以字符串的形式呈现。我们可以看到,schema 包括算子名称(比如zero_sparse_),输入参数个数和类型,返回值类型,是否需要check,如何分发等等。</p> <figure class="highlight yaml"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># zero 操作对应的虚函数表</span></span><br><span class="line"><span class="bullet">-</span> <span class="attr">func:</span> <span class="string">zero_(Tensor(a!)</span> <span class="string">self)</span> <span class="string">-></span> <span class="string">Tensor(a!)</span></span><br><span class="line"> <span class="attr">device_check:</span> <span class="string">NoCheck</span> <span class="comment"># TensorIterator</span></span><br><span class="line"> <span class="attr">variants:</span> <span class="string">method,</span> <span class="string">function</span></span><br><span class="line"> <span class="attr">dispatch:</span></span><br><span class="line"> <span class="string">CPU,</span> <span class="attr">CUDA:</span> <span class="string">zero_</span></span><br><span class="line"> <span class="attr">Meta:</span> <span class="string">zero_meta_</span></span><br><span class="line"> <span class="string">SparseCPU,</span> <span class="attr">SparseCUDA:</span> <span class="string">zero_sparse_</span></span><br><span class="line"> <span class="attr">MkldnnCPU:</span> <span class="string">mkldnn_zero_</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># sub.out 对应的虚函数表</span></span><br><span class="line"><span class="bullet">-</span> <span class="attr">func:</span> <span class="string">sub.out(Tensor</span> <span class="string">self,</span> <span class="string">Tensor</span> <span class="string">other,</span> <span class="string">*,</span> <span class="string">Scalar</span> <span class="string">alpha=1,</span> <span class="string">Tensor(a!)</span> <span class="string">out)</span> <span class="string">-></span> <span class="string">Tensor(a!)</span></span><br><span class="line"> <span class="attr">device_check:</span> <span class="string">NoCheck</span> <span class="comment"># TensorIterator</span></span><br><span class="line"> <span class="attr">structured:</span> <span class="literal">True</span></span><br><span class="line"> <span class="attr">structured_inherits:</span> <span class="string">TensorIteratorBase</span></span><br><span class="line"> <span class="attr">dispatch:</span></span><br><span class="line"> <span class="string">CPU,</span> <span class="attr">CUDA:</span> <span class="string">sub_out</span></span><br><span class="line"> <span class="string">SparseCPU,</span> <span class="attr">SparseCUDA:</span> <span class="string">sub_out_sparse</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># sub.Tensor 对应的虚函数表</span></span><br><span class="line"><span class="bullet">-</span> <span class="attr">func:</span> <span class="string">sub.Tensor(Tensor</span> <span class="string">self,</span> <span class="string">Tensor</span> <span class="string">other,</span> <span class="string">*,</span> <span class="string">Scalar</span> <span class="string">alpha=1)</span> <span class="string">-></span> <span class="string">Tensor</span></span><br><span class="line"> <span class="attr">device_check:</span> <span class="string">NoCheck</span> <span class="comment"># TensorIterator</span></span><br><span class="line"> <span class="attr">variants:</span> <span class="string">function,</span> <span class="string">method</span></span><br><span class="line"> <span class="attr">structured_delegate:</span> <span class="string">sub.out</span></span><br><span class="line"> <span class="attr">dispatch:</span></span><br><span class="line"> <span class="string">SparseCPU,</span> <span class="attr">SparseCUDA:</span> <span class="string">sub_sparse</span></span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure></li></ul><h3 id="Operator的实现"><a href="#Operator的实现" class="headerlink" title="Operator的实现"></a><strong>Operator的实现</strong></h3><p>我们可以看看 zero 的两个实现,下面是MkldnnCPU的实现。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="function">Tensor& <span class="title">mkldnn_zero_</span><span class="params">(Tensor& self)</span> </span>{</span><br><span class="line"> <span class="keyword">using</span> Vec = vec::Vectorized<<span class="type">float</span>>;</span><br><span class="line"></span><br><span class="line"> ideep::tensor& x = <span class="built_in">itensor_from_mkldnn</span>(self);</span><br><span class="line"></span><br><span class="line"> <span class="keyword">auto</span> n = x.<span class="built_in">get_nelems</span>();</span><br><span class="line"> <span class="keyword">auto</span>* x_ = <span class="built_in">static_cast</span><<span class="type">float</span>*>(x.<span class="built_in">get_data_handle</span>());</span><br><span class="line"> <span class="built_in">parallel_for</span>(<span class="number">0</span>, n, <span class="number">2048</span>, [x_](<span class="type">int64_t</span> begin, <span class="type">int64_t</span> end) {</span><br><span class="line"> vec::<span class="built_in">map</span>(</span><br><span class="line"> [](Vec <span class="comment">/* unused */</span>) { <span class="keyword">return</span> <span class="number">0.0</span>; },</span><br><span class="line"> x_ + begin,</span><br><span class="line"> x_ + begin,</span><br><span class="line"> end - begin);</span><br><span class="line"> });</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> self;</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>又比如下面是SparseCPU, SparseCUDA 的对应实现:</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// --------------------------------------------------------------------</span></span><br><span class="line"><span class="comment">// zero_(SparseTensor)</span></span><br><span class="line"><span class="comment">// --------------------------------------------------------------------</span></span><br><span class="line"><span class="comment">// hummu hummu</span></span><br><span class="line"><span class="function">SparseTensor& <span class="title">zero_sparse_</span><span class="params">(SparseTensor& self)</span> </span>{</span><br><span class="line"> <span class="built_in">AT_ASSERT</span>(self.<span class="built_in">is_sparse</span>());</span><br><span class="line"> at::<span class="built_in">zeros_out</span>(self, <span class="built_in">get_sparse_impl</span>(self)-><span class="built_in">sizes</span>());</span><br><span class="line"> <span class="keyword">return</span> self._coalesced_(<span class="literal">true</span>);</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="Dispatcher-定义"><a href="#Dispatcher-定义" class="headerlink" title="Dispatcher 定义"></a><strong>Dispatcher 定义</strong></h3><p>我们接下来看看Dispatcher的定义,这里只给出部分成员变量。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">TORCH_API</span> Dispatcher <span class="keyword">final</span> {</span><br><span class="line"><span class="keyword">private</span>:</span><br><span class="line"> <span class="comment">// For direct access to backend fallback information</span></span><br><span class="line"> <span class="keyword">friend</span> <span class="keyword">class</span> <span class="title class_">impl</span>::OperatorEntry;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">struct</span> <span class="title class_">OperatorDef</span> <span class="keyword">final</span> {</span><br><span class="line"> <span class="function"><span class="keyword">explicit</span> <span class="title">OperatorDef</span><span class="params">(OperatorName&& op_name)</span></span></span><br><span class="line"><span class="function"> : op(std::move(op_name)) {</span>}</span><br><span class="line"> impl::OperatorEntry op;</span><br><span class="line"> <span class="type">size_t</span> def_count = <span class="number">0</span>;</span><br><span class="line"> <span class="type">size_t</span> def_and_impl_count = <span class="number">0</span>;</span><br><span class="line"> };</span><br><span class="line"> <span class="keyword">friend</span> <span class="keyword">class</span> <span class="title class_">OperatorHandle</span>;</span><br><span class="line"> <span class="keyword">template</span><<span class="keyword">class</span>> <span class="keyword">friend</span> <span class="keyword">class</span> <span class="title class_">TypedOperatorHandle</span>;</span><br><span class="line"></span><br><span class="line"><span class="keyword">public</span>:</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="type">static</span> Dispatcher& <span class="title">realSingleton</span><span class="params">()</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="comment">//存储所有的算子,并在其成员变量中存储了每个算子的不同版本,比如cpu,cuda,autograd....</span></span><br><span class="line"> std::list<OperatorDef> operators_;</span><br><span class="line"> <span class="comment">//注册算子时会将算子名称和方法也存储在这个里面, 这样就可以快速的通过名字查找到算子方法(其中包含了成员OperatorDef)</span></span><br><span class="line"> LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;</span><br><span class="line"> <span class="comment">// Map from namespace to debug string (saying, e.g., where the library was defined)</span></span><br><span class="line"> ska::flat_hash_map<std::string, std::string> libraries_;</span><br><span class="line"> std::array<impl::AnnotatedKernel, <span class="keyword">static_cast</span><<span class="type">uint8_t</span>>(DispatchKey::NumDispatchKeys)> backendFallbackKernels_;</span><br><span class="line"> std::unique_ptr<detail::RegistrationListenerList> listeners_;</span><br><span class="line"> std::mutex mutex_;</span><br><span class="line">};</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>逻辑大致如下,operators_ 存储了所有的算子:</p><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">+--------------------------------------------+</span><br><span class="line">| Dispatcher |</span><br><span class="line">| |</span><br><span class="line">| |</span><br><span class="line">| |</span><br><span class="line">| std::list<OperatorDef> operators_ |</span><br><span class="line">| |</span><br><span class="line">| operatorLookupTable_ |</span><br><span class="line">| |</span><br><span class="line">+--------------------------------------------+</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="注册"><a href="#注册" class="headerlink" title="注册"></a><strong>注册</strong></h3><ul><li><p>接下来给出注册虚函数表的方法。</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="function">RegistrationHandleRAII <span class="title">Dispatcher::registerImpl</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> OperatorName op_name,</span></span></span><br><span class="line"><span class="params"><span class="function"> c10::optional<DispatchKey> dispatch_key,</span></span></span><br><span class="line"><span class="params"><span class="function"> KernelFunction kernel,</span></span></span><br><span class="line"><span class="params"><span class="function"> c10::optional<impl::CppSignature> cpp_signature,</span></span></span><br><span class="line"><span class="params"><span class="function"> std::unique_ptr<FunctionSchema> inferred_function_schema,</span></span></span><br><span class="line"><span class="params"><span class="function"> std::string debug</span></span></span><br><span class="line"><span class="params"><span class="function">)</span> </span>{</span><br><span class="line"> <span class="function">std::lock_guard<std::mutex> <span class="title">lock</span><span class="params">(mutex_)</span></span>;</span><br><span class="line"> <span class="keyword">auto</span> op = <span class="built_in">findOrRegisterName_</span>(op_name);</span><br><span class="line"> <span class="keyword">auto</span> handle = op.operatorDef_->op.<span class="built_in">registerKernel</span>( <span class="comment">// 进行注册</span></span><br><span class="line"> *<span class="keyword">this</span>,</span><br><span class="line"> dispatch_key,</span><br><span class="line"> std::<span class="built_in">move</span>(kernel),</span><br><span class="line"> std::<span class="built_in">move</span>(cpp_signature),</span><br><span class="line"> std::<span class="built_in">move</span>(inferred_function_schema),</span><br><span class="line"> std::<span class="built_in">move</span>(debug)</span><br><span class="line"> );</span><br><span class="line"></span><br><span class="line"> ++op.operatorDef_->def_and_impl_count;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">RegistrationHandleRAII</span>([<span class="keyword">this</span>, op, op_name, dispatch_key, handle] {</span><br><span class="line"> <span class="built_in">deregisterImpl_</span>(op, op_name, dispatch_key, handle);</span><br><span class="line"> });</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p> <strong>注册表</strong></p><p> OperatorEntry代表了一个算子,以及该算子的dispatch table,这里只给出成员变量。</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">TORCH_API</span> OperatorEntry <span class="keyword">final</span> { <span class="comment">//代表了一个算子,以及该算子的dispatch table</span></span><br><span class="line"><span class="keyword">public</span>:</span><br><span class="line"> OperatorName name_;</span><br><span class="line"> c10::optional<AnnotatedSchema> schema_;</span><br><span class="line"> <span class="comment">//存储了不同key对应的算子实现版本,比如cpu,cuda,autograd 等等,所有的算子版本都会在这个table里面</span></span><br><span class="line"> std::array<KernelFunction, <span class="keyword">static_cast</span><<span class="type">uint8_t</span>>(DispatchKey::NumDispatchKeys)> dispatchTable_;</span><br><span class="line"> DispatchKeyExtractor dispatchKeyExtractor_;</span><br><span class="line"> <span class="comment">//不同 DispatchKey对应了不同的版本的kernel算子实现版本</span></span><br><span class="line"> ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;</span><br><span class="line">};</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p> 逻辑如下:</p> <figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">+---------------------------+ +------------------------------------------+</span><br><span class="line">| OperatorEntry | | |</span><br><span class="line">| | | std::array<KernelFunction, uint8_t> |</span><br><span class="line">| | | |</span><br><span class="line">| | | |</span><br><span class="line">| | | int('CPU') : CPU_kernel |</span><br><span class="line">| dispatchTable_ +-------> | |</span><br><span class="line">| | | int('GPU') : GPU_kernel |</span><br><span class="line">| | | |</span><br><span class="line">| | | ...... |</span><br><span class="line">| | | |</span><br><span class="line">| | | int('Metal') : Metal_kernel |</span><br><span class="line">| | | |</span><br><span class="line">+---------------------------+ +------------------------------------------+</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p> <strong>注册行为</strong></p><p> 最终注册行为就是往 dispatchTable_ 之中设置。</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">OperatorEntry::updateDispatchTableEntry_</span><span class="params">(<span class="type">const</span> c10::Dispatcher& dispatcher, DispatchKey dispatch_key)</span> </span>{</span><br><span class="line"> <span class="keyword">auto</span> dispatch_ix = <span class="built_in">static_cast</span><<span class="type">uint8_t</span>>(dispatch_key);</span><br><span class="line"> dispatchTable_[dispatch_ix] = <span class="built_in">computeDispatchTableEntry</span>(dispatcher, dispatch_key);</span><br><span class="line"> dispatchKeyExtractor_.<span class="built_in">setOperatorHasFallthroughForKey</span>(dispatch_key, dispatchTable_[dispatch_ix].<span class="built_in">isFallthrough</span>());</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p> 所以 Dispatcher 数据结构拓展近似如下,这里包含了两个OperatorEntry,分别对应了op1和op2,就是说,目前系统中一共有两个operator,每个 operator 有4个kernel函数,分别对应了CPU,GPU等四个后端。</p> <figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br></pre></td><td class="code"><pre><span class="line">+-----------------------------------------+</span><br><span class="line">| Dispatcher |</span><br><span class="line">| |</span><br><span class="line">| |</span><br><span class="line">| std::list<OperatorDef> operators_ +--------+</span><br><span class="line">| | |</span><br><span class="line">| | |</span><br><span class="line">| operatorLookupTable_ | |</span><br><span class="line">| | |</span><br><span class="line">+-----------------------------------------+ |</span><br><span class="line"> |</span><br><span class="line"> |</span><br><span class="line"> v</span><br><span class="line"> +-----------------------------------+------------------------------------------+</span><br><span class="line"> | +---------------------------+ +--------------------------------------+ |</span><br><span class="line"> | | OperatorEntry | | | |</span><br><span class="line"> | | | | std::array<KernelFunction, uint8_t> | |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | | name_ = op1 | | | |</span><br><span class="line"> | | | | int('CPU') : op1_cpu | |</span><br><span class="line"> | | dispatchTable_ +-------> | | |</span><br><span class="line"> | | | | int('GPU') : op1_gpu | |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | | | | int('XLA') : op1_xla | |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | | | | int('Metal') : op1_metal | |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | +---------------------------+ +--------------------------------------+ |</span><br><span class="line"> | |</span><br><span class="line"> | |</span><br><span class="line"> | +---------------------------+ +--------------------------------------+ |</span><br><span class="line"> | | OperatorEntry | | | |</span><br><span class="line"> | | | | std::array<KernelFunction, uint8_t> | |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | | name_ = op2 | | | |</span><br><span class="line"> | | | | int('CPU') : op2_cpu | |</span><br><span class="line"> | | dispatchTable_ +-------> | | |</span><br><span class="line"> | | | | int('GPU') : op2_gpu | |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | | | | int('XLA') : op2_xla | |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | | | | int('Metal') : op2_metal | |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | +---------------------------+ +--------------------------------------+ |</span><br><span class="line"> +------------------------------------------------------------------------------+</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure></li></ul><h3 id="如何dispatch"><a href="#如何dispatch" class="headerlink" title="如何dispatch"></a><strong>如何dispatch</strong></h3><ul><li><p><strong>调度依据</strong></p><p> PyTorch 之中会依据dtype、device和layout的不同来调度不同的operator。</p><ul><li>大多数类型(比如int32)可以使用模版方式直接进行映射,但是某些operator 不支持模版功能,就需要dispatcher这样的动态调度器。</li><li>PyTorch的tensor不仅可以运行在CPU上,还可以跑在GPU,mkldnn和xla等设备,这也需要动态调度。</li><li>layout是指tensor中元素的排布,这就有strided layout和sparse layout的区别,所以也需要动态调度。</li></ul></li><li><p><strong>调度代码</strong></p><p> 这里给出部分代码</p><p> 算子调度的逻辑是:</p><ol><li>通过 dispatcher 类 + operator name + 操作类型等联合的形式来查找对应的算子 schema,算子的schema 定义了本算子的输入/输出/参数等等的相关信息。</li><li>调用 dispatcher::call 完成算子操作。<ol><li>得到 dispatcher 中的 dispatchKetSet。</li><li>利用 op.lookup 找到最高优先级的 key,并且依据 key 找到对应的 KernelFunction。</li><li>调用 kernel。</li></ol></li></ol><p> 首先,具体以range的定义来看看如何查找schema,具体在 findSchemaOrThrow 内部是通过operatorLookupTable_来查找op:</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="function">at::Tensor <span class="title">range::call</span><span class="params">(<span class="type">const</span> at::Scalar & start, <span class="type">const</span> at::Scalar & end, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<<span class="type">bool</span>> pin_memory)</span> </span>{</span><br><span class="line"> <span class="type">static</span> <span class="keyword">auto</span> op = c10::Dispatcher::<span class="built_in">singleton</span>()</span><br><span class="line"> .<span class="built_in">findSchemaOrThrow</span>(<span class="string">"aten::range"</span>, <span class="string">""</span>)</span><br><span class="line"> .typed<at::<span class="built_in">Tensor</span> (<span class="type">const</span> at::Scalar &, <span class="type">const</span> at::Scalar &, c10::optional<at::ScalarType>, c10::optional<at::Layout>, c10::optional<at::Device>, c10::optional<<span class="type">bool</span>>)>();</span><br><span class="line"> <span class="keyword">return</span> op.<span class="built_in">call</span>(start, end, dtype, layout, device, pin_memory);</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p> 其次,Dispatcher::call 定义如下:</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">template</span><<span class="keyword">class</span> Return, <span class="keyword">class</span>... Args></span></span><br><span class="line"><span class="function">C10_DISPATCHER_INLINE_UNLESS_MOBILE Return <span class="title">Dispatcher::call</span><span class="params">(<span class="type">const</span> TypedOperatorHandle<Return(Args...)>& op, Args... args)</span> <span class="type">const</span> </span>{</span><br><span class="line"> detail::<span class="built_in">unused_arg_</span>(args...);</span><br><span class="line"></span><br><span class="line"> <span class="comment">// 得到key set</span></span><br><span class="line"> <span class="keyword">auto</span> dispatchKeySet = op.operatorDef_->op.<span class="built_in">dispatchKeyExtractor</span>()</span><br><span class="line"> .<span class="keyword">template</span> <span class="built_in">getDispatchKeySetUnboxed</span><Args...>(args...);</span><br><span class="line"> <span class="built_in">TORCH_INTERNAL_ASSERT_DEBUG_ONLY</span>(!c10::<span class="built_in">isAliasDispatchKey</span>(dispatchKeySet.<span class="built_in">highestPriorityTypeId</span>()));</span><br><span class="line"></span><br><span class="line"> <span class="comment">// 得到算子</span></span><br><span class="line"> <span class="type">const</span> KernelFunction& kernel = op.operatorDef_->op.<span class="built_in">lookup</span>(dispatchKeySet.<span class="built_in">highestPriorityTypeId</span>());</span><br><span class="line"></span><br><span class="line"> <span class="comment">// 进行调度</span></span><br><span class="line"><span class="meta">#<span class="keyword">ifndef</span> PYTORCH_DISABLE_PER_OP_PROFILING</span></span><br><span class="line"> <span class="type">bool</span> pre_sampled = <span class="literal">false</span>;</span><br><span class="line"> <span class="keyword">if</span> (<span class="built_in">C10_UNLIKELY</span>(at::<span class="built_in">shouldRunRecordFunction</span>(&pre_sampled))) {</span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">callWithDispatchKeySlowPath</span><Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);</span><br><span class="line"> }</span><br><span class="line"><span class="meta">#<span class="keyword">endif</span> <span class="comment">// PYTORCH_DISABLE_PER_OP_PROFILINGreturn kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);</span></span></span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure></li></ul><h3 id="key"><a href="#key" class="headerlink" title="key"></a><strong>key</strong></h3><ul><li><p>我们接下来看看key的定义,因为太多,所以我们只给出部分数值。</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">enum class</span> <span class="title class_">DispatchKey</span> : <span class="type">uint8_t</span> {</span><br><span class="line"> CPU, <span class="comment">// registered at build/aten/src/ATen/RegisterCPU.cpp</span></span><br><span class="line"> CUDA, <span class="comment">// registered at build/aten/src/ATen/RegisterCUDA.cpp</span></span><br><span class="line"> HIP, <span class="comment">// NB: I think this is not actually used, due to Note [Masquerading as</span></span><br><span class="line"> <span class="comment">// CUDA]</span></span><br><span class="line"> FPGA, <span class="comment">// Xilinx support lives out of tree at</span></span><br><span class="line"> <span class="comment">// https://gitlab.com/pytorch-complex/vitis_kernels</span></span><br><span class="line"> MSNPU, <span class="comment">// unused externally, but tested at</span></span><br><span class="line"> <span class="comment">// test/cpp_extensions/msnpu_extension.cpp</span></span><br><span class="line"> XLA, <span class="comment">// lives out of tree at https://github.com/pytorch/xla</span></span><br><span class="line"> MLC, <span class="comment">// lives out of tree at https://github.com/pytorch/MLCompute</span></span><br><span class="line"> Vulkan,</span><br><span class="line"> Metal,</span><br><span class="line"> XPU, <span class="comment">// For out of tree Intel's heterogeneous computing plug-in</span></span><br><span class="line"> HPU, <span class="comment">// For out of tree & closed source integration of HPU / Habana</span></span><br><span class="line"> VE, <span class="comment">// For out of tree & closed source integration of SX-Aurora / NEC</span></span><br><span class="line"> Lazy, <span class="comment">// For lazy tensor backends</span></span><br><span class="line"> <span class="comment">// A meta tensor is a tensor without any data associated with it. (They</span></span><br><span class="line"> <span class="comment">// have also colloquially been referred to as tensors on the "null" device).</span></span><br><span class="line"> <span class="comment">// A meta tensor can be used to dry run operators without actually doing any</span></span><br><span class="line"> <span class="comment">// computation, e.g., add on two meta tensors would give you another meta</span></span><br><span class="line"> <span class="comment">// tensor with the output shape and dtype, but wouldn't actually add anything.</span></span><br><span class="line"> Meta,</span><br><span class="line"> <span class="comment">// Here are backends which specify more specialized operators</span></span><br><span class="line"> <span class="comment">// based on the dtype of the tensor.</span></span><br><span class="line"> QuantizedCPU, <span class="comment">// registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp</span></span><br><span class="line"> QuantizedCUDA, <span class="comment">// registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp</span></span><br><span class="line"> QuantizedXPU, <span class="comment">// For out of tree Intel's heterogeneous computing plug-in</span></span><br><span class="line"> <span class="comment">// This backend is to support custom RNGs; it lets you go</span></span><br><span class="line"> <span class="comment">// to a different kernel if you pass in a generator that is not a</span></span><br><span class="line"> <span class="comment">// traditional CPUGeneratorImpl/CUDAGeneratorImpl. To make use of this</span></span><br><span class="line"> <span class="comment">// key:</span></span><br><span class="line"> <span class="comment">// 1) set it as a second parameter of at::Generator constructor call in</span></span><br><span class="line"> <span class="comment">// the user-defined PRNG class.</span></span><br><span class="line"> <span class="comment">// 2) use it as a dispatch key while registering custom kernels</span></span><br><span class="line"> <span class="comment">// (templatized kernels specialized for user-defined PRNG class)</span></span><br><span class="line"> <span class="comment">// intended for out of tree use; tested by aten/src/ATen/test/rng_test.cpp</span></span><br><span class="line"> CustomRNGKeyId,</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Here are backends which specify more specialized operators</span></span><br><span class="line"> <span class="comment">// based on the layout of the tensor. Note that the sparse backends</span></span><br><span class="line"> <span class="comment">// are one case where ordering matters: sparse multi-dispatches with</span></span><br><span class="line"> <span class="comment">// the corresponding dense tensors, and must be handled before them.</span></span><br><span class="line"> MkldnnCPU, <span class="comment">// registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp</span></span><br><span class="line"> <span class="comment">// NB: not to be confused with MKLDNN, which is Caffe2 only</span></span><br><span class="line"> SparseCPU, <span class="comment">// registered at build/aten/src/ATen/RegisterSparseCPU.cpp</span></span><br><span class="line"> SparseCUDA, <span class="comment">// registered at build/aten/src/ATen/RegisterSparseCUDA.cpp</span></span><br><span class="line"> SparseHIP, <span class="comment">// <span class="doctag">TODO:</span> I think this is not actually used, due to Note</span></span><br><span class="line"> <span class="comment">// [Masquerading as CUDA]</span></span><br><span class="line"> SparseXPU, <span class="comment">// For out of tree Intel's heterogeneous computing plug-in</span></span><br><span class="line"> SparseVE, <span class="comment">// For out of tree & closed source integration of SX-Aurora / NEC</span></span><br><span class="line"> SparseCsrCPU,</span><br><span class="line"> SparseCsrCUDA,</span><br><span class="line"></span><br><span class="line"> AutogradOther,</span><br><span class="line"> AutogradCPU,</span><br><span class="line"> AutogradCUDA,</span><br><span class="line"> AutogradXLA,</span><br><span class="line"> AutogradLazy,</span><br><span class="line"> AutogradXPU,</span><br><span class="line"> AutogradMLC,</span><br><span class="line"> AutogradHPU,</span><br><span class="line"></span><br><span class="line"> ......</span><br><span class="line">};</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure></li><li><p><strong>key的使用</strong></p><p> 因为篇幅所限,我们无法深入分析每一种情况,这里只给出从 DeviceType 出发的情景。我们从下面函数可以看到,如何从 DeviceType 映射到 DispatchKey 类型。</p> <figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> Func></span><br><span class="line"><span class="function"><span class="keyword">inline</span> CppFunction <span class="title">dispatch</span><span class="params">(c10::DeviceType type, Func&& raw_f)</span> </span>{</span><br><span class="line"> <span class="keyword">auto</span> deviceTypeToDispatchKey = [](c10::DeviceType t){</span><br><span class="line"> <span class="keyword">switch</span> (t) {</span><br><span class="line"> <span class="comment">// This list is synchronized with the k-constants in c10/core/DeviceType.h</span></span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::CPU:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::CPU;</span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::CUDA:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::CUDA;</span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::XLA:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::XLA;</span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::Lazy:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::Lazy;</span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::MLC:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::MLC;</span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::Meta:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::Meta;</span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::HIP:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::HIP;</span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::MSNPU:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::MSNPU;</span><br><span class="line"> <span class="keyword">case</span> c10::DeviceType::HPU:</span><br><span class="line"> <span class="keyword">return</span> c10::DispatchKey::HPU;</span><br><span class="line"> <span class="keyword">default</span>:</span><br><span class="line"> <span class="built_in">TORCH_CHECK</span>(<span class="literal">false</span>,</span><br><span class="line"> <span class="string">"Device type "</span>, t, <span class="string">" cannot be overloaded at dispatch time, "</span></span><br><span class="line"> <span class="string">"please file a bug report explaining what you were trying to do."</span>);</span><br><span class="line"> }</span><br><span class="line"> };</span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">dispatch</span>(<span class="built_in">deviceTypeToDispatchKey</span>(type), std::forward<Func>(raw_f));</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure></li></ul><h3 id="小结"><a href="#小结" class="headerlink" title="小结"></a><strong>小结</strong></h3><p>至此,我们知道,通过 Dispatcher 机制,PyTorch 可以依据dtype、device和layout的不同来调度不同的operator。这就解答了我们第三个问题:如何在 CPU,GPU 操作之间无缝切换?</p><p>关于第四个问题:是否需要把损失函数移动到 GPU 之上?,我们也有了解答:</p><p>损失函数的参数是前向传播的outputs和label,outputs已经在GPU之上(因为训练数据已经在GPU之上),label 也被用户手动设置到GPU之上。所以损失函数的参数都已经在GPU之上,这样 Dispather 就依据device会调用到GPU对应的operator,所以不需要把损失函数移动到GPU之上。</p><p>我们整理一个总体逻辑如下,序列是:</p><ol><li>把训练数据 inputs 移动到GPU。</li><li>进行前向操作,假设只有一个operator,就是 op1,使用 device=’GPU’ 这个 dispatch key 去 Dispatcher 查找。</li><li>找到了 op1-gpu 这个operator,进行计算,得出 outputs。</li><li>outputs 就自动存在于 GPU 之上。</li><li>把 Labels 也放到 GPU 之上。</li><li>进行损失函数运算,假设只有一个 operator,就是 op2,此时损失函数的参数都在GPU之上,所以使用 device= ‘GPU’ 这个 dispatch key 去 Dispatcher 查找。</li><li>找到了 op2-gpu 这个operator,进行计算,得出 loss。</li></ol><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br></pre></td><td class="code"><pre><span class="line"> +--------------------+</span><br><span class="line"> +-----------+ | Forward | +------------+ +------------------+</span><br><span class="line"> | GPU | | | | GPU | | Loss Function |</span><br><span class="line"> | +---> | op1 op1-gpu() +----> | +---> | | +--------+</span><br><span class="line"> | Inputs | 1 | | 4 | Outputs | | | | GPU |</span><br><span class="line"> | | | + ^ | | | | | | |</span><br><span class="line"> +-----------+ | | | | +------------+ | op2 op2-gpu() +-->+ loss |</span><br><span class="line"> | | | | | | | |</span><br><span class="line"> +--------------------+ +------------+ | + ^ | | |</span><br><span class="line"> | | | GPU | 5 | | | | +--------+</span><br><span class="line"> | | | +---> | | 6 | 7 |</span><br><span class="line"> 2 | | 3 | Labels | | | | |</span><br><span class="line"> | | | | | | | |</span><br><span class="line"> | | +------------+ +------------------+</span><br><span class="line"> +----------------------------+ +--------------------------------+ | |</span><br><span class="line"> | | | |</span><br><span class="line">+-----------------------------------------------------------------------------+ |</span><br><span class="line">| | | |</span><br><span class="line">| | +-------------------------------------------------------+ | |</span><br><span class="line">| | | Dispather | | |</span><br><span class="line">| | | + + + + | | |</span><br><span class="line">| | | | XLA | CPU | Metal | GPU | | |</span><br><span class="line">| | | +---------------------------------------------------+ | | |</span><br><span class="line">| | | | | | | | | |</span><br><span class="line">| +--------> | OP1 | op1-xla | op1-cpu | op1-metal | op1-gpu +---+ |</span><br><span class="line">| 'device=GPU' | | | | | +------+ | |</span><br><span class="line">| | +---------------------------------------------------+ | |</span><br><span class="line">| | | | | | | |</span><br><span class="line">+------------> | OP2 | op2-xla | op2-cpu | op2-metal | op2-gpu +---------------+</span><br><span class="line"> 'device=GPU' | | | | | +------+ |</span><br><span class="line"> | +---------------------------------------------------+ |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | OP3 | op3-xla | op3-cpu | op3-metal | op3-gpu |</span><br><span class="line"> | | | | | |</span><br><span class="line"> | +---------------------------------------------------+ |</span><br><span class="line"> +-------------------------------------------------------+</span><br></pre></td></tr></tbody></table></figure><p>截图如下:</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210302126-278214823.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211106210302126-278214823.png"></p>]]></content>
<summary type="html">Dispatcher</summary>
<category term="分布式训练" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="分布式训练" scheme="https://thinksky5124.github.io/tags/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="工程" scheme="https://thinksky5124.github.io/tags/%E5%B7%A5%E7%A8%8B/"/>
</entry>
<entry>
<title>DistributedDataParallel 初始化方法&存储</title>
<link href="https://thinksky5124.github.io/2022/08/18/DistributedDataParallel_%E5%88%9D%E5%A7%8B%E5%8C%96%E6%96%B9%E6%B3%95&%E5%AD%98%E5%82%A8/"/>
<id>https://thinksky5124.github.io/2022/08/18/DistributedDataParallel_%E5%88%9D%E5%A7%8B%E5%8C%96%E6%96%B9%E6%B3%95&%E5%AD%98%E5%82%A8/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<h1 id="DistributedDataParallel-初始化方法-存储"><a href="#DistributedDataParallel-初始化方法-存储" class="headerlink" title="DistributedDataParallel 初始化方法&存储"></a>DistributedDataParallel 初始化方法&存储</h1><h2 id="回顾"><a href="#回顾" class="headerlink" title="回顾"></a><strong>回顾</strong></h2><h3 id="基本概念"><a href="#基本概念" class="headerlink" title="基本概念"></a><strong>基本概念</strong></h3><p>关于分布式通信,PyTorch 提供的几个概念是:进程组,后端,初始化,Store。</p><ul><li><strong>进程组</strong> :DDP是真正的分布式训练,可以使用多台机器来组成一次并行运算的任务。为了能够让 DDP 的各个worker之间通信,PyTorch 设置了进程组这个概念。</li><li><strong>后端</strong> :后端这个概念是一个逻辑上的概念。本质上后端是一种IPC通信机制。</li><li><strong>初始化</strong> : 虽然有了后端和进程组的概念,但是如何让 worker 在建立进程组之前发现彼此? 这就需要一种初始化方法来告诉大家传递一个信息:如何联系到其它机器上的进程。</li><li><strong>Store</strong> : 可以认为是分布式键值存储,利用这个存储就可以在组中的进程之间共享信息以及初始化分布式包 (通过显式创建存储来作为<code>init_method</code>的替代)。</li></ul><h3 id="初始化进程组"><a href="#初始化进程组" class="headerlink" title="初始化进程组"></a><strong>初始化进程组</strong></h3><p>在调用任何 DDP 其他方法之前,需要使用<code>torch.distributed.init_process_group()</code>进行初始化。该方法会初始化默认分布式进程组和分布式包。此方法会阻塞,直到所有进程都加入,函数定义如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">init_process_group ( backend ,</span><br><span class="line"> init_method = <span class="literal">None</span> ,</span><br><span class="line"> timeout = default_pg_timeout ,</span><br><span class="line"> world_size =- <span class="number">1</span> ,</span><br><span class="line"> rank =- <span class="number">1</span> ,</span><br><span class="line"> store = <span class="literal">None</span> ,</span><br><span class="line"> group_name = <span class="string">''</span> ,</span><br><span class="line"> pg_options = <span class="literal">None</span> )</span><br></pre></td></tr></tbody></table></figure><p>初始化进程组有两种主要方法:</p><ol><li>明确指定 <code>store</code>,<code>rank</code> 和 <code>world_size</code>。</li><li>指定 <code>init_method</code>(一个 URL 字符串),它指示在哪里/如何发现对等点。</li></ol><p>如果两者都没有指定,<code>init_method</code>则假定为“env://”。因此大家可以看到,<code>store</code> 和 <code>init_method</code> 是互斥的。</p><p><code>init_process_group</code> 的参数具体如下:</p><ul><li><strong>后端</strong> – 要使用的后端。有效值包括<code>mpi</code>,<code>gloo</code>,和<code>nccl</code>。该字段应作为小写字符串(例如<code>"gloo"</code>)给出,也可以通过<code>Backend</code>属性(例如<code>Backend.GLOO</code>)访问 。如果在<code>nccl</code>后端每台机器上使用多个进程,则每个进程必须对其使用的每个 GPU 具有独占访问权限,因为在进程之间共享 GPU 可能会导致死锁。</li><li><strong><code>init_method</code></strong> – 指定如何初始化进程组的 URL。如果未指定<code>init_method</code>或<code>store</code>指定,则默认为<code>“env://”</code> 。与 <code>store</code>互斥。</li><li><strong><code>world_size</code></strong> – 参与作业的进程数。如果<code>store</code>指定,则 <code>world_size</code> 为必需。</li><li><strong><code>rank</code></strong> – 当前进程的等级(它应该是一个介于 0 和<code>world_size</code>1之间的数字)。如果<code>store</code>指定,则 rank 为必需。</li><li><strong><code>store</code></strong> – 所有 worker 都可以访问的键/值存储,用于交换连接/地址信息。与<code>init_method</code> 互斥。</li><li><strong><code>timeout</code></strong> – 针对进程组执行的操作超时。默认值等于 30 分钟。这适用于<code>gloo</code>后端。对于<code>nccl</code>,这仅在环境变量<code>NCCL_BLOCKING_WAIT</code> 或<code>NCCL_ASYNC_ERROR_HANDLING</code>设置为 1 时 适用。</li><li><strong><code>group_name</code></strong> – 组名。</li><li><strong><code>pg_options</code></strong> ( <em>Process Group Options</em> <em>,</em> <em>optional</em> ) – 进程组选项,指定在构建特定进程组期间需要传入哪些附加选项。</li></ul><h2 id="初始化"><a href="#初始化" class="headerlink" title="初始化"></a><strong>初始化</strong></h2><h3 id="初始化方法"><a href="#初始化方法" class="headerlink" title="初始化方法"></a><strong>初始化方法</strong></h3><p>目前DDP模块支持三种初始化方式:</p><ul><li>Environment variable initialization</li><li>Shared file-system initialization:<code>init_method**=**'file:///mnt/nfs/sharedfile'</code></li><li>TCP initialization :<code>init_method**=**'tcp://10.1.1.20:23456'</code></li></ul><p><strong>环境变量</strong></p><p>此方法将从环境变量中读取配置,是允许完全自定义获取信息的方式。通过在所有机器上设置以下四个环境变量,所有进程都可以正常连接到master(就是 rank 0 进程)以获取其他进程的信息,并最终与它们握手。</p><ul><li><code>MASTER_PORT</code>:rank 0 进程的机器上的端口。</li><li><code>MASTER_ADDR</code>:rank 0 进程的机器上的 IP 地址。</li><li><code>WORLD_SIZE</code>: 进程总数,因此master知道要等待多少worker。</li><li><code>RANK</code>: 每个进程的rank,所以进程会知道自己是否是master。</li></ul><p><strong>共享文件系统</strong></p><p>共享文件系统要求所有进程都可以访问共享文件系统,并将通过共享文件协调它们。这意味着每个进程都将打开文件,写入其信息,并等待每个进程都这样做。之后,所有所需的信息都将可供所有流程使用。为了避免竞争条件,文件系统必须通过<a href="http://man7.org/linux/man-pages/man2/fcntl.2.html">fcntl</a>支持锁定 。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">dist.init_process_group(</span><br><span class="line"> init_method=<span class="string">'file:///mnt/nfs/sharedfile'</span>,</span><br><span class="line"> rank=args.rank,</span><br><span class="line"> world_size=<span class="number">4</span>)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p><strong>TCP</strong></p><p>TCP 初始化方式是通过提供rank 0进程的IP和端口来实现的,在这里,所有worker都可以连接到等级为 0 的进程并交换有关如何相互联系的信息。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">dist.init_process_group(</span><br><span class="line"> init_method=<span class="string">'tcp://10.1.1.20:23456'</span>,</span><br><span class="line"> rank=args.rank,</span><br><span class="line"> world_size=<span class="number">4</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="init-method-VS-store"><a href="#init-method-VS-store" class="headerlink" title="init_method VS store"></a><strong><code>init_method</code> VS <code>store</code></strong></h3><p>为什么要有 <code>init_method</code> 和 <code>store</code> 这两个参数?</p><p>通过看 <code>init_process_group</code> 代码我们可以发现以下规律。</p><ul><li>当 MPI 时候, <code>init_method</code> 没有用处。</li><li>在非 MPI 后端时候,如果没有 <code>store</code> 参数,则使用 <code>init_method</code> 构建一个<code>store</code>。</li></ul><p><strong>所以最终还是落到了 store 之上,store才是其作用的实体</strong>。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> store <span class="keyword">is</span> <span class="literal">None</span>:</span><br><span class="line"> rendezvous_iterator = rendezvous(</span><br><span class="line"> init_method, rank, world_size, timeout=timeout</span><br><span class="line"> )</span><br><span class="line"> store, rank, world_size = <span class="built_in">next</span>(rendezvous_iterator)</span><br><span class="line"> store.set_timeout(timeout)</span><br></pre></td></tr></tbody></table></figure><p><code>init_process_group</code> 代码如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">init_process_group</span>(<span class="params">backend,</span></span><br><span class="line"><span class="params"> init_method=<span class="literal">None</span>,</span></span><br><span class="line"><span class="params"> timeout=default_pg_timeout,</span></span><br><span class="line"><span class="params"> world_size=-<span class="number">1</span>,</span></span><br><span class="line"><span class="params"> rank=-<span class="number">1</span>,</span></span><br><span class="line"><span class="params"> store=<span class="literal">None</span>,</span></span><br><span class="line"><span class="params"> group_name=<span class="string">''</span>,</span></span><br><span class="line"><span class="params"> pg_options=<span class="literal">None</span></span>):</span><br><span class="line"></span><br><span class="line"> <span class="keyword">global</span> _pg_group_ranks</span><br><span class="line"> <span class="keyword">global</span> _backend</span><br><span class="line"> <span class="keyword">global</span> _default_pg_init_method</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> store <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line"> <span class="keyword">assert</span> world_size > <span class="number">0</span>, <span class="string">'world_size must be positive if using store'</span></span><br><span class="line"> <span class="keyword">assert</span> rank >= <span class="number">0</span>, <span class="string">'rank must be non-negative if using store'</span></span><br><span class="line"> <span class="keyword">elif</span> init_method <span class="keyword">is</span> <span class="literal">None</span>:</span><br><span class="line"> init_method = <span class="string">"env://"</span></span><br><span class="line"></span><br><span class="line"> backend = Backend(backend)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> backend == Backend.MPI:</span><br><span class="line"> default_pg = _new_process_group_helper(</span><br><span class="line"> -<span class="number">1</span>,</span><br><span class="line"> -<span class="number">1</span>,</span><br><span class="line"> [],</span><br><span class="line"> Backend.MPI,</span><br><span class="line"> <span class="literal">None</span>,</span><br><span class="line"> group_name=group_name,</span><br><span class="line"> timeout=timeout)</span><br><span class="line"> _update_default_pg(default_pg)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="comment"># backward compatible API</span></span><br><span class="line"> <span class="keyword">if</span> store <span class="keyword">is</span> <span class="literal">None</span>:</span><br><span class="line"> <span class="comment"># 如果没有store,还是要用init_method构建一个store。</span></span><br><span class="line"> rendezvous_iterator = rendezvous(</span><br><span class="line"> init_method, rank, world_size, timeout=timeout</span><br><span class="line"> )</span><br><span class="line"> store, rank, world_size = <span class="built_in">next</span>(rendezvous_iterator)</span><br><span class="line"> store.set_timeout(timeout)</span><br><span class="line"></span><br><span class="line"> default_pg = _new_process_group_helper(</span><br><span class="line"> world_size,</span><br><span class="line"> rank,</span><br><span class="line"> [],</span><br><span class="line"> backend,</span><br><span class="line"> store,</span><br><span class="line"> pg_options=pg_options,</span><br><span class="line"> group_name=group_name,</span><br><span class="line"> timeout=timeout)</span><br><span class="line"> _update_default_pg(default_pg)</span><br><span class="line"></span><br><span class="line"> _pg_group_ranks[GroupMember.WORLD] = {i: i <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(GroupMember.WORLD.size())} <span class="comment"># type: ignore[attr-defined, index]</span></span><br><span class="line"> _backend = _pg_map[GroupMember.WORLD][<span class="number">0</span>] <span class="comment"># type: ignore[index]</span></span><br><span class="line"> _default_pg_init_method = init_method</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 省略</span></span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="rendezvous"><a href="#rendezvous" class="headerlink" title="rendezvous"></a><strong>rendezvous</strong></h3><p>上面代码之中提到了 rendezvous,我们就来看看这个概念。</p><p>在我们可以运行集合算法之前,参与的进程需要找到彼此并交换信息才能够进行通信。我们称这个过程为rendezvous。rendezvous过程的结果是一个三元组,其中包含一个共享键/值存储(store),进程的等级(rank)和参与进程的总数。如果内置的rendezvous方法都不适用于您的执行环境,那么您可以选择注册自己的rendezvous处理程序。在调用<code>rendezvous</code>函数时,选择一个唯一的名称并使用URL方案来标识它。</p><p>rendezvous 方法就是依据参数,选择不同的handler来处理。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">rendezvous</span>(<span class="params">url: <span class="built_in">str</span>, rank: <span class="built_in">int</span> = -<span class="number">1</span>, world_size: <span class="built_in">int</span> = -<span class="number">1</span>, **kwargs</span>):</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Append node-specific arguments.</span></span><br><span class="line"> result = urlparse(url)</span><br><span class="line"> <span class="keyword">if</span> rank != -<span class="number">1</span> <span class="keyword">or</span> world_size != -<span class="number">1</span>:</span><br><span class="line"> query_dict: <span class="type">Dict</span>[<span class="built_in">str</span>, <span class="type">Union</span>[<span class="built_in">int</span>, <span class="built_in">str</span>]] = <span class="built_in">dict</span>(</span><br><span class="line"> <span class="comment"># mypy doesn't allow dict() to accept List of values (#257)</span></span><br><span class="line"> pair.split(<span class="string">"="</span>) <span class="keyword">for</span> pair <span class="keyword">in</span> <span class="built_in">filter</span>(<span class="literal">None</span>, result.query.split(<span class="string">"&"</span>)) <span class="comment"># type: ignore[arg-<span class="built_in">type</span>, misc]</span></span><br><span class="line"> )</span><br><span class="line"> <span class="keyword">if</span> rank != -<span class="number">1</span>:</span><br><span class="line"> query_dict[<span class="string">"rank"</span>] = rank</span><br><span class="line"> <span class="keyword">if</span> world_size != -<span class="number">1</span>:</span><br><span class="line"> query_dict[<span class="string">"world_size"</span>] = world_size</span><br><span class="line"></span><br><span class="line"> result = result._replace(</span><br><span class="line"> query=<span class="string">"{}"</span>.<span class="built_in">format</span>(<span class="string">"&"</span>.join([<span class="string">"{}={}"</span>.<span class="built_in">format</span>(k, v) <span class="keyword">for</span> k, v <span class="keyword">in</span> query_dict.items()]))</span><br><span class="line"> )</span><br><span class="line"> url = urlunparse(result)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> _rendezvous_handlers[result.scheme](url, **kwargs)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>handler 如下,你会发现,其实 handler 就是对应了初始化的三种方法:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">register_rendezvous_handler(<span class="string">"tcp"</span>, _tcp_rendezvous_handler)</span><br><span class="line">register_rendezvous_handler(<span class="string">"env"</span>, _env_rendezvous_handler)</span><br><span class="line">register_rendezvous_handler(<span class="string">"file"</span>, _file_rendezvous_handler)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="小结"><a href="#小结" class="headerlink" title="小结"></a><strong>小结</strong></h3><p>从目前分析结果来看,我们得到了如下结论:</p><ul><li><code>init_method</code> 最终还是落到了 store 之上,store才是起作用的实体。</li><li>参与的进程需要找到彼此并交换信息才能够进行通信。这个过程被称为rendezvous。</li></ul><h2 id="Store"><a href="#Store" class="headerlink" title="Store"></a><strong>Store</strong></h2><p>我们给出一个正式的概念。Store 是分布式包(distributed package)所提供的分布式键值存储,所有的 workers 都会访问这个存储以共享信息以及初始化分布式包 。用户可以通过显式创建存储来作为<code>init_method</code>的替代。目前有 3 种键值存储:<code>TCPStore</code>, <code>FileStore</code>,和<code>HashStore</code>。</p><p>我们接着上节继续看 handler 概念。</p><h3 id="rendezvous-handlers"><a href="#rendezvous-handlers" class="headerlink" title="_rendezvous_handlers"></a><strong>_rendezvous_handlers</strong></h3><p>在 PyTorch 定义了一个全局变量 _rendezvous_handlers,用来保存如何返回 store 的方法,可以认为是工厂方法。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">_rendezvous_handlers = {}</span><br></pre></td></tr></tbody></table></figure><p>具体注册方式是:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">register_rendezvous_handler(<span class="string">"tcp"</span>, _tcp_rendezvous_handler)</span><br><span class="line">register_rendezvous_handler(<span class="string">"env"</span>, _env_rendezvous_handler)</span><br><span class="line">register_rendezvous_handler(<span class="string">"file"</span>, _file_rendezvous_handler)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>注册代码如下,就是往全局变量之中插入handler。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">register_rendezvous_handler</span>(<span class="params">scheme, handler</span>):</span><br><span class="line"> <span class="string">"""Registers a new rendezvous handler.</span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> scheme (str): URL scheme to identify your rendezvous handler.</span></span><br><span class="line"><span class="string"> handler (function): Handler that is invoked when the</span></span><br><span class="line"><span class="string"> `rendezvous()` function is called with a URL that uses</span></span><br><span class="line"><span class="string"> the corresponding scheme. It must be a generator function</span></span><br><span class="line"><span class="string"> that yields the triplet.</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="keyword">global</span> _rendezvous_handlers</span><br><span class="line"> <span class="keyword">if</span> scheme <span class="keyword">in</span> _rendezvous_handlers:</span><br><span class="line"> <span class="keyword">raise</span> RuntimeError(</span><br><span class="line"> <span class="string">"Rendezvous handler for {}:// already registered"</span>.<span class="built_in">format</span>(scheme)</span><br><span class="line"> )</span><br><span class="line"> _rendezvous_handlers[scheme] = handler</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="handlers"><a href="#handlers" class="headerlink" title="handlers"></a><strong>handlers</strong></h3><p>如果仔细看 handlers 的代码,就会发现其就是返回了不同的 store,比如 <code>_tcp_rendezvous_handler</code>具体就是使用各种信息建立 TCPStore,然后返回。</p><p>以下代码均删除非关键代码。</p><h3 id="file-rendezvous-handler"><a href="#file-rendezvous-handler" class="headerlink" title="_file_rendezvous_handler"></a><strong>_file_rendezvous_handler</strong></h3><p>这里返回了FileStore。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">_file_rendezvous_handler</span>(<span class="params">url: <span class="built_in">str</span>, **kwargs</span>):</span><br><span class="line"></span><br><span class="line"> result = urlparse(url)</span><br><span class="line"> path = result.path</span><br><span class="line"> query: <span class="type">Dict</span>[<span class="built_in">str</span>, <span class="built_in">str</span>]</span><br><span class="line"> <span class="comment"># mypy doesn't allow dict() to accept List of values (#257)</span></span><br><span class="line"> query = <span class="built_in">dict</span>(pair.split(<span class="string">"="</span>) <span class="keyword">for</span> pair <span class="keyword">in</span> <span class="built_in">filter</span>(<span class="literal">None</span>, result.query.split(<span class="string">"&"</span>))) <span class="comment"># type: ignore[misc, arg-<span class="built_in">type</span>]</span></span><br><span class="line"></span><br><span class="line"> rank = <span class="built_in">int</span>(query[<span class="string">"rank"</span>])</span><br><span class="line"> world_size = <span class="built_in">int</span>(query[<span class="string">"world_size"</span>])</span><br><span class="line"> store = FileStore(path, world_size)</span><br><span class="line"> <span class="keyword">yield</span> (store, rank, world_size)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># If this configuration is invalidated, there is nothing we can do about it</span></span><br><span class="line"> <span class="keyword">raise</span> RuntimeError(<span class="string">"Unable to perform rerendezvous using file:// method"</span>)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="tcp-rendezvous-handler"><a href="#tcp-rendezvous-handler" class="headerlink" title="_tcp_rendezvous_handler"></a><strong>_tcp_rendezvous_handler</strong></h3><p>这里返回了 TCPStore。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">_tcp_rendezvous_handler</span>(<span class="params">url: <span class="built_in">str</span>, timeout: timedelta = default_pg_timeout, **kwargs</span>):</span><br><span class="line"> result = urlparse(url)</span><br><span class="line"> query: <span class="type">Dict</span>[<span class="built_in">str</span>, <span class="type">Union</span>[<span class="built_in">int</span>, <span class="built_in">str</span>]]</span><br><span class="line"> <span class="comment"># mypy doesn't allow dict() to accept List of values (#257)</span></span><br><span class="line"> query = <span class="built_in">dict</span>(pair.split(<span class="string">"="</span>) <span class="keyword">for</span> pair <span class="keyword">in</span> <span class="built_in">filter</span>(<span class="literal">None</span>, result.query.split(<span class="string">"&"</span>))) <span class="comment"># type: ignore[misc, arg-<span class="built_in">type</span>]</span></span><br><span class="line"></span><br><span class="line"> rank = <span class="built_in">int</span>(query[<span class="string">"rank"</span>])</span><br><span class="line"> world_size = <span class="built_in">int</span>(query[<span class="string">"world_size"</span>])</span><br><span class="line"> start_daemon = rank == <span class="number">0</span></span><br><span class="line"> <span class="keyword">assert</span> result.hostname <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span></span><br><span class="line"> store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)</span><br><span class="line"> <span class="keyword">yield</span> (store, rank, world_size)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># If this configuration is invalidated, there is nothing we can do about it</span></span><br><span class="line"> <span class="keyword">raise</span> RuntimeError(<span class="string">"Unable to perform rerendezvous using tcp:// method"</span>)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="env-rendezvous-handler"><a href="#env-rendezvous-handler" class="headerlink" title="_env_rendezvous_handler"></a><strong>_env_rendezvous_handler</strong></h3><p>居然也返回了 TCPStore,但是其会从环境变量中提取需要的信息。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">_env_rendezvous_handler</span>(<span class="params">url: <span class="built_in">str</span>, timeout: timedelta = default_pg_timeout, **kwargs</span>):</span><br><span class="line"></span><br><span class="line"> result = urlparse(url)</span><br><span class="line"> query: <span class="type">Dict</span>[<span class="built_in">str</span>, <span class="type">Union</span>[<span class="built_in">int</span>, <span class="built_in">str</span>]]</span><br><span class="line"> query = <span class="built_in">dict</span>(pair.split(<span class="string">"="</span>) <span class="keyword">for</span> pair <span class="keyword">in</span> <span class="built_in">filter</span>(<span class="literal">None</span>, result.query.split(<span class="string">"&"</span>)))</span><br><span class="line"> rank: <span class="type">Optional</span>[<span class="type">Union</span>[<span class="built_in">str</span>, <span class="built_in">int</span>]]</span><br><span class="line"> world_size: <span class="type">Optional</span>[<span class="type">Union</span>[<span class="built_in">str</span>, <span class="built_in">int</span>]]</span><br><span class="line"> master_port: <span class="type">Optional</span>[<span class="type">Union</span>[<span class="built_in">str</span>, <span class="built_in">int</span>]]</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> <span class="string">"rank"</span> <span class="keyword">in</span> query:</span><br><span class="line"> rank = <span class="built_in">int</span>(query[<span class="string">"rank"</span>])</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> rank = <span class="built_in">int</span>(_get_env_or_raise(<span class="string">"RANK"</span>))</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> <span class="string">"world_size"</span> <span class="keyword">in</span> query:</span><br><span class="line"> world_size = <span class="built_in">int</span>(query[<span class="string">"world_size"</span>])</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> world_size = <span class="built_in">int</span>(_get_env_or_raise(<span class="string">"WORLD_SIZE"</span>))</span><br><span class="line"></span><br><span class="line"> master_addr = _get_env_or_raise(<span class="string">"MASTER_ADDR"</span>)</span><br><span class="line"> master_port = <span class="built_in">int</span>(_get_env_or_raise(<span class="string">"MASTER_PORT"</span>))</span><br><span class="line"></span><br><span class="line"> use_torchelastic_store = os.environ.get(<span class="string">"TORCHELASTIC_USE_AGENT_STORE"</span>, <span class="literal">None</span>)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> use_torchelastic_store == <span class="built_in">str</span>(<span class="literal">True</span>):</span><br><span class="line"> worker_process_prefix = <span class="string">"/worker"</span></span><br><span class="line"> <span class="comment"># When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed</span></span><br><span class="line"> <span class="comment"># to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread</span></span><br><span class="line"> <span class="comment"># on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False</span></span><br><span class="line"> tcp_store = TCPStore(master_addr, master_port, world_size, <span class="literal">False</span>, timeout)</span><br><span class="line"> <span class="keyword">yield</span> (PrefixStore(worker_process_prefix, tcp_store), rank, world_size)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="comment"># Start the TCP store daemon on the rank 0</span></span><br><span class="line"> start_daemon = rank == <span class="number">0</span></span><br><span class="line"> store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)</span><br><span class="line"> <span class="keyword">yield</span> (store, rank, world_size)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># If this configuration is invalidated, there is nothing we can do about it</span></span><br><span class="line"> <span class="keyword">raise</span> RuntimeError(<span class="string">"Unable to perform rerendezvous using env:// method"</span>)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h2 id="使用"><a href="#使用" class="headerlink" title="使用"></a><strong>使用</strong></h2><h3 id="使用-handler"><a href="#使用-handler" class="headerlink" title="使用 handler"></a><strong>使用 handler</strong></h3><p>如何使用 handler?在 <code>init_process_group</code> 之中有:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">rendezvous_iterator = rendezvous(</span><br><span class="line"> init_method, rank, world_size, timeout=timeout</span><br><span class="line">)</span><br><span class="line">store, rank, world_size = <span class="built_in">next</span>(rendezvous_iterator)</span><br></pre></td></tr></tbody></table></figure><p>rendezvous 具体就是依据 <code>init_method</code> 来选择一个 <code>_rendezvous_handler</code>,然后 <code>_rendezvous_handler</code> 返回了 store。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">rendezvous</span>(<span class="params">url: <span class="built_in">str</span>, rank: <span class="built_in">int</span> = -<span class="number">1</span>, world_size: <span class="built_in">int</span> = -<span class="number">1</span>, **kwargs</span>):</span><br><span class="line"> <span class="comment"># Append node-specific arguments.</span></span><br><span class="line"> result = urlparse(url)</span><br><span class="line"> <span class="keyword">if</span> rank != -<span class="number">1</span> <span class="keyword">or</span> world_size != -<span class="number">1</span>:</span><br><span class="line"> query_dict: <span class="type">Dict</span>[<span class="built_in">str</span>, <span class="type">Union</span>[<span class="built_in">int</span>, <span class="built_in">str</span>]] = <span class="built_in">dict</span>(</span><br><span class="line"> <span class="comment"># mypy doesn't allow dict() to accept List of values (#257)</span></span><br><span class="line"> pair.split(<span class="string">"="</span>) <span class="keyword">for</span> pair <span class="keyword">in</span> <span class="built_in">filter</span>(<span class="literal">None</span>, result.query.split(<span class="string">"&"</span>)) <span class="comment"># type: ignore[arg-<span class="built_in">type</span>, misc]</span></span><br><span class="line"> )</span><br><span class="line"> <span class="keyword">if</span> rank != -<span class="number">1</span>:</span><br><span class="line"> query_dict[<span class="string">"rank"</span>] = rank</span><br><span class="line"> <span class="keyword">if</span> world_size != -<span class="number">1</span>:</span><br><span class="line"> query_dict[<span class="string">"world_size"</span>] = world_size</span><br><span class="line"></span><br><span class="line"> result = result._replace(</span><br><span class="line"> query=<span class="string">"{}"</span>.<span class="built_in">format</span>(<span class="string">"&"</span>.join([<span class="string">"{}={}"</span>.<span class="built_in">format</span>(k, v) <span class="keyword">for</span> k, v <span class="keyword">in</span> query_dict.items()]))</span><br><span class="line"> )</span><br><span class="line"> url = urlunparse(result)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> _rendezvous_handlers[result.scheme](url, **kwargs)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="使用-Store"><a href="#使用-Store" class="headerlink" title="使用 Store"></a><strong>使用 Store</strong></h3><p>我们继续看如何使用 store。在 <code>init_process_group</code> 代码之中,接下来就使用了 store 来初始化进程组。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">default_pg = _new_process_group_helper(</span><br><span class="line"> world_size,</span><br><span class="line"> rank,</span><br><span class="line"> [],</span><br><span class="line"> backend,</span><br><span class="line"> store,</span><br><span class="line"> pg_options=pg_options,</span><br><span class="line"> group_name=group_name,</span><br><span class="line"> timeout=timeout)</span><br><span class="line">_update_default_pg(default_pg)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="new-process-group-helper"><a href="#new-process-group-helper" class="headerlink" title="_new_process_group_helper"></a><strong>_new_process_group_helper</strong></h3><p>为了接着看 _new_process_group_helper,我们首先看看几个全局变量。以下几个变量 ProcessGroup 信息做了全局存储,比如 <code>_pg_map[pg] = (Backend.NCCL, store)</code>。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Cached process groups</span></span><br><span class="line"><span class="comment"># For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)</span></span><br><span class="line"><span class="comment"># For MPI pg, it is a map from ProcessGroup to (Backend, None)</span></span><br><span class="line">_pg_map: <span class="type">Dict</span>[ProcessGroup, <span class="type">Tuple</span>[<span class="built_in">str</span>, <span class="type">Optional</span>[Store]]] = {}</span><br><span class="line"><span class="comment"># Process group's names, map from ProcessGroup to str</span></span><br><span class="line">_pg_names: <span class="type">Dict</span>[ProcessGroup, <span class="built_in">str</span>] = {}</span><br><span class="line"><span class="comment"># Process group's global rank to local rank mapping</span></span><br><span class="line">_pg_group_ranks: <span class="type">Dict</span>[ProcessGroup, <span class="type">Dict</span>[<span class="built_in">int</span>, <span class="built_in">int</span>]] = {}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p><code>_new_process_group_helper</code> 之中得到了 store 参数之后,据此生成了一个 prefix_store,然后再根据这个 pre_store 来生成了 ProcessGroupGloo。<code>_new_process_group_helper</code> 代码具体如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">_new_process_group_helper</span>(<span class="params">world_size,</span></span><br><span class="line"><span class="params"> rank,</span></span><br><span class="line"><span class="params"> group_ranks,</span></span><br><span class="line"><span class="params"> backend,</span></span><br><span class="line"><span class="params"> store,</span></span><br><span class="line"><span class="params"> pg_options=<span class="literal">None</span>,</span></span><br><span class="line"><span class="params"> group_name=<span class="literal">None</span>,</span></span><br><span class="line"><span class="params"> timeout=default_pg_timeout</span>):</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> Create a new distributed process group.</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> This function must be called by ALL processes in the global group, even if</span></span><br><span class="line"><span class="string"> the calling process is not part of the newly created group. In that case,</span></span><br><span class="line"><span class="string"> this function returns GroupMember.NON_GROUP_MEMBER.</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> This function is called with ``group_ranks == []`` for the default group.</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="keyword">global</span> _pg_map</span><br><span class="line"> <span class="keyword">global</span> _group_count</span><br><span class="line"> <span class="keyword">global</span> _pg_names</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> group_name:</span><br><span class="line"> group_name = <span class="built_in">str</span>(_group_count)</span><br><span class="line"> _group_count += <span class="number">1</span></span><br><span class="line"></span><br><span class="line"> <span class="comment"># The list of group ranks is empty if we're creating the default group.</span></span><br><span class="line"> is_default_group = (<span class="built_in">len</span>(group_ranks) == <span class="number">0</span>)</span><br><span class="line"></span><br><span class="line"> backend = Backend(backend)</span><br><span class="line"> pg: <span class="type">Union</span>[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL]</span><br><span class="line"> <span class="keyword">if</span> backend == Backend.MPI: <span class="comment"># 没有使用store</span></span><br><span class="line"> pg = ProcessGroupMPI.create(group_ranks)</span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> pg:</span><br><span class="line"> <span class="keyword">return</span> GroupMember.NON_GROUP_MEMBER</span><br><span class="line"> _pg_map[pg] = (Backend.MPI, <span class="literal">None</span>)</span><br><span class="line"> _pg_names[pg] = group_name</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="comment"># 这里会使用store</span></span><br><span class="line"></span><br><span class="line"> <span class="comment"># If this is a subgroup (which means group_ranks is specified),</span></span><br><span class="line"> <span class="comment"># we check if the current process is a member of the new group.</span></span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> is_default_group:</span><br><span class="line"> global_rank = _get_default_group().rank()</span><br><span class="line"> <span class="keyword">if</span> global_rank <span class="keyword">not</span> <span class="keyword">in</span> group_ranks:</span><br><span class="line"> <span class="keyword">return</span> GroupMember.NON_GROUP_MEMBER</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Use the group name as prefix in the default store, such that</span></span><br><span class="line"> <span class="comment"># a single store can be reused by multiple groups.</span></span><br><span class="line"></span><br><span class="line"> prefix_store = PrefixStore(group_name, store) <span class="comment"># 构建了 PrefixStore</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> backend == Backend.GLOO:</span><br><span class="line"> pg = ProcessGroupGloo(</span><br><span class="line"> prefix_store, <span class="comment"># 使用PrefixStore构建进程组</span></span><br><span class="line"> rank,</span><br><span class="line"> world_size,</span><br><span class="line"> timeout=timeout)</span><br><span class="line"> _pg_map[pg] = (Backend.GLOO, store)</span><br><span class="line"> _pg_names[pg] = group_name</span><br><span class="line"> <span class="keyword">elif</span> backend == Backend.NCCL:</span><br><span class="line"> <span class="keyword">if</span> pg_options <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line"> <span class="keyword">assert</span> <span class="built_in">isinstance</span>(pg_options, ProcessGroupNCCL.Options), \</span><br><span class="line"> <span class="string">"Expected pg_options argument to be of type ProcessGroupNCCL.Options"</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="comment"># default pg_options for NCCL</span></span><br><span class="line"> pg_options = ProcessGroupNCCL.Options()</span><br><span class="line"> pg_options.is_high_priority_stream = <span class="literal">False</span></span><br><span class="line"> pg_options._timeout = timeout</span><br><span class="line"></span><br><span class="line"> pg = ProcessGroupNCCL(</span><br><span class="line"> prefix_store, <span class="comment"># 使用PrefixStore构建进程组</span></span><br><span class="line"> rank,</span><br><span class="line"> world_size,</span><br><span class="line"> pg_options)</span><br><span class="line"> _pg_map[pg] = (Backend.NCCL, store)</span><br><span class="line"> _pg_names[pg] = group_name</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> pg = <span class="built_in">getattr</span>(Backend, backend.upper())(</span><br><span class="line"> prefix_store,</span><br><span class="line"> rank,</span><br><span class="line"> world_size,</span><br><span class="line"> timeout)</span><br><span class="line"> _pg_map[pg] = (backend, store)</span><br><span class="line"> _pg_names[pg] = group_name</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> pg</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="ProcessGroupGloo"><a href="#ProcessGroupGloo" class="headerlink" title="ProcessGroupGloo"></a><strong>ProcessGroupGloo</strong></h3><p>在 ProcessGroupGloo 之中有具体使用,比如在PrefixStore之上生成了一个GlooStore,利用 PrefixStore 建立网络等等。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line">ProcessGroupGloo::<span class="built_in">ProcessGroupGloo</span>(</span><br><span class="line"> <span class="type">const</span> c10::intrusive_ptr<Store>& store,</span><br><span class="line"> <span class="type">int</span> rank,</span><br><span class="line"> <span class="type">int</span> size,</span><br><span class="line"> c10::intrusive_ptr<Options> options)</span><br><span class="line"> : <span class="built_in">ProcessGroup</span>(rank, size),</span><br><span class="line"> <span class="built_in">store_</span>(<span class="keyword">new</span> <span class="built_in">GlooStore</span>(store)), <span class="comment">// 在PrefixStore之上生成了一个GlooStore</span></span><br><span class="line"> <span class="built_in">options_</span>(options),</span><br><span class="line"> <span class="built_in">stop_</span>(<span class="literal">false</span>),</span><br><span class="line"> <span class="built_in">collectiveCounter_</span>(<span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">auto</span>& devices = options->devices;</span><br><span class="line"></span><br><span class="line"> contexts_.<span class="built_in">reserve</span>(options->devices.<span class="built_in">size</span>());</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">size_t</span> i = <span class="number">0</span>; i < options->devices.<span class="built_in">size</span>(); i++) {</span><br><span class="line"> <span class="keyword">auto</span> context = std::<span class="built_in">make_shared</span><::gloo::rendezvous::Context>(rank_, size_);</span><br><span class="line"> <span class="comment">// 又生成了一个PrefixStore</span></span><br><span class="line"> <span class="keyword">auto</span> store = ::gloo::rendezvous::<span class="built_in">PrefixStore</span>(std::<span class="built_in">to_string</span>(i), *store_);</span><br><span class="line"> context-><span class="built_in">setTimeout</span>(options->timeout);</span><br><span class="line"> <span class="comment">// 利用 PrefixStore 建立网络</span></span><br><span class="line"> context-><span class="built_in">connectFullMesh</span>(store, options->devices[i]);</span><br><span class="line"> contexts_.<span class="built_in">push_back</span>(std::<span class="built_in">move</span>(context));</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Every worker thread stores the AsyncWork object it's currently</span></span><br><span class="line"> <span class="comment">// working on in the workInProgress_ vector. It must have size equal</span></span><br><span class="line"> <span class="comment">// to the number of workers such that they can simply index into it</span></span><br><span class="line"> <span class="comment">// using the worker index they are started with.</span></span><br><span class="line"> workInProgress_.<span class="built_in">resize</span>(options->threads);</span><br><span class="line"></span><br><span class="line"> threads_.<span class="built_in">resize</span>(options->threads);</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">size_t</span> i = <span class="number">0</span>; i < threads_.<span class="built_in">size</span>(); i++) {</span><br><span class="line"> threads_[i] = std::<span class="built_in">thread</span>(&ProcessGroupGloo::runLoop, <span class="keyword">this</span>, i);</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>在下面代码之中,也有对<code>store_</code>的使用,比如等待,存取。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">ProcessGroupGloo::setSequenceNumberForGroup</span><span class="params">()</span> </span>{</span><br><span class="line"> <span class="keyword">if</span> (rank_ == <span class="number">0</span>) {</span><br><span class="line"> <span class="comment">// Create and broadcast sequence number</span></span><br><span class="line"> <span class="keyword">auto</span> seq = <span class="number">1</span> + <span class="built_in">rand</span>();</span><br><span class="line"> sequenceNum_ = c10d::<span class="built_in">SequenceNum</span>(seq);</span><br><span class="line"> std::vector<<span class="type">char</span>> values = c10d::<span class="built_in">toVec</span><<span class="type">char</span>>(seq, kBytes);</span><br><span class="line"> store_-><span class="built_in">set</span>(kSeqNumStoreKey, values); <span class="comment">// 存value</span></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> <span class="comment">// Read rank 0's sequence number from store.</span></span><br><span class="line"> sequenceNum_ = c10d::<span class="built_in">SequenceNum</span>();</span><br><span class="line"> store_-><span class="built_in">wait</span>({kSeqNumStoreKey}, options_->timeout); <span class="comment">// 等待</span></span><br><span class="line"> std::vector<<span class="type">char</span>> values = store_-><span class="built_in">get</span>(kSeqNumStoreKey); <span class="comment">// 取value</span></span><br><span class="line"> <span class="type">uint64_t</span> num = c10d::<span class="built_in">fromVec</span><<span class="type">char</span>>(values);</span><br><span class="line"> sequenceNum_-><span class="built_in">set</span>(num);</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="小结-1"><a href="#小结-1" class="headerlink" title="小结"></a><strong>小结</strong></h3><p>从目前分析结果来看,我们拓展结论如下:</p><ul><li><code>init_method</code> 最终还是落到了 store 之上,store才是起作用的实体。</li><li>参与的进程需要找到彼此并交换信息才能够进行通信。这个过程被称为rendezvous。</li><li>rendezvous 其实就是返回了某一种store 以供后续通信使用。</li><li>在进程组之中,会使用 store 来构建通信,等待,存取等。</li></ul><p>我们接下来选择 TCPStore进行相信分析。</p><h2 id="TCPStore"><a href="#TCPStore" class="headerlink" title="TCPStore"></a><strong>TCPStore</strong></h2><p>TCPStore 是基于 TCP 的分布式键值存储实现。服务器存储/保存数据,而存储客户端可以通过 TCP 连接到服务器存储并执行诸如<code>set()</code>插入键值对、<code>get()</code>检索键值对等操作。系统中应该有一个初始化完毕的TCPStore存储服务器,因为存储客户端将等待这个存储服务以建立连接。</p><p>TCPStore 的参数如下:</p><ul><li><code>host_name ( str )</code> – 主机名或 IP 地址。存储服务器在其上运行。</li><li><code>port ( int )</code> – 存储服务器在这个端口上侦听传入请求。</li><li><code>world_size ( int , optional )</code> – 用户总数。<ul><li><code>world_size</code> = 客户端数 + 1,1 代表服务器。</li><li>默认值为 -1(负值表示不固定的用户数)。</li></ul></li><li><code>is_master ( bool , optional )</code> – 初始化存储服务器时为真,初始化存储客户端时为假。默认值为假。</li><li><code>timeout ( timedelta , optional )</code> – store在初始化期间,以及get()和 wait()方法使用的超时时间。默认为 timedelta(seconds=300)。</li><li><code>wait_for_worker ( bool , optional )</code> – 是否等待所有worker与存储服务器连接。这仅在 world_size 为固定值时适用。默认值为真。</li></ul><p>使用例子如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch.distributed <span class="keyword">as</span> dist</span><br><span class="line"><span class="keyword">from</span> datetime <span class="keyword">import</span> timedelta</span><br><span class="line"><span class="comment"># Run on process 1 (server)</span></span><br><span class="line">server_store = dist.TCPStore(<span class="string">"127.0.0.1"</span>, <span class="number">1234</span>, <span class="number">2</span>, <span class="literal">True</span>, timedelta(seconds=<span class="number">30</span>))</span><br><span class="line"><span class="comment"># Run on process 2 (client)</span></span><br><span class="line">client_store = dist.TCPStore(<span class="string">"127.0.0.1"</span>, <span class="number">1234</span>, <span class="number">2</span>, <span class="literal">False</span>)</span><br><span class="line"><span class="comment"># Use any of the store methods from either the client or server after initialization</span></span><br><span class="line">server_store.<span class="built_in">set</span>(<span class="string">"first_key"</span>, <span class="string">"first_value"</span>)</span><br><span class="line">client_store.get(<span class="string">"first_key"</span>)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>或者</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">>>> import torch.distributed as dist</span><br><span class="line">>>> from datetime import timedelta</span><br><span class="line">>>> <span class="comment"># Using TCPStore as an example, other store types can also be used</span></span><br><span class="line">>>> store = dist.TCPStore(<span class="string">"127.0.0.1"</span>, 0, 1, True, timedelta(seconds=30))</span><br><span class="line">>>> <span class="comment"># This will throw an exception after 10 seconds</span></span><br><span class="line">>>> store.wait([<span class="string">"bad_key"</span>], timedelta(seconds=10))</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>从例子上看,就是简单的 <code>server</code>,<code>client</code> 或者说 <code>master</code>, <code>worker</code> 的关系,我们接下来仔细分析。</p><h3 id="TCPStore-in-python"><a href="#TCPStore-in-python" class="headerlink" title="TCPStore in python"></a><strong>TCPStore in python</strong></h3><p>在 Python 世界之中,就是简单的设定了 host 和 port。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">TCPStore</span>(<span class="title class_ inherited__">Store</span>):</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, host_name, port, world_size=-<span class="number">1</span>, is_master=<span class="literal">False</span>, timeout=<span class="literal">None</span>, *args, **kwargs</span>): <span class="comment"># real signature unknown; <span class="doctag">NOTE:</span> unreliably restored from __doc__</span></span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line"> host = <span class="built_in">property</span>(<span class="keyword">lambda</span> self: <span class="built_in">object</span>(), <span class="keyword">lambda</span> self, v: <span class="literal">None</span>, <span class="keyword">lambda</span> self: <span class="literal">None</span>) <span class="comment"># default</span></span><br><span class="line"> <span class="string">"""Gets the hostname on which the store listens for requests."""</span></span><br><span class="line"></span><br><span class="line"> port = <span class="built_in">property</span>(<span class="keyword">lambda</span> self: <span class="built_in">object</span>(), <span class="keyword">lambda</span> self, v: <span class="literal">None</span>, <span class="keyword">lambda</span> self: <span class="literal">None</span>) <span class="comment"># default</span></span><br><span class="line"> <span class="string">"""Gets the port number on which the store listens for requests."""</span></span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>我们需要深入到 C++ 世界看看。</p><h3 id="TCPStore-in-CPP"><a href="#TCPStore-in-CPP" class="headerlink" title="TCPStore in CPP"></a><strong>TCPStore in CPP</strong></h3><h3 id="API接口"><a href="#API接口" class="headerlink" title="API接口"></a><strong>API接口</strong></h3><p>首先,C++之中的 TCPStore 可以认为是一个API接口,其定义如下:</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">TCPStore</span> : <span class="keyword">public</span> Store {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="function"><span class="keyword">explicit</span> <span class="title">TCPStore</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::string& masterAddr,</span></span></span><br><span class="line"><span class="params"><span class="function"> PortType masterPort,</span></span></span><br><span class="line"><span class="params"><span class="function"> c10::optional<<span class="type">int</span>> numWorkers = c10::<span class="type">nullopt_t</span>(<span class="number">-1</span>),</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">bool</span> isServer = <span class="literal">false</span>,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::chrono::milliseconds& timeout = kDefaultTimeout,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">bool</span> waitWorkers = <span class="literal">true</span>)</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">virtual</span> ~<span class="built_in">TCPStore</span>();</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">set</span><span class="params">(<span class="type">const</span> std::string& key, <span class="type">const</span> std::vector<<span class="type">uint8_t</span>>& value)</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="function">std::vector<<span class="type">uint8_t</span>> <span class="title">compareSet</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::string& key,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<<span class="type">uint8_t</span>>& expectedValue,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<<span class="type">uint8_t</span>>& desiredValue)</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="function">std::vector<<span class="type">uint8_t</span>> <span class="title">get</span><span class="params">(<span class="type">const</span> std::string& key)</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="function"><span class="type">int64_t</span> <span class="title">add</span><span class="params">(<span class="type">const</span> std::string& key, <span class="type">int64_t</span> value)</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="function"><span class="type">bool</span> <span class="title">deleteKey</span><span class="params">(<span class="type">const</span> std::string& key)</span> <span class="keyword">override</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="comment">// <span class="doctag">NOTE:</span> calling other TCPStore APIs inside the callback is NOT threadsafe</span></span><br><span class="line"> <span class="comment">// watchKey() is a blocking operation. It will register the socket on</span></span><br><span class="line"> <span class="comment">// TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will</span></span><br><span class="line"> <span class="comment">// return once it has verified the callback is registered on both background</span></span><br><span class="line"> <span class="comment">// threads. Only one thread can call watchKey() at a time.</span></span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">watchKey</span><span class="params">(<span class="type">const</span> std::string& key, WatchKeyCallback callback)</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="function"><span class="type">bool</span> <span class="title">check</span><span class="params">(<span class="type">const</span> std::vector<std::string>& keys)</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="function"><span class="type">int64_t</span> <span class="title">getNumKeys</span><span class="params">()</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">wait</span><span class="params">(<span class="type">const</span> std::vector<std::string>& keys)</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">wait</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<std::string>& keys,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::chrono::milliseconds& timeout)</span> <span class="keyword">override</span></span>;</span><br><span class="line"> <span class="comment">// Waits for all workers to join.</span></span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">waitForWorkers</span><span class="params">()</span></span>;</span><br><span class="line"> <span class="comment">// Returns the hostname used by the TCPStore.</span></span><br><span class="line"> <span class="function"><span class="type">const</span> std::string& <span class="title">getHost</span><span class="params">()</span> <span class="type">const</span> <span class="keyword">noexcept</span></span>;</span><br><span class="line"> <span class="comment">// Returns the port used by the TCPStore.</span></span><br><span class="line"> <span class="function">PortType <span class="title">getPort</span><span class="params">()</span> <span class="type">const</span> <span class="keyword">noexcept</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">private</span>:</span><br><span class="line"> <span class="function"><span class="type">int64_t</span> <span class="title">addHelper_</span><span class="params">(<span class="type">const</span> std::string& key, <span class="type">int64_t</span> value)</span></span>;</span><br><span class="line"> <span class="function">std::vector<<span class="type">uint8_t</span>> <span class="title">getHelper_</span><span class="params">(<span class="type">const</span> std::string& key)</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">waitHelper_</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::vector<std::string>& keys,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::chrono::milliseconds& timeout)</span></span>;</span><br><span class="line"></span><br><span class="line"> std::mutex watchKeyMutex_;</span><br><span class="line"> <span class="type">bool</span> isServer_;</span><br><span class="line"> <span class="type">int</span> storeSocket_ = <span class="number">-1</span>; <span class="comment">//</span></span><br><span class="line"> <span class="type">int</span> listenSocket_ = <span class="number">-1</span>; <span class="comment">//</span></span><br><span class="line"> <span class="type">int</span> masterListenSocket_ = <span class="number">-1</span>; <span class="comment">// master 在这里监听</span></span><br><span class="line"></span><br><span class="line"> std::string tcpStoreAddr_;</span><br><span class="line"> PortType tcpStorePort_;</span><br><span class="line"></span><br><span class="line"> c10::optional<<span class="type">int</span>> numWorkers_;</span><br><span class="line"> <span class="type">const</span> std::string initKey_;</span><br><span class="line"> <span class="type">const</span> std::string regularPrefix_;</span><br><span class="line"></span><br><span class="line"> std::unique_ptr<TCPStoreMasterDaemon> tcpStoreMasterDaemon_ = <span class="literal">nullptr</span>;</span><br><span class="line"> std::unique_ptr<TCPStoreWorkerDaemon> tcpStoreWorkerDaemon_ = <span class="literal">nullptr</span>;</span><br><span class="line">};</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="socket用处"><a href="#socket用处" class="headerlink" title="socket用处"></a><strong>socket用处</strong></h3><p>其成员变量之中最主要的是三个socket,或者说他们是 store 的精华(难点)所在。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">int</span> storeSocket_ = <span class="number">-1</span>; <span class="comment">//</span></span><br><span class="line"><span class="type">int</span> listenSocket_ = <span class="number">-1</span>; <span class="comment">//</span></span><br><span class="line"><span class="type">int</span> masterListenSocket_ = <span class="number">-1</span>; <span class="comment">// master 在这里监听</span></span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="业务分工"><a href="#业务分工" class="headerlink" title="业务分工"></a><strong>业务分工</strong></h3><p>具体解释如下(后面还会结合代码继续分析):</p><ul><li><code>masterListenSocket_</code> 是 listen 在 <code>masterPort</code> 之上。<ul><li><code>tcpStoreMasterDaemon_</code>本身是一个master,就是为整个 TCPStore提供服务的 server。</li><li><code>tcpStoreMasterDaemon_</code> 使用 <code>tcputil::addPollfd(fds, storeListenSocket_, POLLIN)</code> 来监听 <code>masterListenSocket_</code>。</li><li><code>key-value</code> 就是<code>std::unordered_map<std::string, std::vector<uint8_t>> tcpStore</code>。</li></ul></li><li><code>storeSocket_</code> 在 <code>tcpStoreWorkerDaemon_</code> 之上,其连接到 <code>masterListenSocket_</code> : <code>masterPort</code> 之上。<ul><li><code>storeSocket_</code> 的作用是封装面对 master port 的操作,用户只管 <code>set</code>,<code>get</code> 等操作,不用知道 master port。</li><li><code>set(key, data)</code> 的作用就是通过 <code>storeSocket_</code> 向master 发送一个设置<code>key : value</code> 的请求。</li><li><code>tcpStoreMasterDaemon_</code> 监听到socket变化,就开始相应。</li><li><code>tcpStoreMasterDaemon_</code> 内部把 <code>key : value</code> 添加到 <code>std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_</code> 之上。</li></ul></li><li><code>listenSocket_</code> 在 <code>tcpStoreWorkerDaemon_</code> 之上,也连接到 <code>masterListenSocket_</code>: <code>masterPort</code> 之上。下面有一个解耦,如注释所述,<code>It will register the socket on TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon</code>。<ul><li><code>listenSocket_</code> 封装了对 <code>watchKey</code> 的处理。Store Client 使用<code>watchKey(const std::string& key, WatchKeyCallback callback)</code> 请求注册,即:<ul><li><strong>Worker 请求注册</strong>。使用 <code>tcpStoreWorkerDaemon_->setCallback(regKey, callback)</code> 来为 <code>tcpStoreWorkerDaemon_</code> 的 <code>std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_</code> 之上添加一个 callback。</li><li><strong>Worker 发送请求</strong>。通过 <code>listenSocket_</code> 给 master 发消息 (key, WATCH_KEY),告诉master,如果 key 的 value 有变化,就调用这个 callback。</li></ul></li><li><strong>Master 执行注册</strong>。Master 接到 WATCH_KEY 消息之后进行注册,调用 watchHandler,使用 <code>watchedSockets_[key].push_back(socket)</code> 来配置,告诉自己,如果这个 key 有变化,就给这个 socket 发消息。</li><li><strong>Master通知Worker</strong>。在 <code>TCPStoreMasterDaemon::setHandler</code> 之中,如果设置了新 value 之后,调用 <code>sendKeyUpdatesToClients</code>,其会遍历 <code>watchedSockets_[key]</code>,如果有 socket,就给 socket 发送消息变化通知。</li><li><strong>Worker执行callback</strong>。所以如果 key 有变化,就在 <code>tcpStoreWorkerDaemon_</code> 之中调用了这个 callback。</li></ul></li></ul><h3 id="Set-例子"><a href="#Set-例子" class="headerlink" title="Set 例子"></a><strong>Set 例子</strong></h3><p>我们首先看看 Set 的例子如下,就是 Worker 通过 socket 来在 Master 之上设置 value。</p><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"> +</span><br><span class="line">+----------------------------------------------------------------------+ | +----------------------------------------------+</span><br><span class="line">| TCPStore Master | | | TCPStore Worker |</span><br><span class="line">| | | | |</span><br><span class="line">| | | | |</span><br><span class="line">| | | | |</span><br><span class="line">| +------------------------------------------------------------+ | | | |</span><br><span class="line">| | TcpStoreMasterDaemon_ MasterPort| | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | TCPStore.masterListenSocket_ | | | | +---------------------------------+ |</span><br><span class="line">| | | | | | | set(key, value) | |</span><br><span class="line">| | | | | | | | |</span><br><span class="line">| | tcpStore_[key] = value <------------------------------------------------+ | storeSocket_ | |</span><br><span class="line">| | | | | | | | |</span><br><span class="line">| | | | | | +---------------------------------+ |</span><br><span class="line">| | | | | | |</span><br><span class="line">| +------------------------------------------------------------+ | | | |</span><br><span class="line">| | | | |</span><br><span class="line">+----------------------------------------------------------------------+ | +----------------------------------------------+</span><br><span class="line"> +</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>图片如下:</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114220923634-1185814714.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114220923634-1185814714.png"></p><h3 id="Set-和-watchKey-结合"><a href="#Set-和-watchKey-结合" class="headerlink" title="Set 和 watchKey 结合"></a><strong>Set 和 watchKey 结合</strong></h3><p>Set 和 watchKey 结合起来的示意图如下(worker请求注册,具体执行回调;master执行注册,通知worker执行回调):</p><ol><li><strong>Worker 请求注册</strong>。Store Client 使用<code>watchKey(const std::string& key, WatchKeyCallback callback)</code> 就是使用 <code>tcpStoreWorkerDaemon_->setCallback(regKey, callback)</code> 来为 <code>tcpStoreWorkerDaemon_</code> 的 <code>std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_</code> 之上添加一个callback。</li><li><strong>Worker 发送请求</strong>。Worker 通过 <code>listenSocket_</code> 给 master 发消息 (key, WATCH_KEY),告诉master,如果 key 的 value 有变化,就调用这个 callback。</li><li><strong>Master 执行注册</strong>。Master 接到 WATCH_KEY 消息之后,调用 watchHandler,使用 <code>watchedSockets_[key].push_back(socket)</code> 来配置,告诉自己,如果这个 key 有变化,就给这个 socket 发消息。</li><li>下面我们假设 Store Client(这里假设是同一个worker设置,实际上可能是不同worker)设置了一个 value。</li><li><strong>Master通知Worker</strong>。Master 在 <code>TCPStoreMasterDaemon::setHandler</code> 之中,如果设置了新 value 之后,调用 sendKeyUpdatesToClients,其会遍历 <code>watchedSockets_[key]</code>,如果有 socket,就给 socket 发送消息变化通知。</li><li><strong>Worker执行callback</strong>。如果 key 有变化,就在 <code>tcpStoreWorkerDaemon_</code> 之中调用了这个 callback。</li></ol><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line">+----------------------------------------------------------------------+ + +------------------------------------------------------------------------+</span><br><span class="line">| TCPStore Master | | | TCPStore Worker |</span><br><span class="line">| | | | |</span><br><span class="line">| +------------------------------------------------------------+ | | | |</span><br><span class="line">| | TcpStoreMasterDaemon_ MasterPort| | | | +---------------------------------+ |</span><br><span class="line">| | | | | | | | |</span><br><span class="line">| | 2 | | | | | watchKey(key, callback) +----------------------+ |</span><br><span class="line">| | TCPStore.masterListenSocket_ <----------------------------------+ | | | |</span><br><span class="line">| | + | | | | | listenSocket_ | | |</span><br><span class="line">| | | 3 | | | | | | 1 | |</span><br><span class="line">| | v | | | | | | | |</span><br><span class="line">| | watchedSockets_[key] = socket | | | | +---------------------------------+ | |</span><br><span class="line">| | | | | | | |</span><br><span class="line">| | +-------------------------------------------------+ | | | | | |</span><br><span class="line">| | | | | | | | | |</span><br><span class="line">| | | setHandler | | | | | +----------------------------------------------------------------+ |</span><br><span class="line">| | | | | | | | | TCPStoreWorkerDaemon | | |</span><br><span class="line">| | | | | | | | | v | |</span><br><span class="line">| | | tcpStore_[key] = newData | | | | | | unordered_map<string, WatchKeyCallback> keyToCallbacks_ | |</span><br><span class="line">| | | + | | | | | | | |</span><br><span class="line">| | | | | | | | | | TCPStore.listenSocket_ | |</span><br><span class="line">| | | | | | | | | | | |</span><br><span class="line">| | | v | | | | | | +----------------------------------------------------------+ | |</span><br><span class="line">| | | sendKeyUpdatesToClients | | | | | | | run | | |</span><br><span class="line">| | | + | | 5 | | | | | | | |</span><br><span class="line">| | | | | +---------------------->+ 6 | | |</span><br><span class="line">| | | | | | | | | | | | callbackHandler +-----> keyToCallbacks_(callback) | | |</span><br><span class="line">| | | v | | | | | | | | | | |</span><br><span class="line">| | | | | | | | | | +----------------------------------------------------------+ | |</span><br><span class="line">| | | for (int socket : watchedSockets_[key]){ | | | | | | +----------------------------------------------------------------+ |</span><br><span class="line">| | | tcputil::sendString(socket, key, true) +-----+ | | | | |</span><br><span class="line">| | | } | | | | | |</span><br><span class="line">| | | | | | | | +------------------------+ |</span><br><span class="line">| | | | | 4 | | | | set(key, newData) | |</span><br><span class="line">| | | | <-----------------------+ | | |</span><br><span class="line">| | +-------------------------------------------------+ | | | | | | |</span><br><span class="line">| | | | | | +------------------------+ |</span><br><span class="line">| +------------------------------------------------------------+ | | | |</span><br><span class="line">| | | | |</span><br><span class="line">+----------------------------------------------------------------------+ + +------------------------------------------------------------------------+</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>图片如下:</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114220956117-265439998.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114220956117-265439998.png"></p><h3 id="功能函数"><a href="#功能函数" class="headerlink" title="功能函数"></a><strong>功能函数</strong></h3><p>TCPStore 提供了若干功能函数。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStore::set</span><span class="params">(<span class="type">const</span> std::string& key, <span class="type">const</span> std::vector<<span class="type">uint8_t</span>>& data)</span> </span>{</span><br><span class="line"> std::string regKey = regularPrefix_ + key;</span><br><span class="line"> tcputil::<span class="built_in">sendValue</span><QueryType>(storeSocket_, QueryType::SET);</span><br><span class="line"> tcputil::<span class="built_in">sendString</span>(storeSocket_, regKey, <span class="literal">true</span>);</span><br><span class="line"> tcputil::<span class="built_in">sendVector</span><<span class="type">uint8_t</span>>(storeSocket_, data);</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="function">std::vector<<span class="type">uint8_t</span>> <span class="title">TCPStore::get</span><span class="params">(<span class="type">const</span> std::string& key)</span> </span>{</span><br><span class="line"> std::string regKey = regularPrefix_ + key;</span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">getHelper_</span>(regKey);</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">int64_t</span> <span class="title">TCPStore::add</span><span class="params">(<span class="type">const</span> std::string& key, <span class="type">int64_t</span> value)</span> </span>{</span><br><span class="line"> std::string regKey = regularPrefix_ + key;</span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">addHelper_</span>(regKey, value);</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">int64_t</span> <span class="title">TCPStore::addHelper_</span><span class="params">(<span class="type">const</span> std::string& key, <span class="type">int64_t</span> value)</span> </span>{</span><br><span class="line"> tcputil::<span class="built_in">sendValue</span><QueryType>(storeSocket_, QueryType::ADD);</span><br><span class="line"> tcputil::<span class="built_in">sendString</span>(storeSocket_, key, <span class="literal">true</span>);</span><br><span class="line"> tcputil::<span class="built_in">sendValue</span><<span class="type">int64_t</span>>(storeSocket_, value);</span><br><span class="line"> <span class="keyword">return</span> tcputil::<span class="built_in">recvValue</span><<span class="type">int64_t</span>>(storeSocket_);</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>这些功能函数是调用如下基础函数来发送接收。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// this is only for convenience when sending rvalues</span></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> T></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">sendValue</span><span class="params">(<span class="type">int</span> socket, <span class="type">const</span> T& value, <span class="type">bool</span> moreData = <span class="literal">false</span>)</span> </span>{</span><br><span class="line"> <span class="built_in">sendBytes</span><T>(socket, &value, <span class="number">1</span>, moreData);</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="keyword">template</span> <<span class="keyword">typename</span> T></span><br><span class="line"><span class="function">T <span class="title">recvValue</span><span class="params">(<span class="type">int</span> socket)</span> </span>{</span><br><span class="line"> T value;</span><br><span class="line"> <span class="built_in">recvBytes</span><T>(socket, &value, <span class="number">1</span>);</span><br><span class="line"> <span class="keyword">return</span> value;</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="构建函数"><a href="#构建函数" class="headerlink" title="构建函数"></a><strong>构建函数</strong></h3><p>我们从构建函数可以看到:</p><ul><li>对于存储服务器角色,主要就是启动了 <code>tcpStoreMasterDaemon_</code>,注意在启动了 daemon 之后,server 就进入了等待worker状态,**不会启动接下来代码中的 <code>tcpStoreWorkerDaemon_</code>**。</li><li>对于存储客户端,则启动了 <code>tcpStoreWorkerDaemon_</code>。</li></ul><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// TCPStore class methods</span></span><br><span class="line">TCPStore::<span class="built_in">TCPStore</span>(</span><br><span class="line"> <span class="type">const</span> std::string& masterAddr,</span><br><span class="line"> PortType masterPort,</span><br><span class="line"> c10::optional<<span class="type">int</span>> numWorkers,</span><br><span class="line"> <span class="type">bool</span> isServer,</span><br><span class="line"> <span class="type">const</span> std::chrono::milliseconds& timeout,</span><br><span class="line"> <span class="type">bool</span> waitWorkers)</span><br><span class="line"> : <span class="built_in">Store</span>(timeout),</span><br><span class="line"> <span class="built_in">isServer_</span>(isServer),</span><br><span class="line"> <span class="built_in">tcpStoreAddr_</span>(masterAddr),</span><br><span class="line"> <span class="built_in">tcpStorePort_</span>(masterPort),</span><br><span class="line"> <span class="built_in">numWorkers_</span>(numWorkers),</span><br><span class="line"> <span class="built_in">initKey_</span>(<span class="string">"init/"</span>),</span><br><span class="line"> <span class="built_in">regularPrefix_</span>(<span class="string">"/"</span>) {</span><br><span class="line"> tcputil::<span class="built_in">socketInitialize</span>();</span><br><span class="line"> <span class="keyword">if</span> (isServer_) { <span class="comment">// 如果设置了是server,就在masterPort上监听</span></span><br><span class="line"> <span class="comment">// Opening up the listening socket</span></span><br><span class="line"> std::<span class="built_in">tie</span>(masterListenSocket_, tcpStorePort_) = tcputil::<span class="built_in">listen</span>(masterPort);</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">try</span> {</span><br><span class="line"> <span class="keyword">if</span> (isServer_) { <span class="comment">// 如果设置了是server,就启动 tcpStoreMasterDaemon_</span></span><br><span class="line"> <span class="comment">// Now start the daemon</span></span><br><span class="line"> tcpStoreMasterDaemon_ =</span><br><span class="line"> std::<span class="built_in">make_unique</span><TCPStoreMasterDaemon>(masterListenSocket_);</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">// Connect to the daemon</span></span><br><span class="line"> <span class="comment">// worker 会与 master port 建立联系</span></span><br><span class="line"> storeSocket_ = tcputil::<span class="built_in">connect</span>(</span><br><span class="line"> tcpStoreAddr_, tcpStorePort_, <span class="comment">/* wait= */</span> <span class="literal">true</span>, timeout_);</span><br><span class="line"> <span class="keyword">if</span> (numWorkers.<span class="built_in">value_or</span>(<span class="number">-1</span>) >= <span class="number">0</span> && waitWorkers) {</span><br><span class="line"> <span class="built_in">waitForWorkers</span>(); <span class="comment">// server 等待 worker</span></span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// socket to handle requests from server,因为 master 也会给 worker 发消息</span></span><br><span class="line"> listenSocket_ = tcputil::<span class="built_in">connect</span>(</span><br><span class="line"> tcpStoreAddr_, tcpStorePort_, <span class="comment">/* wait= */</span> <span class="literal">true</span>, timeout_);</span><br><span class="line"> <span class="comment">// 启动 worker daemon</span></span><br><span class="line"> tcpStoreWorkerDaemon_ =</span><br><span class="line"> std::<span class="built_in">make_unique</span><TCPStoreWorkerDaemon>(listenSocket_);</span><br><span class="line"> } <span class="built_in">catch</span> (<span class="type">const</span> std::exception&) {</span><br><span class="line"> <span class="keyword">if</span> (isServer_) {</span><br><span class="line"> tcpStoreMasterDaemon_ = <span class="literal">nullptr</span>;</span><br><span class="line"> tcputil::<span class="built_in">closeSocket</span>(masterListenSocket_);</span><br><span class="line"> }</span><br><span class="line"> tcpStoreWorkerDaemon_ = <span class="literal">nullptr</span>;</span><br><span class="line"> <span class="keyword">if</span> (listenSocket_ != <span class="number">-1</span>) {</span><br><span class="line"> tcputil::<span class="built_in">closeSocket</span>(listenSocket_);</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">if</span> (storeSocket_ != <span class="number">-1</span>) {</span><br><span class="line"> tcputil::<span class="built_in">closeSocket</span>(storeSocket_);</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">throw</span>;</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>server 会使用如下函数来等待 worker.</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStore::waitForWorkers</span><span class="params">()</span> </span>{</span><br><span class="line"> <span class="built_in">addHelper_</span>(initKey_, <span class="number">1</span>);</span><br><span class="line"> <span class="comment">// Let server block until all workers have completed, this ensures that</span></span><br><span class="line"> <span class="comment">// the server daemon thread is always running until the very end</span></span><br><span class="line"> <span class="keyword">if</span> (isServer_) {</span><br><span class="line"> <span class="type">const</span> <span class="keyword">auto</span> start = std::chrono::steady_clock::<span class="built_in">now</span>();</span><br><span class="line"> <span class="keyword">while</span> (<span class="literal">true</span>) {</span><br><span class="line"> std::vector<<span class="type">uint8_t</span>> value = <span class="built_in">getHelper_</span>(initKey_);</span><br><span class="line"> <span class="keyword">auto</span> buf = <span class="built_in">reinterpret_cast</span><<span class="type">const</span> <span class="type">char</span>*>(value.<span class="built_in">data</span>());</span><br><span class="line"> <span class="keyword">auto</span> len = value.<span class="built_in">size</span>();</span><br><span class="line"> <span class="type">int</span> numWorkersCompleted = std::<span class="built_in">stoi</span>(std::<span class="built_in">string</span>(buf, len));</span><br><span class="line"> <span class="keyword">if</span> (numWorkersCompleted >= numWorkers_.<span class="built_in">value_or</span>(<span class="number">-1</span>)) {</span><br><span class="line"> <span class="keyword">break</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="type">const</span> <span class="keyword">auto</span> elapsed = std::chrono::<span class="built_in">duration_cast</span><std::chrono::seconds>(</span><br><span class="line"> std::chrono::steady_clock::<span class="built_in">now</span>() - start);</span><br><span class="line"> <span class="keyword">if</span> (timeout_ != kNoTimeout && elapsed > timeout_) {</span><br><span class="line"> <span class="keyword">break</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="comment">/* sleep override */</span></span><br><span class="line"> std::this_thread::<span class="built_in">sleep_for</span>(std::chrono::<span class="built_in">milliseconds</span>(<span class="number">10</span>));</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="TCPStoreWorkerDaemon"><a href="#TCPStoreWorkerDaemon" class="headerlink" title="TCPStoreWorkerDaemon"></a><strong>TCPStoreWorkerDaemon</strong></h3><p>这个 daemon 进程只是用来处理 watchKey。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// Separate thread that is launched on all instances (including master)</span></span><br><span class="line"><span class="comment">// Right now only handles callbacks registered from watchKey()</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">TCPStoreWorkerDaemon</span> : <span class="keyword">public</span> BackgroundThread {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="function"><span class="keyword">explicit</span> <span class="title">TCPStoreWorkerDaemon</span><span class="params">(<span class="type">int</span> listenSocket)</span></span>;</span><br><span class="line"> <span class="comment">// Set the callback to run key change</span></span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">setCallback</span><span class="params">(std::string key, WatchKeyCallback cb)</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">waitForCallbackRegistration</span><span class="params">()</span> </span>{</span><br><span class="line"> <span class="comment">// Block until callback has been registered successfully</span></span><br><span class="line"> <span class="function">std::unique_lock<std::mutex> <span class="title">callbackRegistrationLock</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> callbackRegistrationMutex_)</span></span>;</span><br><span class="line"> callbackRegisteredCV_.<span class="built_in">wait</span>(</span><br><span class="line"> callbackRegistrationLock, [&] { <span class="keyword">return</span> callbackRegisteredData_; });</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Reset payload for next callback</span></span><br><span class="line"> callbackRegisteredData_ = <span class="literal">false</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">setCallbackRegistered</span><span class="params">()</span> </span>{</span><br><span class="line"> callbackRegisteredData_ = <span class="literal">true</span>;</span><br><span class="line"> callbackRegisteredCV_.<span class="built_in">notify_one</span>();</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="keyword">private</span>:</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">run</span><span class="params">()</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">callbackHandler</span><span class="params">(<span class="type">int</span> socket)</span></span>;</span><br><span class="line"> <span class="comment">// List of callbacks map each watched key</span></span><br><span class="line"> std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_;</span><br><span class="line"> std::mutex keyToCallbacksMutex_;</span><br><span class="line"> std::mutex callbackRegistrationMutex_;</span><br><span class="line"> std::condition_variable callbackRegisteredCV_;</span><br><span class="line"> <span class="type">bool</span> callbackRegisteredData_ = <span class="literal">false</span>;</span><br><span class="line">};</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>其构建函数只是建立一个线程。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// TCPStoreListener class methods</span></span><br><span class="line">TCPStoreWorkerDaemon::<span class="built_in">TCPStoreWorkerDaemon</span>(<span class="type">int</span> listenSocket)</span><br><span class="line"> : <span class="built_in">BackgroundThread</span>(listenSocket) {</span><br><span class="line"> daemonThread_ = std::<span class="built_in">thread</span>(&TCPStoreWorkerDaemon::run, <span class="keyword">this</span>);</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="watchKey"><a href="#watchKey" class="headerlink" title="watchKey"></a><strong>watchKey</strong></h3><p>Client Store 使用<code>watchKey(const std::string& key, WatchKeyCallback callback)</code> 的作用是往master注册监听key:</p><ul><li><strong>Worker 请求注册</strong>。使用 <code>tcpStoreWorkerDaemon_->setCallback(regKey, callback)</code> 来为 <code>tcpStoreWorkerDaemon_</code> 的 <code>std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_</code> 之上添加一个 callback。</li><li><strong>Worker 发送请求</strong>。通过 <code>listenSocket_</code> 给 master 发消息 (key, WATCH_KEY),告诉master,如果 key 的 value 有变化,就调用这个 callback。</li><li>然后使用 waitForCallbackRegistration 等待注册完成。</li></ul><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStore::watchKey</span><span class="params">(<span class="type">const</span> std::string& key, WatchKeyCallback callback)</span> </span>{</span><br><span class="line"> <span class="comment">// Only allow one thread to perform watchKey() at a time</span></span><br><span class="line"> <span class="function"><span class="type">const</span> std::lock_guard<std::mutex> <span class="title">watchKeyLock</span><span class="params">(watchKeyMutex_)</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Register callback with TCPStoreMasterDaemon to call TCPStoreWorkerDaemon on</span></span><br><span class="line"> <span class="comment">// key change</span></span><br><span class="line"> std::string regKey = regularPrefix_ + key;</span><br><span class="line"> tcpStoreWorkerDaemon_-><span class="built_in">setCallback</span>(regKey, callback);</span><br><span class="line"> tcputil::<span class="built_in">sendValue</span><QueryType>(listenSocket_, QueryType::WATCH_KEY);</span><br><span class="line"> tcputil::<span class="built_in">sendString</span>(listenSocket_, regKey);</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Block until callback has been registered successfully</span></span><br><span class="line"> tcpStoreWorkerDaemon_-><span class="built_in">waitForCallbackRegistration</span>();</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="运行"><a href="#运行" class="headerlink" title="运行"></a><strong>运行</strong></h3><p>其运行分为 windows 和 其他系统,但是主要就是收到了业务key,然后进行相关业务处理。</p><ul><li><strong>Master 执行注册</strong>。Master 接到 WATCH_KEY 消息之后,调用 watchHandler,使用 <code>watchedSockets_[key].push_back(socket)</code> 来配置,告诉自己,如果这个 key 有变化,就给这个 socket 发消息。</li><li><strong>Master通知Worker</strong>。在 <code>TCPStoreMasterDaemon::setHandler</code> 之中,如果设置了新 value 之后,调用 sendKeyUpdatesToClients,其会遍历 <code>watchedSockets_[key]</code>,如果有 socket,就给 socket 发送消息变化通知。</li><li><strong>Worker执行callback</strong>。所以如果 key 有变化,就在 <code>tcpStoreWorkerDaemon_</code> 之中调用了这个 callback。</li></ul><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">ifdef</span> _WIN32</span></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStoreWorkerDaemon::run</span><span class="params">()</span> </span>{ <span class="comment">// 这里是windows系统</span></span><br><span class="line"> std::vector<<span class="keyword">struct</span> pollfd> fds;</span><br><span class="line"> tcputil::<span class="built_in">addPollfd</span>(fds, storeListenSocket_, POLLIN);</span><br><span class="line"></span><br><span class="line"> <span class="keyword">while</span> (<span class="literal">true</span>) {</span><br><span class="line"> <span class="comment">// Check control and exit early if triggered</span></span><br><span class="line"> <span class="type">int</span> res;</span><br><span class="line"> <span class="built_in">SYSCHECK_ERR_RETURN_NEG1</span>(</span><br><span class="line"> res = <span class="built_in">WSAPoll</span>(fds.<span class="built_in">data</span>(), fds.<span class="built_in">size</span>(), checkTimeout_.<span class="built_in">count</span>()))</span><br><span class="line"> <span class="keyword">if</span> (res == <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">auto</span> rvPoll = <span class="built_in">WaitForSingleObject</span>(ghStopEvent_, <span class="number">0</span>);</span><br><span class="line"> <span class="keyword">if</span> (rvPoll != WAIT_TIMEOUT) {</span><br><span class="line"> <span class="keyword">break</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">continue</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// if connection is closed gracefully by master, peeked data will return 0</span></span><br><span class="line"> <span class="type">char</span> data;</span><br><span class="line"> <span class="type">int</span> ret = <span class="built_in">recv</span>(fds[<span class="number">0</span>].fd, &data, <span class="number">1</span>, MSG_PEEK);</span><br><span class="line"> <span class="keyword">if</span> (ret == <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">auto</span> rvData = <span class="built_in">WaitForSingleObject</span>(ghStopEvent_, <span class="number">0</span>);</span><br><span class="line"> <span class="keyword">if</span> (rvData != WAIT_TIMEOUT) {</span><br><span class="line"> <span class="keyword">break</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">continue</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// valid request, perform callback logic</span></span><br><span class="line"> <span class="built_in">callbackHandler</span>(fds[<span class="number">0</span>].fd); <span class="comment">// 业务处理</span></span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"><span class="meta">#elsevoid TCPStoreWorkerDaemon::run() {</span></span><br><span class="line"> std::vector<<span class="keyword">struct</span> pollfd> fds;</span><br><span class="line"> tcputil::<span class="built_in">addPollfd</span>(fds, controlPipeFd_[<span class="number">0</span>], POLLHUP);</span><br><span class="line"> tcputil::<span class="built_in">addPollfd</span>(fds, storeListenSocket_, POLLIN);</span><br><span class="line"></span><br><span class="line"> <span class="keyword">while</span> (<span class="literal">true</span>) {</span><br><span class="line"> <span class="built_in">SYSCHECK_ERR_RETURN_NEG1</span>(::<span class="built_in">poll</span>(fds.<span class="built_in">data</span>(), fds.<span class="built_in">size</span>(), <span class="number">-1</span>));</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Check control and exit early if triggered</span></span><br><span class="line"> <span class="comment">// The pipe receives an event which tells us to shutdown the listener thread</span></span><br><span class="line"> <span class="keyword">if</span> (fds[<span class="number">0</span>].revents != <span class="number">0</span>) {</span><br><span class="line"> <span class="comment">// Will be POLLUP when the pipe is closed</span></span><br><span class="line"> <span class="keyword">if</span> (fds[<span class="number">0</span>].revents ^ POLLHUP) {</span><br><span class="line"> <span class="keyword">throw</span> std::<span class="built_in">system_error</span>(</span><br><span class="line"> ECONNABORTED,</span><br><span class="line"> std::<span class="built_in">system_category</span>(),</span><br><span class="line"> <span class="string">"Unexpected poll revent on the control pipe's reading fd: "</span> +</span><br><span class="line"> std::<span class="built_in">to_string</span>(fds[<span class="number">0</span>].revents));</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">break</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// if connection is closed gracefully by master, peeked data will return 0</span></span><br><span class="line"> <span class="type">char</span> data;</span><br><span class="line"> <span class="type">int</span> ret = <span class="built_in">recv</span>(fds[<span class="number">1</span>].fd, &data, <span class="number">1</span>, MSG_PEEK);</span><br><span class="line"> <span class="keyword">if</span> (ret == <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">continue</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// valid request, perform callback logic</span></span><br><span class="line"> <span class="built_in">callbackHandler</span>(fds[<span class="number">1</span>].fd); <span class="comment">// 业务处理</span></span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"><span class="meta">#<span class="keyword">endif</span></span></span><br></pre></td></tr></tbody></table></figure><h3 id="TCPStoreMasterDaemon"><a href="#TCPStoreMasterDaemon" class="headerlink" title="TCPStoreMasterDaemon"></a><strong>TCPStoreMasterDaemon</strong></h3><p>这里的 <code>std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;</code> 是真实的 kv。</p><p>所以,TCPStoreMasterDaemon 就是负责对 kv 的操作,比如存取。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// Separate thread that is only launched on master</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">TCPStoreMasterDaemon</span> : <span class="keyword">public</span> BackgroundThread {</span><br><span class="line"> <span class="keyword">public</span>:</span><br><span class="line"> <span class="function"><span class="keyword">explicit</span> <span class="title">TCPStoreMasterDaemon</span><span class="params">(<span class="type">int</span> storeListenSocket)</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">private</span>:</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">run</span><span class="params">()</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">queryFds</span><span class="params">(std::vector<<span class="keyword">struct</span> pollfd>& fds)</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">query</span><span class="params">(<span class="type">int</span> socket)</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="comment">// The master runs on a single thread so only</span></span><br><span class="line"> <span class="comment">// one handler can be executed at a time</span></span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">setHandler</span><span class="params">(<span class="type">int</span> socket)</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">compareSetHandler</span><span class="params">(<span class="type">int</span> socket)</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">addHandler</span><span class="params">(<span class="type">int</span> socket)</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">getHandler</span><span class="params">(<span class="type">int</span> socket)</span> <span class="type">const</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">checkHandler</span><span class="params">(<span class="type">int</span> socket)</span> <span class="type">const</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">getNumKeysHandler</span><span class="params">(<span class="type">int</span> socket)</span> <span class="type">const</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">deleteHandler</span><span class="params">(<span class="type">int</span> socket)</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">waitHandler</span><span class="params">(<span class="type">int</span> socket)</span></span>;</span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">watchHandler</span><span class="params">(<span class="type">int</span> socket)</span></span>;</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="type">bool</span> <span class="title">checkKeys</span><span class="params">(<span class="type">const</span> std::vector<std::string>& keys)</span> <span class="type">const</span></span>;</span><br><span class="line"> <span class="comment">// Helper function to alerts waiting workers, used in setHandler, getHandler</span></span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">wakeupWaitingClients</span><span class="params">(<span class="type">const</span> std::string& key)</span></span>;</span><br><span class="line"> <span class="comment">// Helper function used when the key is changed</span></span><br><span class="line"> <span class="comment">// used in setHandler, addHandler, getHandler, deleteHandler</span></span><br><span class="line"> <span class="function"><span class="type">void</span> <span class="title">sendKeyUpdatesToClients</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::string& key,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> <span class="keyword">enum</span> WatchResponseType& type,</span></span></span><br><span class="line"><span class="params"><span class="function"> std::vector<<span class="type">uint8_t</span>>& oldData,</span></span></span><br><span class="line"><span class="params"><span class="function"> std::vector<<span class="type">uint8_t</span>>& newData)</span></span>;</span><br><span class="line"> std::unordered_map<std::string, std::vector<<span class="type">uint8_t</span>>> tcpStore_;</span><br><span class="line"> <span class="comment">// From key -> the list of sockets waiting on the key</span></span><br><span class="line"> std::unordered_map<std::string, std::vector<<span class="type">int</span>>> waitingSockets_;</span><br><span class="line"> <span class="comment">// From socket -> number of keys awaited</span></span><br><span class="line"> std::unordered_map<<span class="type">int</span>, <span class="type">size_t</span>> keysAwaited_;</span><br><span class="line"> <span class="comment">// From key -> the list of sockets watching the key</span></span><br><span class="line"> std::unordered_map<std::string, std::vector<<span class="type">int</span>>> watchedSockets_;</span><br><span class="line">};</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="运行-1"><a href="#运行-1" class="headerlink" title="运行"></a><strong>运行</strong></h3><p>TCPStoreMasterDaemon 就是等待在 socket 之上,即 <code>masterListenSocket_</code> 是 listen 在 masterPort 之上。</p><ul><li><code>tcpStoreMasterDaemon_</code> 使用 <code>tcputil::addPollfd(fds, storeListenSocket_, POLLIN)</code> 来监听 <code>masterListenSocket_</code>。</li><li>tcpStoreMasterDaemon_本身成为一个master,就是为整个 TCPStore提供服务的 server。</li><li>key-value 就是<code>std::unordered_map<std::string, std::vector<uint8_t>> tcpStore</code>。</li></ul><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">ifdef</span> _WIN32</span></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStoreMasterDaemon::run</span><span class="params">()</span> </span>{</span><br><span class="line"> std::vector<<span class="keyword">struct</span> pollfd> fds;</span><br><span class="line"> tcputil::<span class="built_in">addPollfd</span>(fds, storeListenSocket_, POLLIN);</span><br><span class="line"></span><br><span class="line"> <span class="comment">// receive the queries</span></span><br><span class="line"> <span class="type">bool</span> finished = <span class="literal">false</span>;</span><br><span class="line"> <span class="keyword">while</span> (!finished) {</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">size_t</span> i = <span class="number">0</span>; i < sockets_.<span class="built_in">size</span>(); i++) {</span><br><span class="line"> fds[i].revents = <span class="number">0</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="type">int</span> res;</span><br><span class="line"> <span class="built_in">SYSCHECK_ERR_RETURN_NEG1</span>(</span><br><span class="line"> res = <span class="built_in">WSAPoll</span>(fds.<span class="built_in">data</span>(), fds.<span class="built_in">size</span>(), checkTimeout_.<span class="built_in">count</span>()))</span><br><span class="line"> <span class="keyword">if</span> (res == <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">auto</span> rv = <span class="built_in">WaitForSingleObject</span>(ghStopEvent_, <span class="number">0</span>);</span><br><span class="line"> <span class="keyword">if</span> (rv != WAIT_TIMEOUT) {</span><br><span class="line"> finished = <span class="literal">true</span>;</span><br><span class="line"> <span class="keyword">break</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">continue</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// TCPStore's listening socket has an event and it should now be able to</span></span><br><span class="line"> <span class="comment">// accept new connections.</span></span><br><span class="line"> <span class="keyword">if</span> (fds[<span class="number">0</span>].revents != <span class="number">0</span>) { <span class="comment">// 收到了消息</span></span><br><span class="line"> <span class="keyword">if</span> (!(fds[<span class="number">0</span>].revents & POLLIN)) {</span><br><span class="line"> <span class="keyword">throw</span> std::<span class="built_in">system_error</span>(</span><br><span class="line"> ECONNABORTED,</span><br><span class="line"> std::<span class="built_in">system_category</span>(),</span><br><span class="line"> <span class="string">"Unexpected poll revent on the master's listening socket: "</span> +</span><br><span class="line"> std::<span class="built_in">to_string</span>(fds[<span class="number">0</span>].revents));</span><br><span class="line"> }</span><br><span class="line"> <span class="type">int</span> sockFd = std::<span class="built_in">get</span><<span class="number">0</span>>(tcputil::<span class="built_in">accept</span>(storeListenSocket_));</span><br><span class="line"> sockets_.<span class="built_in">push_back</span>(sockFd);</span><br><span class="line"> tcputil::<span class="built_in">addPollfd</span>(fds, sockFd, POLLIN);</span><br><span class="line"> }</span><br><span class="line"> <span class="built_in">queryFds</span>(fds); <span class="comment">// 业务处理</span></span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"><span class="meta">#elsevoid TCPStoreMasterDaemon::run() {</span></span><br><span class="line"> std::vector<<span class="keyword">struct</span> pollfd> fds;</span><br><span class="line"> tcputil::<span class="built_in">addPollfd</span>(fds, storeListenSocket_, POLLIN);</span><br><span class="line"> <span class="comment">// Push the read end of the pipe to signal the stopping of the daemon run</span></span><br><span class="line"> tcputil::<span class="built_in">addPollfd</span>(fds, controlPipeFd_[<span class="number">0</span>], POLLHUP);</span><br><span class="line"></span><br><span class="line"> <span class="comment">// receive the queries</span></span><br><span class="line"> <span class="type">bool</span> finished = <span class="literal">false</span>;</span><br><span class="line"> <span class="keyword">while</span> (!finished) {</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">size_t</span> i = <span class="number">0</span>; i < sockets_.<span class="built_in">size</span>(); i++) {</span><br><span class="line"> fds[i].revents = <span class="number">0</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="built_in">SYSCHECK_ERR_RETURN_NEG1</span>(::<span class="built_in">poll</span>(fds.<span class="built_in">data</span>(), fds.<span class="built_in">size</span>(), <span class="number">-1</span>));</span><br><span class="line"></span><br><span class="line"> <span class="comment">// TCPStore's listening socket has an event and it should now be able to</span></span><br><span class="line"> <span class="comment">// accept new connections.</span></span><br><span class="line"> <span class="keyword">if</span> (fds[<span class="number">0</span>].revents != <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">if</span> (fds[<span class="number">0</span>].revents ^ POLLIN) {</span><br><span class="line"> <span class="keyword">throw</span> std::<span class="built_in">system_error</span>(</span><br><span class="line"> ECONNABORTED,</span><br><span class="line"> std::<span class="built_in">system_category</span>(),</span><br><span class="line"> <span class="string">"Unexpected poll revent on the master's listening socket: "</span> +</span><br><span class="line"> std::<span class="built_in">to_string</span>(fds[<span class="number">0</span>].revents));</span><br><span class="line"> }</span><br><span class="line"> <span class="type">int</span> sockFd = std::<span class="built_in">get</span><<span class="number">0</span>>(tcputil::<span class="built_in">accept</span>(storeListenSocket_));</span><br><span class="line"> sockets_.<span class="built_in">push_back</span>(sockFd);</span><br><span class="line"> tcputil::<span class="built_in">addPollfd</span>(fds, sockFd, POLLIN);</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// The pipe receives an event which tells us to shutdown the daemon</span></span><br><span class="line"> <span class="keyword">if</span> (fds[<span class="number">1</span>].revents != <span class="number">0</span>) { <span class="comment">// 收到了消息</span></span><br><span class="line"> <span class="comment">// Will be POLLUP when the pipe is closed</span></span><br><span class="line"> <span class="keyword">if</span> (fds[<span class="number">1</span>].revents ^ POLLHUP) {</span><br><span class="line"> <span class="keyword">throw</span> std::<span class="built_in">system_error</span>(</span><br><span class="line"> ECONNABORTED,</span><br><span class="line"> std::<span class="built_in">system_category</span>(),</span><br><span class="line"> <span class="string">"Unexpected poll revent on the control pipe's reading fd: "</span> +</span><br><span class="line"> std::<span class="built_in">to_string</span>(fds[<span class="number">1</span>].revents));</span><br><span class="line"> }</span><br><span class="line"> finished = <span class="literal">true</span>;</span><br><span class="line"> <span class="keyword">break</span>;</span><br><span class="line"> }</span><br><span class="line"> <span class="built_in">queryFds</span>(fds); <span class="comment">// 业务处理</span></span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"><span class="meta">#<span class="keyword">endif</span></span></span><br></pre></td></tr></tbody></table></figure><h3 id="调用业务"><a href="#调用业务" class="headerlink" title="调用业务"></a><strong>调用业务</strong></h3><p><code>queryFds</code> 会根据 socket 监听结果而调用不同业务。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStoreMasterDaemon::queryFds</span><span class="params">(std::vector<<span class="keyword">struct</span> pollfd>& fds)</span> </span>{</span><br><span class="line"> <span class="comment">// Skipping the fds[0] and fds[1],</span></span><br><span class="line"> <span class="comment">// fds[0] is master's listening socket</span></span><br><span class="line"> <span class="comment">// fds[1] is control pipe's reading fd, it is not for Windows platform</span></span><br><span class="line"> <span class="keyword">for</span> (<span class="type">size_t</span> fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.<span class="built_in">size</span>(); ++fdIdx) {</span><br><span class="line"> <span class="keyword">if</span> (fds[fdIdx].revents == <span class="number">0</span>) {</span><br><span class="line"> <span class="keyword">continue</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Now query the socket that has the event</span></span><br><span class="line"> <span class="keyword">try</span> {</span><br><span class="line"> <span class="built_in">query</span>(fds[fdIdx].fd); <span class="comment">// 处理业务</span></span><br><span class="line"> } <span class="built_in">catch</span> (...) {</span><br><span class="line"> tcputil::<span class="built_in">closeSocket</span>(fds[fdIdx].fd);</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Remove all the tracking state of the close FD</span></span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">auto</span> it = waitingSockets_.<span class="built_in">begin</span>(); it != waitingSockets_.<span class="built_in">end</span>();) {</span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">auto</span> vecIt = it->second.<span class="built_in">begin</span>(); vecIt != it->second.<span class="built_in">end</span>();) {</span><br><span class="line"> <span class="keyword">if</span> (*vecIt == fds[fdIdx].fd) {</span><br><span class="line"> vecIt = it->second.<span class="built_in">erase</span>(vecIt);</span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> ++vecIt;</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">if</span> (it->second.<span class="built_in">size</span>() == <span class="number">0</span>) {</span><br><span class="line"> it = waitingSockets_.<span class="built_in">erase</span>(it);</span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> ++it;</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">for</span> (<span class="keyword">auto</span> it = keysAwaited_.<span class="built_in">begin</span>(); it != keysAwaited_.<span class="built_in">end</span>();) {</span><br><span class="line"> <span class="keyword">if</span> (it->first == fds[fdIdx].fd) {</span><br><span class="line"> it = keysAwaited_.<span class="built_in">erase</span>(it);</span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> ++it;</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line"> fds.<span class="built_in">erase</span>(fds.<span class="built_in">begin</span>() + fdIdx);</span><br><span class="line"> sockets_.<span class="built_in">erase</span>(sockets_.<span class="built_in">begin</span>() + fdIdx - CONNECT_SOCKET_OFFSET);</span><br><span class="line"> --fdIdx;</span><br><span class="line"> <span class="keyword">continue</span>;</span><br><span class="line"> }</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="处理业务"><a href="#处理业务" class="headerlink" title="处理业务"></a><strong>处理业务</strong></h3><p>从 socket 之中读取消息,依据消息内容来进行相关业务处理。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">// query communicates with the worker. The format</span></span><br><span class="line"><span class="comment">// of the query is as follows:</span></span><br><span class="line"><span class="comment">// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...</span></span><br><span class="line"><span class="comment">// or, in the case of wait</span></span><br><span class="line"><span class="comment">// type of query | number of args | size of arg1 | arg1 | ...</span></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStoreMasterDaemon::query</span><span class="params">(<span class="type">int</span> socket)</span> </span>{</span><br><span class="line"> QueryType qt;</span><br><span class="line"> tcputil::<span class="built_in">recvBytes</span><QueryType>(socket, &qt, <span class="number">1</span>);</span><br><span class="line"> <span class="keyword">if</span> (qt == QueryType::SET) {</span><br><span class="line"> <span class="built_in">setHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (qt == QueryType::COMPARE_SET) {</span><br><span class="line"> <span class="built_in">compareSetHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (qt == QueryType::ADD) {</span><br><span class="line"> <span class="built_in">addHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (qt == QueryType::GET) {</span><br><span class="line"> <span class="built_in">getHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (qt == QueryType::CHECK) {</span><br><span class="line"> <span class="built_in">checkHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (qt == QueryType::WAIT) {</span><br><span class="line"> <span class="built_in">waitHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (qt == QueryType::GETNUMKEYS) {</span><br><span class="line"> <span class="built_in">getNumKeysHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (qt == QueryType::DELETE_KEY) {</span><br><span class="line"> <span class="built_in">deleteHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> <span class="keyword">if</span> (qt == QueryType::WATCH_KEY) {</span><br><span class="line"> <span class="built_in">watchHandler</span>(socket);</span><br><span class="line"></span><br><span class="line"> } <span class="keyword">else</span> {</span><br><span class="line"> <span class="keyword">throw</span> std::<span class="built_in">runtime_error</span>(<span class="string">"Unexpected query type"</span>);</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="添加"><a href="#添加" class="headerlink" title="添加"></a><strong>添加</strong></h3><p>此处是处理添加 value 的业务。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStoreMasterDaemon::setHandler</span><span class="params">(<span class="type">int</span> socket)</span> </span>{</span><br><span class="line"> std::string key = tcputil::<span class="built_in">recvString</span>(socket);</span><br><span class="line"> std::vector<<span class="type">uint8_t</span>> newData = tcputil::<span class="built_in">recvVector</span><<span class="type">uint8_t</span>>(socket);</span><br><span class="line"> std::vector<<span class="type">uint8_t</span>> oldData;</span><br><span class="line"> <span class="type">bool</span> newKey = <span class="literal">true</span>;</span><br><span class="line"> <span class="keyword">auto</span> it = tcpStore_.<span class="built_in">find</span>(key);</span><br><span class="line"> <span class="keyword">if</span> (it != tcpStore_.<span class="built_in">end</span>()) {</span><br><span class="line"> oldData = it->second;</span><br><span class="line"> newKey = <span class="literal">false</span>;</span><br><span class="line"> }</span><br><span class="line"> tcpStore_[key] = newData;</span><br><span class="line"> <span class="comment">// On "set", wake up all clients that have been waiting</span></span><br><span class="line"> <span class="built_in">wakeupWaitingClients</span>(key);</span><br><span class="line"> <span class="comment">// Send key update to all watching clients</span></span><br><span class="line"> newKey ? <span class="built_in">sendKeyUpdatesToClients</span>(</span><br><span class="line"> key, WatchResponseType::KEY_CREATED, oldData, newData)</span><br><span class="line"> : <span class="built_in">sendKeyUpdatesToClients</span>(</span><br><span class="line"> key, WatchResponseType::KEY_UPDATED, oldData, newData);</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="获取"><a href="#获取" class="headerlink" title="获取"></a><strong>获取</strong></h3><p>出处处理获取 value 的业务。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStoreMasterDaemon::getHandler</span><span class="params">(<span class="type">int</span> socket)</span> <span class="type">const</span> </span>{</span><br><span class="line"> std::string key = tcputil::<span class="built_in">recvString</span>(socket);</span><br><span class="line"> <span class="keyword">auto</span> data = tcpStore_.<span class="built_in">at</span>(key);</span><br><span class="line"> tcputil::<span class="built_in">sendVector</span><<span class="type">uint8_t</span>>(socket, data);</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="watchKey-1"><a href="#watchKey-1" class="headerlink" title="watchKey"></a><strong>watchKey</strong></h3><p>此处添加了想要监控的 key。</p><p>对于WATCH_KEY,给对应的key添加了一个socket,作为以后发送通知的对象。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStoreMasterDaemon::watchHandler</span><span class="params">(<span class="type">int</span> socket)</span> </span>{</span><br><span class="line"> std::string key = tcputil::<span class="built_in">recvString</span>(socket);</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Record the socket to respond to when the key is updated</span></span><br><span class="line"> watchedSockets_[key].<span class="built_in">push_back</span>(socket);</span><br><span class="line"></span><br><span class="line"> <span class="comment">// Send update to TCPStoreWorkerDaemon on client</span></span><br><span class="line"> tcputil::<span class="built_in">sendValue</span><WatchResponseType>(</span><br><span class="line"> socket, WatchResponseType::KEY_CALLBACK_REGISTERED);</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="通知"><a href="#通知" class="headerlink" title="通知"></a><strong>通知</strong></h3><p>如果key 有变化,就通知客户端。</p><figure class="highlight cpp"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">TCPStoreMasterDaemon::sendKeyUpdatesToClients</span><span class="params">(</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> std::string& key,</span></span></span><br><span class="line"><span class="params"><span class="function"> <span class="type">const</span> <span class="keyword">enum</span> WatchResponseType& type,</span></span></span><br><span class="line"><span class="params"><span class="function"> std::vector<<span class="type">uint8_t</span>>& oldData,</span></span></span><br><span class="line"><span class="params"><span class="function"> std::vector<<span class="type">uint8_t</span>>& newData)</span> </span>{</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> socket : watchedSockets_[key]) {</span><br><span class="line"> tcputil::<span class="built_in">sendValue</span><WatchResponseType>(socket, type);</span><br><span class="line"> tcputil::<span class="built_in">sendString</span>(socket, key, <span class="literal">true</span>);</span><br><span class="line"> tcputil::<span class="built_in">sendVector</span><<span class="type">uint8_t</span>>(socket, oldData);</span><br><span class="line"> tcputil::<span class="built_in">sendVector</span><<span class="type">uint8_t</span>>(socket, newData);</span><br><span class="line"> }</span><br><span class="line">}</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a><strong>总结</strong></h2><p>我们总结图例如下:</p><ul><li>Master 之中使用MasterPort 进行监听请求。</li><li>关于存取value。<ul><li>Worker 之中,storeSocket_ 被用来存储/获取value,对应下图 数字 1。</li><li>在 Master 之中对应了 <code>tcpStore_</code>。</li></ul></li><li>关于监控。<ul><li>Worker 之中,listenSocket_ 被用来通知 Master 我需要监听这个 key,对应下图 数字 2。同时 worker 内部给这个 key 设置了 callback,对应了下图 数字 3。</li><li>监听在 Master 之中对应了 <code>watchedSockets_[key] = socket_</code> 。</li><li>Master 之中,如果设置 value 时候,发现是一个被监控的 key,就通知 <code>watchedSockets_[key]</code>,对应了下图 数字 4。</li><li>Worker 之中会进行相关业务调用。</li></ul></li></ul><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br></pre></td><td class="code"><pre><span class="line"> +</span><br><span class="line">+----------------------------------------------------------------------+ | +------------------------------------------------------------------------+</span><br><span class="line">| TCPStore Master | | | TCPStore Worker |</span><br><span class="line">| | | | |</span><br><span class="line">| storeSocket_ | | | |</span><br><span class="line">| | | | |</span><br><span class="line">| +------------------------------------------------------------+ | | | |</span><br><span class="line">| | TcpStoreMasterDaemon_ MasterPort| | | | 1 +---------------------------------+ |</span><br><span class="line">| | | <--------------+ | set(key, value) | |</span><br><span class="line">| | unordered_map<string, vector<uint8_t> > tcpStore_+---+ | | | | | | |</span><br><span class="line">| | | | | | | | storeSocket_ | |</span><br><span class="line">| | TCPStore.masterListenSocket_ | | | | | | | |</span><br><span class="line">| | | | | | | +---------------------------------+ |</span><br><span class="line">| | +-----------------------------------------------+ | | | | | |</span><br><span class="line">| | | run | | | | | | 2 +---------------------------------+ |</span><br><span class="line">| | | | | | <--------------+ | | |</span><br><span class="line">| | | queryFds query | | | | | | | watchKey(key, callback) +-------------------------------+ |</span><br><span class="line">| | | | | | | | | | | 3 | |</span><br><span class="line">| | | setHandler getHandler | | | | | | | listenSocket_ | | |</span><br><span class="line">| | | | | | | | | | | | |</span><br><span class="line">| | +-----------------------------------------------+ | | | | | | | | |</span><br><span class="line">| | | | | | | +---------------------------------+ | |</span><br><span class="line">| +------------------------------------------------------------+ | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | +----------------------------------------------------------------+ |</span><br><span class="line">| | | | | | TCPStoreWorkerDaemon | | |</span><br><span class="line">| | | | | | | | |</span><br><span class="line">| | | | | | unordered_map<string, WatchKeyCallback> keyToCallbacks_ | | |</span><br><span class="line">| | | | | | | | |</span><br><span class="line">| | | | | | TCPStore.listenSocket_ +----+ | |</span><br><span class="line">| | | | | | | | |</span><br><span class="line">| | | | | | +----------------------------------------------------------+ | |</span><br><span class="line">| | | | | | | run | | | |</span><br><span class="line">| | 4 | | | | | | | | |</span><br><span class="line">| +--------------------->+ v | | |</span><br><span class="line">| | | | | | callbackHandler +-----> keyToCallbacks_(callback) | | |</span><br><span class="line">| | | | | | | | |</span><br><span class="line">| | | | | +----------------------------------------------------------+ | |</span><br><span class="line">| | | | +----------------------------------------------------------------+ |</span><br><span class="line">+----------------------------------------------------------------------+ + +------------------------------------------------------------------------+</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>图片如下:</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114221048360-1216105416.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114221048360-1216105416.png"></p><p>至此,我们梳理了初始化方法和Store这两个概念,最终其实是Store这个概念在初始化过程中起了作用。我们也通过TCPStore 的分析知道了一个Store应该具备的功能,比如设置KV,监控某个key的变等等,正是这些功能才可以让若干进程彼此知道对方的存在。</p>]]></content>
<summary type="html">DistributedDataParallel 初始化方法&存储</summary>
<category term="分布式训练" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="分布式训练" scheme="https://thinksky5124.github.io/tags/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="工程" scheme="https://thinksky5124.github.io/tags/%E5%B7%A5%E7%A8%8B/"/>
</entry>
<entry>
<title>DistributedDataParallel 总述&如何使用</title>
<link href="https://thinksky5124.github.io/2022/08/18/DistributedDataParallel_%E6%80%BB%E8%BF%B0&%E5%A6%82%E4%BD%95%E4%BD%BF%E7%94%A8/"/>
<id>https://thinksky5124.github.io/2022/08/18/DistributedDataParallel_%E6%80%BB%E8%BF%B0&%E5%A6%82%E4%BD%95%E4%BD%BF%E7%94%A8/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<h1 id="DistributedDataParallel-总述-如何使用"><a href="#DistributedDataParallel-总述-如何使用" class="headerlink" title="DistributedDataParallel 总述&如何使用"></a>DistributedDataParallel 总述&如何使用</h1><h2 id="数据并行"><a href="#数据并行" class="headerlink" title="数据并行"></a><strong>数据并行</strong></h2><p>因为DistributedDataParallel 是数据并行,所以首先通过两个图,复习一下什么是数据并行。</p><p>我们可以看到,模型并行与数据并行的区别。</p><p><img src="https://s2.loli.net/2024/03/25/Vf3zSQbqKiMIACY.png" alt="DP_vs_DDP.png"></p><p>第二张图来自fairscale github源码,清晰的给出了一个数据并行的运行模式,具体包括:</p><p>模型分片,本地前向计算,本地反向传播,AllReduce来同步梯度,本地更新梯度这几步。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114211025872-537994393.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114211025872-537994393.png"></p><h2 id="DDP-运行逻辑"><a href="#DDP-运行逻辑" class="headerlink" title="DDP 运行逻辑"></a><strong>DDP 运行逻辑</strong></h2><p>Torch.distributed 包 为多个计算节点的 PyTorch 提供多进程并行通信原语,可以并行化跨进程和跨集群的计算。<code>torch.nn.parallel.DistributedDataParallel</code>基于torch.distributed 包的功能提供了一个同步分布式训练wrapper,这个wrapper可以对 PyTorch 模型封装进行训练。其核心功能是基于多进程级别的通信,与<a href="https://pytorch.org/docs/stable/multiprocessing.html">Multiprocessing package - torch.multiprocessing</a> 和 DataParrallel 提供的并行性有明显区别。</p><p>以下是 DDP 的整体架构,大家可以看到ddp在整个架构之中的位置,依赖项等等。图片来自来自源码。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114211122447-104368073.png" alt="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114211122447-104368073.png"></p><p>通过一个图来说明 DDP 的运行逻辑。</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114211206302-864708996.png" alt="图片来自 [https://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/](https://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/)"></p><p>图片来自 <a href="https://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/">https://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/</a></p><p>具体逻辑如下:</p><ol><li><strong>加载模型阶段</strong>。每个GPU都拥有模型的一个副本,所以不需要拷贝模型。rank为0的进程会将网络初始化参数broadcast到其它每个进程中,确保每个进程中的模型都拥有一样的初始化值。</li><li><strong>加载数据阶段</strong>。DDP 不需要广播数据,而是使用多进程并行加载数据。在 host 之上,每个worker进程都会把自己负责的数据从硬盘加载到 page-locked memory。DistributedSampler 保证每个进程加载到的数据是彼此不重叠的。</li><li><strong>前向传播阶段</strong>。在每个GPU之上运行前向传播,计算输出。每个GPU都执行同样的训练,所以不需要有主 GPU。</li><li><strong>计算损失</strong>。在每个GPU之上计算损失。</li><li><strong>反向传播阶段</strong>。运行后向传播来计算梯度,在计算梯度同时也对梯度执行all-reduce操作。</li><li><strong>更新模型参数阶段</strong>。因为每个GPU都从完全相同的模型开始训练,并且梯度被all-reduced,因此每个GPU在反向传播结束时最终得到平均梯度的相同副本,所有GPU上的权重更新都相同,也就不需要模型同步了。注意,在每次迭代中,模型中的Buffers 需要从rank为0的进程广播到进程组的其它进程上。</li></ol><h2 id="DistributedDataParallel-VS-DataParallel"><a href="#DistributedDataParallel-VS-DataParallel" class="headerlink" title="DistributedDataParallel VS DataParallel"></a>DistributedDataParallel <strong>VS DataParallel</strong></h2><h3 id="本质区别"><a href="#本质区别" class="headerlink" title="本质区别"></a><strong>本质区别</strong></h3><p>既然 DataParallel 可以进行数据并行训练,那么为什么还需要提出 DistributedDataParallel呢?这里我们就需要知道两种方法的实现原理与区别:</p><ul><li>大型模型训练。<ul><li>如果模型太大而无法容纳在单个 GPU 上,则必须使用模型并行将其拆分到多个 GPU 中。<ul><li>DataParallel 因为必须将模型放入单块 GPU 中,所以难以完成大型模型的训练,即,无法和模型并行(跨多个 GPU 拆分单个模型)一起合作。</li><li>DistributedDataParallel 可以只包括大型模型的一部分,因此可以与模型并行一起合作。</li></ul></li><li>如果数据太大而无法容纳在一台计算机上,则需要使用数据并行。<ul><li>在这种情况下,每个 DistributedDataParallel 进程都可以并行使用模型,而所有进程都将并行使用数据。此时与 DP 没有太大区别。</li></ul></li><li>如果您的模型需要跨越多台机器,或者您的用例不适合数据并行性范式,请参阅 <a href="https://pytorch.org/docs/stable/rpc.html">RPC API</a> ,以获得更多通用的分布式训练支持。</li></ul></li><li>多进程还是多线程:<ul><li>DataParallel 是单进程,多线程的并行训练方式,并且只能在单台机器上运行。</li><li>而DistributedDataParallel 是多进程,并且适用于单机和多机训练。DistributedDataParallel 还预先复制模型,而不是在每次迭代时复制模型,并避免了全局解释器锁定。<ul><li>每个进程维护自己的优化器,并且在每次迭代中执行一个完整的优化步骤。由于梯度已经聚合(gather)并跨进程平均,因此梯度对于每个进程都是相同的,这就不需要广播参数步骤,因此减少了在节点之间传输张量的时间。</li><li>每个进程包含一个独立的 Python 解释器,因而消除了单个 Python 进程驱动多个执行线程、模型副本或者 GPU 的额外解释器开销和”GIL 颠簸”(<code>GIL-thrashing</code>)。对于严重依赖 Python 运行时的模型(比如说包含 <code>RNN</code> 层或大量小组件的 <code>models</code> )这尤其重要。</li></ul></li><li>即使在单台机器上,<code>DataParallel</code>通常也比<code>DistributedDataParallel</code>慢,这是因为跨线程的 GIL 争用,每次迭代复制的模型以及分散输入和收集输出所带来的额外开销。</li></ul></li></ul><h3 id="实现区别"><a href="#实现区别" class="headerlink" title="实现区别"></a><strong>实现区别</strong></h3><p>DDP 与DP在具体实现上的区别如下:</p><ul><li>关于优化器:<ul><li>DDP :在每次迭代之中,DDP 的每个进程都有自己的 <code>optimizer</code> ,每个进程都独立完成所有优化步骤,这和非分布式训练一样。</li><li>DP :在 DP 中只有一个 <code>optimizer</code>,在主线程执行。其对各 <code>GPU</code> 上梯度进行求和,而在主 <code>GPU</code> 进行参数更新,之后再将模型参数 <code>broadcast</code> 到其他 <code>GPU</code>。</li></ul></li><li>关于梯度。<ul><li>DDP :每个进程在自己 GPU之上计算损失,运行后向传播来计算梯度,在计算梯度同时对梯度执行all-reduce操作。</li><li>DP :在各进程梯度计算完成之后,各进程需要将<strong>梯度</strong>进行汇总规约到主进程,主进程用梯度来更新模型权重,然后其 <strong><code>broadcast</code></strong> 模型到所有进程(其他GPU)进行下一步训练。</li></ul></li><li>关于传播数据:<ul><li>DDP :只对梯度等少量数据进行交换。由于各进程中的模型,初始参数一致 (初始时刻进行一次 <code>broadcast</code>),而每次用于更新参数的梯度也一致,因此,各进程的模型参数始终保持一致。相较于 <code>DataParallel</code>来说,<code>torch.distributed</code> 传输的数据量更少,因此速度更快,效率更高。</li><li>DP :每次迭代,有大量交互,比如模型,前向输出,损失,梯度等。</li></ul></li></ul><h2 id="使用"><a href="#使用" class="headerlink" title="使用"></a><strong>使用</strong></h2><p><code>Pytorch</code> 中分布式的基本使用流程如下:</p><ol><li>首先需要使用 <code>init_process_group</code> 初始化进程组,同时初始化 <code>distributed</code> 包,然后才能使用 <code>distributed</code> 包的其他函数。</li><li>如果需要进行组内集体通信,用 <code>new_group</code> 创建子分组。</li><li>使用 <code>DDP(model, device_ids=device_ids)</code> 创建 DistributedDataParalle 模型。</li><li>为数据集创建分布式 <code>Sampler</code>。</li><li>使用启动工具 <code>torch.distributed.launch</code> 在每个主机上执行脚本,开始训练。</li><li>使用 <code>destory_process_group()</code> 销毁进程组。</li></ol><h3 id="基本示例"><a href="#基本示例" class="headerlink" title="基本示例"></a><strong>基本示例</strong></h3><p>首先,使用 <a href="https://pytorch.org/tutorials/intermediate/ddp_tutorial.html">https://pytorch.org/tutorials/intermediate/ddp_tutorial.html</a> 来看看。</p><h3 id="第一步:设置进程组"><a href="#第一步:设置进程组" class="headerlink" title="第一步:设置进程组"></a>第一步:<strong>设置进程组</strong></h3><p>在示例的最开始,我们首先要正确设置进程组。</p><p><code>init_process_group</code> 的参数解释如下:</p><ul><li><code>"gloo"</code> 说明后端使用 <code>"gloo"</code>。</li><li>rank 是本进程对应的rank,如果是0,则说明本进程是 master 进程,负责广播模型状态等工作。</li><li><code>world_size</code> 指的是总的并行进程数目,如果连接的进程数小于<code>world_size</code>,进程就会阻塞在 <code>init_process_group</code>之上,如果达到了 <code>world_size</code>,程序才会继续运行。如果 <code>batch_size = 16</code>,那么总体的<code>batch size</code> 就是 <code>16 * world_size</code>。</li></ul><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> sys</span><br><span class="line"><span class="keyword">import</span> tempfile</span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torch.distributed <span class="keyword">as</span> dist</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">import</span> torch.optim <span class="keyword">as</span> optim</span><br><span class="line"><span class="keyword">import</span> torch.multiprocessing <span class="keyword">as</span> mp</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> torch.nn.parallel <span class="keyword">import</span> DistributedDataParallel <span class="keyword">as</span> DDP</span><br><span class="line"></span><br><span class="line"><span class="comment"># On Windows platform, the torch.distributed package only</span></span><br><span class="line"><span class="comment"># supports Gloo backend, FileStore and TcpStore.</span></span><br><span class="line"><span class="comment"># For FileStore, set init_method parameter in init_process_group</span></span><br><span class="line"><span class="comment"># to a local file. Example as follow:</span></span><br><span class="line"><span class="comment"># init_method="file:///f:/libtmp/some_file"</span></span><br><span class="line"><span class="comment"># dist.init_process_group(</span></span><br><span class="line"><span class="comment"># "gloo",</span></span><br><span class="line"><span class="comment"># rank=rank,</span></span><br><span class="line"><span class="comment"># init_method=init_method,</span></span><br><span class="line"><span class="comment"># world_size=world_size)</span></span><br><span class="line"><span class="comment"># For TcpStore, same way as on Linux.</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">setup</span>(<span class="params">rank, world_size</span>):</span><br><span class="line"> os.environ[<span class="string">'MASTER_ADDR'</span>] = <span class="string">'localhost'</span></span><br><span class="line"> os.environ[<span class="string">'MASTER_PORT'</span>] = <span class="string">'12355'</span></span><br><span class="line"></span><br><span class="line"> <span class="comment"># initialize the process group</span></span><br><span class="line"> dist.init_process_group(<span class="string">"gloo"</span>, rank=rank, world_size=world_size) <span class="comment"># 这条命令之后,master进程就处于等待状态</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">cleanup</span>():</span><br><span class="line"> dist.destroy_process_group()</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="简单模型"><a href="#简单模型" class="headerlink" title="简单模型"></a><strong>简单模型</strong></h3><p>现在,让我们创建一个简单模块,用 DDP 包装它,并用一些虚拟输入数据馈送它。请注意,由于 DDP 将模型状态从 rank 0 进程广播到 DDP 构造函数中的所有其他进程,因此对于所有 DDP 进程来说,它们的起始模型参数是一样的,用户无需担心不同的 DDP 进程从不同的模型参数初始值开始。</p><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"> +-----------+</span><br><span class="line"> | |</span><br><span class="line"> | Rank 0 |</span><br><span class="line"> | |</span><br><span class="line"> +-----+-----+</span><br><span class="line"> |</span><br><span class="line"> | Model Parameters</span><br><span class="line"> |</span><br><span class="line"> |</span><br><span class="line"> +---------------+---------v----------------------+</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> v v v</span><br><span class="line">+----+-----+ +----+-----+ +---+-------+</span><br><span class="line">| | | | | |</span><br><span class="line">| Rank 1 | | Rank 2 | ...... | Rank n |</span><br><span class="line">| | | | | |</span><br><span class="line">+----------+ +----------+ +-----------+</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>DDP 包装了较低级别的分布式通信细节,并提供了一个干净的 API,就好像它是一个本地模型一样。梯度同步通信发生在反向传播期间,并与反向计算重叠。当<code>backward()</code>返回时,<code>param.grad</code>已经包含同步梯度张量。因为DDP 封装了分布式通信原语,所以模型参数的梯度可以进行 all-reduce。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">ToyModel</span>(nn.Module):</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line"> <span class="built_in">super</span>(ToyModel, self).__init__()</span><br><span class="line"> self.net1 = nn.Linear(<span class="number">10</span>, <span class="number">10</span>)</span><br><span class="line"> self.relu = nn.ReLU()</span><br><span class="line"> self.net2 = nn.Linear(<span class="number">10</span>, <span class="number">5</span>)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line"> <span class="keyword">return</span> self.net2(self.relu(self.net1(x)))</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">demo_basic</span>(<span class="params">rank, world_size</span>):</span><br><span class="line"> <span class="built_in">print</span>(<span class="string">f"Running basic DDP example on rank <span class="subst">{rank}</span>."</span>)</span><br><span class="line"> setup(rank, world_size)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># create model and move it to GPU with id rank</span></span><br><span class="line"> model = ToyModel().to(rank)</span><br><span class="line"> ddp_model = DDP(model, device_ids=[rank])</span><br><span class="line"></span><br><span class="line"> loss_fn = nn.MSELoss()</span><br><span class="line"> optimizer = optim.SGD(ddp_model.parameters(), lr=<span class="number">0.001</span>)</span><br><span class="line"></span><br><span class="line"> optimizer.zero_grad()</span><br><span class="line"> outputs = ddp_model(torch.randn(<span class="number">20</span>, <span class="number">10</span>))</span><br><span class="line"> labels = torch.randn(<span class="number">20</span>, <span class="number">5</span>).to(rank)</span><br><span class="line"> loss_fn(outputs, labels).backward()</span><br><span class="line"> optimizer.step()</span><br><span class="line"></span><br><span class="line"> cleanup()</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">run_demo</span>(<span class="params">demo_fn, world_size</span>):</span><br><span class="line"> mp.spawn(demo_fn,</span><br><span class="line"> args=(world_size,),</span><br><span class="line"> nprocs=world_size,</span><br><span class="line"> join=<span class="literal">True</span>)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>具体如下图</p><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line">+--------------------------+ +------------------------+</span><br><span class="line">| torch.optim.SGD | | DDP |</span><br><span class="line">| | parameters() | |</span><br><span class="line">| | | +------------+ |</span><br><span class="line">| | <-----------------+ | | |</span><br><span class="line">| | | | ToyModel | |</span><br><span class="line">| | | | | |</span><br><span class="line">| | | +------------+ |</span><br><span class="line">| | | |</span><br><span class="line">+--------------------------+ +--------+---------------+</span><br><span class="line"> |</span><br><span class="line"> |</span><br><span class="line"> | forward outputs</span><br><span class="line"> |</span><br><span class="line"> |</span><br><span class="line"> v</span><br><span class="line"></span><br><span class="line"> +-------------------------+</span><br><span class="line"> | nn.MSELoss() |</span><br><span class="line"> | |</span><br><span class="line"> | |</span><br><span class="line"> | |</span><br><span class="line"> | |</span><br><span class="line"> +-------------------------+</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="处理速度偏差"><a href="#处理速度偏差" class="headerlink" title="处理速度偏差"></a><strong>处理速度偏差</strong></h3><p>在 DDP 中,构造函数、前向传递和后向传递是分布式同步点。我们期望不同的进程会启动相同数量的同步操作,并在大致相同的时间以相同的顺序到达这些同步点。否则,进度快的进程可能会提前到达同步点,如果快进程等待落后者的时间过长,那么先到的进程会超时。</p><p>因此,用户需要负责平衡进程间的工作负载分布。有时,由于网络延迟,资源争用,不可预测的工作负载峰值等原因,处理速度的偏差是不可避免的。为避免在这些情况下超时,请确保在调用 <code>init_process_group</code> 时。<code>timeout</code>这个参数传递足够大的值 。</p><h3 id="保存和加载检查点"><a href="#保存和加载检查点" class="headerlink" title="保存和加载检查点"></a><strong>保存和加载检查点</strong></h3><p>一般来说,用户可以使用<code>torch.save</code>和<code>torch.load</code>作为checkpoints,以便从检查点恢复训练。</p><p>在使用 DDP 时,一种优化是只在一个进程中保存模型,然后在所有进程中加载模型,从而减少写入开销(这其实很像数据库中的读写分离)。因为所有进程都从相同的参数开始,并且在反向传递中同步梯度,所以优化器应该将参数设置为相同的值。如果使用此优化,请确保在保存完成之前所有进程都不会开始加载。</p><p>此外,在加载模块时,您需要提供适当的<code>map_location</code> 参数,以防止一个进程进入他人的设备。如果<code>map_location</code> 缺失,<code>torch.load</code>将首先将模块加载到 CPU,然后将每个参数复制到它之前保存的地方,这将导致同一台机器上的所有进程使用相同的一组设备。</p><p>有关更高级的故障恢复和弹性支持,请参阅<a href="https://pytorch.org/elastic">TorchElastic</a>。</p><p>从下图可以看出来,Rank 0 负责保存模型到存储之上,其他 Rank 会加载模型到其本地。</p><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line"> +-----------+</span><br><span class="line"> | |</span><br><span class="line"> | Rank 0 |</span><br><span class="line"> | |</span><br><span class="line"> +-----+-----+</span><br><span class="line"> |</span><br><span class="line"> save | Model Parameters</span><br><span class="line"> |</span><br><span class="line"> |</span><br><span class="line"> v</span><br><span class="line"> +-------+------+</span><br><span class="line"> | |</span><br><span class="line"> +-----------+ Model file +---------------------+</span><br><span class="line"> | | | |</span><br><span class="line"> | +---+----------+ |</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> |load |load load |</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> | | |</span><br><span class="line"> v v v</span><br><span class="line">+----+-----+ +----+-----+ +---+-------+</span><br><span class="line">| | | | | |</span><br><span class="line">| Rank 1 | | Rank 2 | ...... | Rank n |</span><br><span class="line">| | | | | |</span><br><span class="line">+----------+ +----------+ +-----------+</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>具体如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">demo_checkpoint</span>(<span class="params">rank, world_size</span>):</span><br><span class="line"> <span class="built_in">print</span>(<span class="string">f"Running DDP checkpoint example on rank <span class="subst">{rank}</span>."</span>)</span><br><span class="line"> setup(rank, world_size)</span><br><span class="line"></span><br><span class="line"> model = ToyModel().to(rank)</span><br><span class="line"> ddp_model = DDP(model, device_ids=[rank])</span><br><span class="line"></span><br><span class="line"> loss_fn = nn.MSELoss()</span><br><span class="line"> optimizer = optim.SGD(ddp_model.parameters(), lr=<span class="number">0.001</span>)</span><br><span class="line"></span><br><span class="line"> CHECKPOINT_PATH = tempfile.gettempdir() + <span class="string">"/model.checkpoint"</span></span><br><span class="line"> <span class="keyword">if</span> rank == <span class="number">0</span>:</span><br><span class="line"> <span class="comment"># All processes should see same parameters as they all start from same</span></span><br><span class="line"> <span class="comment"># random parameters and gradients are synchronized in backward passes.</span></span><br><span class="line"> <span class="comment"># Therefore, saving it in one process is sufficient.</span></span><br><span class="line"> torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Use a barrier() to make sure that process 1 loads the model after process</span></span><br><span class="line"> <span class="comment"># 0 saves it.</span></span><br><span class="line"> dist.barrier()</span><br><span class="line"> <span class="comment"># configure map_location properly</span></span><br><span class="line"> map_location = {<span class="string">'cuda:%d'</span> % <span class="number">0</span>: <span class="string">'cuda:%d'</span> % rank}</span><br><span class="line"> ddp_model.load_state_dict(</span><br><span class="line"> torch.load(CHECKPOINT_PATH, map_location=map_location))</span><br><span class="line"></span><br><span class="line"> optimizer.zero_grad()</span><br><span class="line"> outputs = ddp_model(torch.randn(<span class="number">20</span>, <span class="number">10</span>))</span><br><span class="line"> labels = torch.randn(<span class="number">20</span>, <span class="number">5</span>).to(rank)</span><br><span class="line"> loss_fn = nn.MSELoss()</span><br><span class="line"> loss_fn(outputs, labels).backward()</span><br><span class="line"> optimizer.step()</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Not necessary to use a dist.barrier() to guard the file deletion below</span></span><br><span class="line"> <span class="comment"># as the AllReduce ops in the backward pass of DDP already served as</span></span><br><span class="line"> <span class="comment"># a synchronization.</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> rank == <span class="number">0</span>:</span><br><span class="line"> os.remove(CHECKPOINT_PATH)</span><br><span class="line"></span><br><span class="line"> cleanup()</span><br></pre></td></tr></tbody></table></figure><h3 id="将-DDP-与模型并行相结合"><a href="#将-DDP-与模型并行相结合" class="headerlink" title="将 DDP 与模型并行相结合"></a><strong>将 DDP 与模型并行相结合</strong></h3><p><a href="https://pytorch.org/tutorials/intermediate/ddp_tutorial.html">https://pytorch.org/tutorials/intermediate/ddp_tutorial.html</a> 后半部分是与模型并行的结合。</p><p>DDP 也适用于多 GPU 模型。DDP 在使用大数据训练大模型时候特别有用。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">ToyMpModel</span>(nn.Module):</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, dev0, dev1</span>):</span><br><span class="line"> <span class="built_in">super</span>(ToyMpModel, self).__init__()</span><br><span class="line"> self.dev0 = dev0</span><br><span class="line"> self.dev1 = dev1</span><br><span class="line"> self.net1 = torch.nn.Linear(<span class="number">10</span>, <span class="number">10</span>).to(dev0)</span><br><span class="line"> self.relu = torch.nn.ReLU()</span><br><span class="line"> self.net2 = torch.nn.Linear(<span class="number">10</span>, <span class="number">5</span>).to(dev1)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line"> x = x.to(self.dev0)</span><br><span class="line"> x = self.relu(self.net1(x))</span><br><span class="line"> x = x.to(self.dev1)</span><br><span class="line"> <span class="keyword">return</span> self.net2(x)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>注意,当把一个多GPU 模型传递给DDP时候,不能设置<code>device_ids</code>和<code>output_device</code>。</p><p>输入和输出数据将通过应用程序或模型<code>forward()</code>方法来放置在适当的设备中。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">demo_model_parallel</span>(<span class="params">rank, world_size</span>):</span><br><span class="line"> <span class="built_in">print</span>(<span class="string">f"Running DDP with model parallel example on rank <span class="subst">{rank}</span>."</span>)</span><br><span class="line"> setup(rank, world_size)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># setup mp_model and devices for this process</span></span><br><span class="line"> dev0 = (rank * <span class="number">2</span>) % world_size</span><br><span class="line"> dev1 = (rank * <span class="number">2</span> + <span class="number">1</span>) % world_size</span><br><span class="line"> mp_model = ToyMpModel(dev0, dev1)</span><br><span class="line"> ddp_mp_model = DDP(mp_model)</span><br><span class="line"></span><br><span class="line"> loss_fn = nn.MSELoss()</span><br><span class="line"> optimizer = optim.SGD(ddp_mp_model.parameters(), lr=<span class="number">0.001</span>)</span><br><span class="line"></span><br><span class="line"> optimizer.zero_grad()</span><br><span class="line"> <span class="comment"># outputs will be on dev1</span></span><br><span class="line"> outputs = ddp_mp_model(torch.randn(<span class="number">20</span>, <span class="number">10</span>))</span><br><span class="line"> labels = torch.randn(<span class="number">20</span>, <span class="number">5</span>).to(dev1)</span><br><span class="line"> loss_fn(outputs, labels).backward()</span><br><span class="line"> optimizer.step()</span><br><span class="line"></span><br><span class="line"> cleanup()</span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">"__main__"</span>:</span><br><span class="line"> n_gpus = torch.cuda.device_count()</span><br><span class="line"> <span class="keyword">assert</span> n_gpus >= <span class="number">2</span>, <span class="string">f"Requires at least 2 GPUs to run, but got <span class="subst">{n_gpus}</span>"</span></span><br><span class="line"> world_size = n_gpus</span><br><span class="line"> run_demo(demo_basic, world_size)</span><br><span class="line"> run_demo(demo_checkpoint, world_size)</span><br><span class="line"> run_demo(demo_model_parallel, world_size)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>请注意,这里没有使用 Sampler,正常在使用之中,需要用DistributedSampler来配合 DDP 使用,DistributedSampler 会把数据集样本针对每个进程来划分,这样每个进程就读取到了自己应该使用的样本,而且 <code>DistributedSampler</code> 会为 DDP 模式使用 <code>set_epoch</code> 来<code>shuffle</code>数据集。</p><h2 id="如何多进程启动"><a href="#如何多进程启动" class="headerlink" title="如何多进程启动"></a><strong>如何多进程启动</strong></h2><p>前面提到,如果应用程序需要跨机器边界进行扩展,需要使用多机 <code>DistributedDataParallel</code> 和 启动脚本。<code>torch.nn.parallel.DistributedDataParallel()</code> 支持多个通过网络互联的机器,用户必须为每个进程显式启动一个主训练脚本。</p><p>我们下面就看看这个启动脚本 <a href="https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md%E3%80%82%E4%BB%A5%E4%B8%8B%E5%B0%B1%E6%98%AF%E8%BF%99%E4%B8%AAmd%E6%96%87%E4%BB%B6%E7%9A%84%E7%BF%BB%E8%AF%91%E3%80%82">https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md。以下就是这个md文件的翻译。</a></p><p>在本教程中,我们将演示如何构建分布式模型训练应用程序,这样它可以在多个节点上方便地启动。这里每个节点都有多个 GPU,并且使用 PyTorch 的分布式启动程序脚本 <a href="https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py">https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py</a> 启动实用程序<code>torch.distributed.launch</code>,此脚本程序可用于为每个节点启动多个进程以进行分布式训练,它在每个训练节点上产生多个分布式训练进程。</p><p>这个工具可以用作CPU训练或者GPU 训练,如果被用于GPU,每个GPU产生一个进程Process。该工具既可以用来做单节点多GPU训练,也可用于多节点多GPU训练。</p><ul><li>如果是单节点多GPU,将会在单个GPU上运行一个分布式进程,据称可以非常好地改进单节点训练性能。</li><li>如果用于多节点分布式训练,则通过在每个节点上产生多个进程来获得更好的多节点分布式训练性能。如果有<code>Infiniband</code>接口则加速比会更高。</li></ul><p>在 单节点分布式训练 或 多节点分布式训练 的两种情况下,该工具将<strong>为每个节点启动给定数量的进程</strong>(<code>--nproc_per_node</code>)。如果用于GPU训练,则此数字需要小于或等于当前系统上的GPU数量(<code>nproc_per_node</code>),每个进程将在从GPU 0到GPU(<code>nproc_per_node - 1</code>)的单个GPU上运行。</p><h3 id="先决条件"><a href="#先决条件" class="headerlink" title="先决条件"></a><strong>先决条件</strong></h3><p>多个worker通过处理大型数据集的不同部分来训练同一个全局模型,每个worker将独立计算局部梯度(也称为<em>子</em>梯度 sub-gradients),然后使用 <code>AllReduce</code> 原语来同步梯度。因为同一个程序在所有应用上运行,但每个应用都在训练数据集的不同部分上运行,所以在 HPC 术语中,这种执行模型称为<em>单程序多数据</em>或 SPMD,</p><h3 id="应用进程拓扑"><a href="#应用进程拓扑" class="headerlink" title="应用进程拓扑"></a><strong>应用进程拓扑</strong></h3><p>一个分布式数据并行 (DDP) 应用程序可以在多个节点上执行,其中每个节点可以由多个 GPU 设备组成。每个节点依次可以运行 DDP 应用程序的多个副本,每个副本在多个 GPU 上处理其模型。</p><p>设<em>N</em>为运行应用程序的节点数, <em>G</em>为每个节点的 GPU 数。同时在所有节点上运行的应用程序进程总数称为 <strong>World Size</strong>,简写为<em>W</em>。在每个节点上运行的进程数称为<strong>Local World Size</strong>,简写为<em>L</em>。</p><p>每个应用进程都分配了两个 ID:<code>local rank</code> 取值在 [0, <em>L</em> -1] 中,global rank 取值在 [0, <em>W</em> -1] 之中。</p><p>为了阐明上面定义的术语,我们考虑在两个节点上启动 DDP 应用程序的情况,每个节点都有四个 GPU。然后我们希望每个进程跨越(span)两个 GPU。进程到节点的映射如下图所示:</p><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211114211252570-39939734.png" alt="下面图片也出自于 [https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md。](https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md%E3%80%82)"></p><p>下面图片也出自于 <a href="https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md%E3%80%82">https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md。</a></p><p>虽然有很多方法可以将进程映射到节点,但一个好的经验法则是让一个进程跨越(span)单个 GPU。这使得 DDP 应用程序能够拥有与 GPU 一样多的并行读取流,并且在现实中也提供了 I/O 和计算成本之间的良好平衡。</p><h3 id="准备和启动-DDP-应用程序"><a href="#准备和启动-DDP-应用程序" class="headerlink" title="准备和启动 DDP 应用程序"></a><strong>准备和启动 DDP 应用程序</strong></h3><p>无论 DDP 应用程序采用何种启动方式,每个进程都需要一种机制来了解其全局和本地等级。所以,所有进程会创建一个<code>ProcessGroup</code>,基于<code>ProcessGroup</code>可以使它们能够参与诸如 <code>AllReduce</code> 之类的集合通信操作。</p><p>有一种便捷的方法可以启动多个 DDP 进程,并且可以初始化所有参数(这些数值是建立一个<code>ProcessGroup</code> 所需要的),这就是使用PyTorch 提供的分布式 脚本<code>launch.py</code>。</p><p>这个 Launcher 可以在本地<code>torch</code> 安装目录的<code>distributed</code>子目录下找到。这是在任何操作系统上获取<code>launch.py</code>路径的快捷方法 :</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">python -c <span class="string">" from os import path; import torch; print(path.join(path.dirname(torch.__file__), 'distributed', 'launch.py')) "</span></span><br></pre></td></tr></tbody></table></figure><p>这将打印如下内容:</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">/home/username/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/distributed/launch.py</span><br></pre></td></tr></tbody></table></figure><p>当 DDP 应用程序通过 <code>launch.py</code>启动时,它通过环境变量将 <code>world size</code>、 <code>global rank</code>、<code>local rank</code>,<code>master address</code> 和端口作为命令行参数传递给每个实例。要使用 Launcher,应用程序需要遵守以下约定:</p><ul><li>必须为<em>单个 worker</em>提供入口点函数。例如,它不应该使用<code>torch.multiprocessing.spawn</code>启动子进程。</li><li>必须使用环境变量来初始化进程组。</li></ul><p>为简单起见,应用程序可以假设每个进程映射到单个 GPU,但在下一节中,我们还将展示如何用更通用的办法来执行进程到 GPU 的映射。</p><h3 id="示例应用"><a href="#示例应用" class="headerlink" title="示例应用"></a><strong>示例应用</strong></h3><p>此示例 DDP 应用程序基于 <a href="https://pytorch.org/tutorials/intermediate/ddp_tutorial.html">DDP 教程</a> 的 “Hello, World” 应用。</p><h3 id="参数传递约定"><a href="#参数传递约定" class="headerlink" title="参数传递约定"></a><strong>参数传递约定</strong></h3><p>DDP 应用程序采用两个命令行参数:</p><ol><li><code>-local_rank</code>: 此参数将通过 <code>launch.py</code>传入。</li><li><code>-local_world_size</code>:这是明确传递的,通常是数字1或每个节点的 GPU 数量。</li></ol><p>应用程序解析这些并调用<code>spmd_main</code>入口点:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> __name__ == <span class="string">"__main__"</span>:</span><br><span class="line"> parser = argparse.ArgumentParser()</span><br><span class="line"> parser.add_argument(<span class="string">"--local_rank"</span>, <span class="built_in">type</span>=<span class="built_in">int</span>, default=<span class="number">0</span>)</span><br><span class="line"> parser.add_argument(<span class="string">"--local_world_size"</span>, <span class="built_in">type</span>=<span class="built_in">int</span>, default=<span class="number">1</span>)</span><br><span class="line"> args = parser.parse_args()</span><br><span class="line"> spmd_main(args.local_world_size, args.local_rank)</span><br></pre></td></tr></tbody></table></figure><p>在 <code>spmd_main</code>之中,进程组使用后端(NCCL 或 Gloo)进行初始化。集合点(rendezvous )所需的其余信息来自<code>launch.py</code>设置的环境变量:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">spmd_main</span>(<span class="params">local_world_size, local_rank</span>):</span><br><span class="line"> <span class="comment"># These are the parameters used to initialize the process group</span></span><br><span class="line"> env_dict = {</span><br><span class="line"> key: os.environ[key]</span><br><span class="line"> <span class="keyword">for</span> key <span class="keyword">in</span> (<span class="string">"MASTER_ADDR"</span>, <span class="string">"MASTER_PORT"</span>, <span class="string">"RANK"</span>, <span class="string">"WORLD_SIZE"</span>)</span><br><span class="line"> }</span><br><span class="line"> <span class="built_in">print</span>(<span class="string">f"[<span class="subst">{os.getpid()}</span>] Initializing process group with: <span class="subst">{env_dict}</span>"</span>)</span><br><span class="line"> dist.init_process_group(backend=<span class="string">"nccl"</span>)</span><br><span class="line"> <span class="built_in">print</span>(</span><br><span class="line"> <span class="string">f"[<span class="subst">{os.getpid()}</span>] world_size = <span class="subst">{dist.get_world_size()}</span>, "</span></span><br><span class="line"> + <span class="string">f"rank = <span class="subst">{dist.get_rank()}</span>, backend=<span class="subst">{dist.get_backend()}</span>"</span></span><br><span class="line"> )</span><br><span class="line"></span><br><span class="line"> demo_basic(local_world_size, local_rank)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Tear down the process group</span></span><br><span class="line"> dist.destroy_process_group()</span><br></pre></td></tr></tbody></table></figure><p>给定 local rank 和 world size,训练函数<code>demo_basic</code>将通过<code>device_ids</code>在本地节点的一组 GPU 上初始化<code>DistributedDataParallel</code>模型:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">demo_basic</span>(<span class="params">local_world_size, local_rank</span>):</span><br><span class="line"></span><br><span class="line"> <span class="comment"># setup devices for this process. For local_world_size = 2, num_gpus = 8,</span></span><br><span class="line"> <span class="comment"># rank 0 uses GPUs [0, 1, 2, 3] and</span></span><br><span class="line"> <span class="comment"># rank 1 uses GPUs [4, 5, 6, 7].</span></span><br><span class="line"> n = torch.cuda.device_count() // local_world_size</span><br><span class="line"> device_ids = <span class="built_in">list</span>(<span class="built_in">range</span>(local_rank * n, (local_rank + <span class="number">1</span>) * n))</span><br><span class="line"></span><br><span class="line"> <span class="built_in">print</span>(</span><br><span class="line"> <span class="string">f"[<span class="subst">{os.getpid()}</span>] rank = <span class="subst">{dist.get_rank()}</span>, "</span></span><br><span class="line"> + <span class="string">f"world_size = <span class="subst">{dist.get_world_size()}</span>, n = <span class="subst">{n}</span>, device_ids = <span class="subst">{device_ids}</span>"</span></span><br><span class="line"> )</span><br><span class="line"></span><br><span class="line"> model = ToyModel().cuda(device_ids[<span class="number">0</span>])</span><br><span class="line"> ddp_model = DDP(model, device_ids)</span><br><span class="line"></span><br><span class="line"> loss_fn = nn.MSELoss()</span><br><span class="line"> optimizer = optim.SGD(ddp_model.parameters(), lr=<span class="number">0.001</span>)</span><br><span class="line"></span><br><span class="line"> optimizer.zero_grad()</span><br><span class="line"> outputs = ddp_model(torch.randn(<span class="number">20</span>, <span class="number">10</span>))</span><br><span class="line"> labels = torch.randn(<span class="number">20</span>, <span class="number">5</span>).to(device_ids[<span class="number">0</span>])</span><br><span class="line"> loss_fn(outputs, labels).backward()</span><br><span class="line"> optimizer.step()</span><br></pre></td></tr></tbody></table></figure><p>该应用程序可以通过<code>launch.py</code>以下方式在一个 8 GPU 的节点上启动,每个 GPU 一个进程:</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">python /path/to/launch.py --nnode=1 --node_rank=0 --nproc_per_node=8 example.py --local_world_size=8</span><br></pre></td></tr></tbody></table></figure><p>并产生类似于下图所示的输出:</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line">*****************************************</span><br><span class="line">Setting OMP_NUM_THREADS environment variable <span class="keyword">for</span> each process to be 1 <span class="keyword">in</span> default, to avoid your system being overloaded, please further tune the variable <span class="keyword">for</span> optimal performance <span class="keyword">in</span> your application as needed.</span><br><span class="line">*****************************************</span><br><span class="line">[238627] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'0'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'8'</span>}</span><br><span class="line">[238630] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'3'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'8'</span>}</span><br><span class="line">[238628] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'1'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'8'</span>}</span><br><span class="line">[238634] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'7'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'8'</span>}</span><br><span class="line">[238631] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'4'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'8'</span>}</span><br><span class="line">[238632] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'5'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'8'</span>}</span><br><span class="line">[238629] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'2'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'8'</span>}</span><br><span class="line">[238633] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'6'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'8'</span>}</span><br><span class="line">[238633] world_size = 8, rank = 6, backend=nccl</span><br><span class="line">[238628] world_size = 8, rank = 1, backend=nccl</span><br><span class="line">[238629] world_size = 8, rank = 2, backend=nccl</span><br><span class="line">[238631] world_size = 8, rank = 4, backend=nccl</span><br><span class="line">[238630] world_size = 8, rank = 3, backend=nccl</span><br><span class="line">[238632] world_size = 8, rank = 5, backend=nccl</span><br><span class="line">[238634] world_size = 8, rank = 7, backend=nccl</span><br><span class="line">[238627] world_size = 8, rank = 0, backend=nccl</span><br><span class="line">[238633] rank = 6, world_size = 8, n = 1, device_ids = [6]</span><br><span class="line">[238628] rank = 1, world_size = 8, n = 1, device_ids = [1]</span><br><span class="line">[238632] rank = 5, world_size = 8, n = 1, device_ids = [5]</span><br><span class="line">[238634] rank = 7, world_size = 8, n = 1, device_ids = [7]</span><br><span class="line">[238629] rank = 2, world_size = 8, n = 1, device_ids = [2]</span><br><span class="line">[238630] rank = 3, world_size = 8, n = 1, device_ids = [3]</span><br><span class="line">[238631] rank = 4, world_size = 8, n = 1, device_ids = [4]</span><br><span class="line">[238627] rank = 0, world_size = 8, n = 1, device_ids = [0]</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>同样,它可以使用一个跨越(span)所有 8 个 GPU 的单进程来启动:</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">python /path/to/launch.py --nnode=1 --node_rank=0 --nproc_per_node=1 example.py --local_world_size=1</span><br></pre></td></tr></tbody></table></figure><p>为当前主机创建 <code>nproc_per_node</code> 个进程,每个进程独立执行训练脚本,同时还为每个进程分配一个 <code>local_rank</code> 参数,表示当前进程在当前主机上的编号。</p><p>比如 <code>node_rank = 2</code>, <code>local_rank = 0</code>,表示 <code>node_rank</code> 第2个节点,上第一个进程。</p><p>依次产生以下输出</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">[262816] Initializing process group with: {<span class="string">'MASTER_ADDR'</span>: <span class="string">'127.0.0.1'</span>, <span class="string">'MASTER_PORT'</span>: <span class="string">'29500'</span>, <span class="string">'RANK'</span>: <span class="string">'0'</span>, <span class="string">'WORLD_SIZE'</span>: <span class="string">'1'</span>}</span><br><span class="line">[262816]: world_size = 1, rank = 0, backend=nccl</span><br><span class="line">[262816] rank = 0, world_size = 1, n = 8, device_ids = [0, 1, 2, 3, 4, 5, 6, 7]</span><br></pre></td></tr></tbody></table></figure><h3 id="结论"><a href="#结论" class="headerlink" title="结论"></a><strong>结论</strong></h3><p>作为分布式数据并行应用程序的作者,您的代码需要了解两种类型的资源:计算节点和每个节点内的 GPU。但是需要跟踪GPU集如何映射到应用程序进程,这个簿记(bookkeeping )工作可能既乏味又容易出错。</p><p>所以Pytroch希望通过按照本示例所示的方法,使用 launcher 来构建您的应用程序,这样可以显著简化分布式训练的设置。</p><h3 id="启动脚本的背后"><a href="#启动脚本的背后" class="headerlink" title="启动脚本的背后"></a><strong>启动脚本的背后</strong></h3><p>知道了启动脚本的作用依然不够,我们还需要知道其内部做了什么。</p><h3 id="launch-py"><a href="#launch-py" class="headerlink" title="launch.py"></a><strong>launch.py</strong></h3><p>launch.py 位于 torch/distributed/launch.py,但是实际上,它的大部分功能都被转移到了 torch/distributed/run.py 之中。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">main</span>(<span class="params">args=<span class="literal">None</span></span>):</span><br><span class="line"> logger.warn(</span><br><span class="line"> <span class="string">"The module torch.distributed.launch is deprecated "</span></span><br><span class="line"> <span class="string">"and going to be removed in future."</span></span><br><span class="line"> <span class="string">"Migrate to torch.distributed.run"</span></span><br><span class="line"> )</span><br><span class="line"> args = parse_args(args)</span><br><span class="line"> run(args)</span><br></pre></td></tr></tbody></table></figure><p>所以我们要看看 <code>run.py</code>。</p><h3 id="run-py"><a href="#run-py" class="headerlink" title="run.py"></a><strong>run.py</strong></h3><p>可以看到,<code>run.py</code>的基本思路就是:使用 <code>config_from_args</code> 来从命令行之中提取信息,构建了对应的配置,执行语句和其参数,然后调用 <code>elastic_launch</code> 来执行。由此可见,弹性训练是未来趋势。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">run</span>(<span class="params">args</span>):</span><br><span class="line"> <span class="keyword">if</span> args.standalone:</span><br><span class="line"> args.rdzv_backend = <span class="string">"c10d"</span></span><br><span class="line"> args.rdzv_endpoint = <span class="string">"localhost:29400"</span></span><br><span class="line"> args.rdzv_id = <span class="built_in">str</span>(uuid.uuid4())</span><br><span class="line"> log.info(</span><br><span class="line"> <span class="string">f"\n**************************************\n"</span></span><br><span class="line"> <span class="string">f"Rendezvous info:\n"</span></span><br><span class="line"> <span class="string">f"--rdzv_backend=<span class="subst">{args.rdzv_backend}</span> "</span></span><br><span class="line"> <span class="string">f"--rdzv_endpoint=<span class="subst">{args.rdzv_endpoint}</span> "</span></span><br><span class="line"> <span class="string">f"--rdzv_id=<span class="subst">{args.rdzv_id}</span>\n"</span></span><br><span class="line"> <span class="string">f"**************************************\n"</span></span><br><span class="line"> )</span><br><span class="line"></span><br><span class="line"> config, cmd, cmd_args = config_from_args(args)</span><br><span class="line"> elastic_launch(</span><br><span class="line"> config=config,</span><br><span class="line"> entrypoint=cmd,</span><br><span class="line"> )(*cmd_args)</span><br></pre></td></tr></tbody></table></figure><p><code>run.py</code> 也可以独立运行,比如。</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">>>> python -m torch.distributed.run</span><br><span class="line"> --nnodes=<span class="variable">$NUM_NODES</span></span><br><span class="line"> --nproc_per_node=<span class="variable">$NUM_TRAINERS</span></span><br><span class="line"> --rdzv_id=<span class="variable">$JOB_ID</span></span><br><span class="line"> --rdzv_backend=c10d</span><br><span class="line"> --rdzv_endpoint=<span class="variable">$HOST_NODE_ADDR</span></span><br><span class="line"> YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><h3 id="定义"><a href="#定义" class="headerlink" title="定义"></a><strong>定义</strong></h3><p>因为<code>run.py</code> 有很多配置参数,所以我们大致看一下。</p><ol><li><p><code>Node</code> - 物理实例或容器;映射到与 job manager 所协调的单元。</p></li><li><p><code>Worker</code> - 分布式培训环境中的worker。</p></li><li><p><code>WorkerGroup</code> - 执行相同功能的一组worker(例如trainer)。</p></li><li><p><code>LocalWorkerGroup</code> - 在同一节点上运行的工作组中的workers子集。</p></li><li><p><code>RANK</code> - 工作组中worker的rank,是全局rank,可以认为是一个全局GPU资源列表。</p></li><li><p><code>LOCAL_RANK</code> - 本地工作组中,某个worker 的 rank,可以认为是当前节点上的GPU资源列表。</p></li><li><p><code>GROUP_RANK</code> - worker group的rank。介于0和“最大节点数”之间的数字。如果每个节点运行一个单一工作组,那就是这个节点的rank。</p></li><li><p><code>ROLE_RANK</code> - 对于具有相同角色worker来说,他们之间共享的rank,角色在“WorkerSpec”中被指定。</p></li><li><p><code>WORLD_SIZE</code> - 工作组中worker的总数。因为节点会加入/离开,所以<code>WORLD_SIZE</code>会变化,不能依赖 <code>WORLD_SIZE</code>的稳定性进行编码。</p></li><li><p><code>LOCAL_WORLD_SIZE</code> - 本地工作组的大小,即本地运行的worker数目,等于在<code>torch.distributed.run</code>运行时候指定的<code>-nproc_per_node</code>。目前,torch/distributed/run.py 仅支持同构的 <code>LOCAL_WORLD_SIZE</code>。也就是说,假设所有节点运行相同数量的本地工作者(每个角色)。</p></li><li><p><code>ROLE_WORLD_SIZE</code> - 具有同样角色的workers总数,在 <code>WorkerSpec</code>之中被指定。</p></li><li><p><code>rdzv_id</code> - 用户定义的id,用于唯一标识作业的工作组。这个id在每个节点加入特定工作组时候使用。</p></li><li><p><code>rdzv_backend</code>rendezvous 的后端(例如“c10d”)。这通常是一个强一致性的键值存储。</p></li><li><p><code>rdzv_endpoint</code> - rendezvous 后端端点;通常以“<code><host>:<port></code>”的形式出现。</p></li><li><p><code>run_id</code>: 用户定义的id,它唯一地标识分布式应用程序的一个实例。它通常映射到作业id并用于</p><p>允许节点加入正确的分布式应用程序。</p></li><li><p><code>TORCHELASTIC_RESTART_COUNT</code> - 迄今为止,工作组重启的次数。</p></li><li><p><code>TORCHELASTIC_MAX_RESTARTS</code> - 配置的最大重启数目。</p></li><li><p><code>TORCHELASTIC_RUN_ID</code> - 与 rendezvous <code>run_id</code> 相等,即唯一的job id。</p></li></ol>]]></content>
<summary type="html">DistributedDataParallel 总述&如何使用</summary>
<category term="分布式训练" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="分布式训练" scheme="https://thinksky5124.github.io/tags/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="工程" scheme="https://thinksky5124.github.io/tags/%E5%B7%A5%E7%A8%8B/"/>
</entry>
<entry>
<title>PyTorch分布式训练</title>
<link href="https://thinksky5124.github.io/2022/08/18/PyTorch%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<id>https://thinksky5124.github.io/2022/08/18/PyTorch%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<h1 id="PyTorch分布式训练"><a href="#PyTorch分布式训练" class="headerlink" title="PyTorch分布式训练"></a>PyTorch分布式训练</h1><h1 id="数据并行训练"><a href="#数据并行训练" class="headerlink" title="数据并行训练"></a>数据并行训练</h1><p>PyTorch 为数据并行训练提供了多种选项。一般来说,应用会从简单到复杂,从原型到量产。这些应用共同的发展轨迹是:</p><ol><li>如果数据和模型可以放在一个 GPU 中,并且不关心训练速度,就使用单设备(single-device)训练。</li><li>如果服务器上有多个 GPU,并且您希望以最少的代码更改来加速训练,那么可以使用单机多 GPU DataParallel。</li><li>如果您想进一步加快训练速度并愿意编写更多代码来设置它,可以使用单机多 GPU DistributedDataParallel。</li><li>如果应用程序需要跨机器边界进行扩展,请使用多机 DistributedDataParallel 和 <a href="https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md">启动脚本</a>。</li><li>如果预期会出现错误(例如,OOM)或者资源可以在训练期间动态加入和离开,则使用<a href="https://pytorch.org/elastic">torchelastic</a>启动分布式训练。</li></ol><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211105002040087-1992121188.png" alt="Untitled"></p><h1 id="Torch如何使用GPU"><a href="#Torch如何使用GPU" class="headerlink" title="Torch如何使用GPU"></a>Torch如何使用GPU</h1><p><img src="https://img2020.cnblogs.com/blog/1850883/202111/1850883-20211101212805949-1423202605.png" alt="Untitled"></p><ul><li><strong>_apply 方法</strong><ul><li>遍历 _parameters:<ul><li>对参数调用fn进行处理,得到param_applied。<ul><li>用 param_applied 重新设置参数。</li></ul></li><li>如果参数有梯度,则:<ul><li>对参数的grad调用fn进行处理,得到grad_applied。</li><li>用 grad_applied 重新设置参数的梯度。</li></ul></li></ul></li><li>遍历 _buffers:<ul><li>对buf调用fn进行处理。</li></ul></li></ul></li></ul><p>调用 cuda 或者 to 方法来移动模型到GPU,其实就是把模型的<code>self._parameters</code>和 <code>self._buffers</code>移动到 GPU,并没有对 <code>self._modules</code>进行移动。这个移动过程是递归调用的,是把模型每个叶子都移动到了 GPU 之上。</p><figure class="highlight plaintext"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line">+</span><br><span class="line"> |</span><br><span class="line">+---------------------------------+ | +----------------------------------+</span><br><span class="line">| CPU | | | CPU |</span><br><span class="line">| +--------------+ | | | +--------------------+ |</span><br><span class="line">| |Module | | | | | Module | |</span><br><span class="line">| | | | | | | | |</span><br><span class="line">| | _parameters+----> Parameters | | | | _parameters ------+ |</span><br><span class="line">| | | | | | | | | |</span><br><span class="line">| | _buffers +------> Buffers | | | +-----+ _buffers | | |</span><br><span class="line">| | | | | | | | | | |</span><br><span class="line">| | _modules | | | | | | _modules | | |</span><br><span class="line">| | | | | | | | | | |</span><br><span class="line">| +--------------+ | | | | +--------------------+ | |</span><br><span class="line">| | | | | | |</span><br><span class="line">+---------------------------------+ | +----------------------------------+</span><br><span class="line"> | | |</span><br><span class="line"> + | |</span><br><span class="line">+-------------------------------> Module.cuda() +---------------------------------> Time</span><br><span class="line"> + | |</span><br><span class="line"> | | |</span><br><span class="line">+---------------------------------+ | +----------------------------------+</span><br><span class="line">| GPU | | | GPU | | |</span><br><span class="line">| | | | | | |</span><br><span class="line">| | | | | Parameters <-----+ |</span><br><span class="line">| | | | | |</span><br><span class="line">| | | | | |</span><br><span class="line">| | | | +----> Buffers |</span><br><span class="line">| | | | |</span><br><span class="line">| | | | |</span><br><span class="line">+---------------------------------+ | +----------------------------------+</span><br><span class="line"> |</span><br><span class="line"> +</span><br></pre></td></tr></tbody></table></figure>]]></content>
<summary type="html">PyTorch分布式训练</summary>
<category term="分布式训练" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="分布式训练" scheme="https://thinksky5124.github.io/tags/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="工程" scheme="https://thinksky5124.github.io/tags/%E5%B7%A5%E7%A8%8B/"/>
</entry>
<entry>
<title>SLAM数学基础</title>
<link href="https://thinksky5124.github.io/2022/08/18/SLAM_math_fundation/"/>
<id>https://thinksky5124.github.io/2022/08/18/SLAM_math_fundation/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<h1 id="SLAM数学基础"><a href="#SLAM数学基础" class="headerlink" title="SLAM数学基础"></a>SLAM数学基础</h1><h1 id="三维空间的刚体运动"><a href="#三维空间的刚体运动" class="headerlink" title="三维空间的刚体运动"></a>三维空间的刚体运动</h1><p>刚体:刚体是指在运动中和受力作用后,形状和大小不变,而且内部各点的相对位置不变的物体。</p><h2 id="旋转矩阵"><a href="#旋转矩阵" class="headerlink" title="旋转矩阵"></a>旋转矩阵</h2><p>$$ SO(n)={R \in \mathbb{R}^{n \times n}|R R^T = I, det(R)=1} $$</p><p>$SO(n)$是特殊正交群,刚体在两个坐标系之间的变换可以公式化表述为</p><p>$$ a’=Ra+t $$</p><p>$t$表示的是从a坐标系原点到b坐标系原点的向量</p><p>三维旋转矩阵的计算如下</p><p>$$<br>\begin{bmatrix}<br> e^T_1 e’_1 & e^T_1 e’_2 & e^T_1 e’_3\<br> e^T_2 e’_1 & e^T_2 e’_2 & e^T_2 e’_3 \<br> e^T_3 e’_1 & e^T_3 e’_2 & e^T_3 e’_3\<br>\end{bmatrix}<br>$$</p><p>$e_i$是某一坐标系的基向量</p><h2 id="齐次坐标"><a href="#齐次坐标" class="headerlink" title="齐次坐标"></a>齐次坐标</h2><p>连续变换坐标系时,使用上述的变换公式不是线性的表述,写起来会很麻烦,如</p><p>$$<br>c=R_2(R_1a+t_1)+t_2<br>$$</p><p>因而在三维向量的末尾加上一个1,就变为了四维向量,称为齐次坐标,变换式变为如下形式</p><p>$$<br>\begin{bmatrix}a’\1\\end{bmatrix}=\begin{bmatrix}R & t\o^T&1\\end{bmatrix}\begin{bmatrix}a\1\\end{bmatrix}=T\begin{bmatrix}a\1\\end{bmatrix}<br>$$</p><p>矩阵$T$称为变换矩阵,这样连续变换坐标系就变成了线性表示$\begin{bmatrix}c\1\\end{bmatrix}=T_2T_1\begin{bmatrix}a\1\\end{bmatrix}$</p><p>反向变换为</p><p>$$<br>T^{-1}=\begin{bmatrix}R^T & -R^Tt\o^T&1\\end{bmatrix}<br>$$</p><p><strong>反对称矩阵</strong></p><p>$$<br>a^{\wedge}=\begin{bmatrix}a_1 \ a_2\a_3\\end{bmatrix}^{\wedge}=\begin{bmatrix}0 & -a_3 & a_2 \ a_3 & 0 & -a_1 \ -a_2 & a_1 & 0\ \end{bmatrix}<br>$$</p><h2 id="旋转向量和欧拉角"><a href="#旋转向量和欧拉角" class="headerlink" title="旋转向量和欧拉角"></a>旋转向量和欧拉角</h2><p>矩阵方式表达旋转至少有以下两个缺点</p><ol><li>$SO(3)$的旋转矩阵有9个量,但是一次旋转只有3个自由度,这种表达方式是冗余的</li><li>旋转矩阵自身带有约束:它必须是个正交矩阵,而且行列式为1,这些约束会使求解变得困难</li></ol><h3 id="旋转向量"><a href="#旋转向量" class="headerlink" title="旋转向量"></a>旋转向量</h3><p>任意的旋转都可以用一个<strong>旋转轴</strong>和一个<strong>旋转角</strong>来刻画,我们使用一个向量,方向与旋转轴一致,长度等于旋转角,称之为旋转向量$n$。</p><p>旋转向量到旋转矩阵的转换过程由罗德里格斯公式表明</p><p>$$<br>R = cos\theta I + (1 - cos\theta)nn^T + sin\theta n^{\wedge}<br>$$</p><p>旋转矩阵到旋转向量的转换过程有</p><p>$$<br>tr(R) = cos \theta tr(I) + (1 - cos\theta)tr(nn^T) + sin\theta tr(n^-) = 3 cos\theta + (1 - cos \theta) = 1 + 2 cos\theta<br>$$</p><p>$$<br>\theta=arccos \frac{tr(R) - 1}{2}<br>$$</p><p>关于转轴$n$,旋转轴上的向量在旋转后不发生变化,表明</p><p>$$<br>Rn=n<br>$$</p><h3 id="欧拉角"><a href="#欧拉角" class="headerlink" title="欧拉角"></a>欧拉角</h3><p>以ZYX为例子,可以把任意旋转分解为一下3个轴上的转角</p><ol><li>偏航角yaw</li><li>俯仰角pitch</li><li>滚转角roll</li></ol><p>使用$\begin{bmatrix}r & p & y\end{bmatrix}^T$向量描述,但是欧拉角有一个重大的缺点就是会碰到著名的万向锁问题:在俯仰角为$\pm90^{\circ}$时,第一次旋转与第三次旋转将使用同一个轴,使得系统失去一个自由度,这被称为奇异性问题。由于这种原理欧拉角不适用于插值和迭代,往往只应用于人机交互中。</p><h2 id="四元数"><a href="#四元数" class="headerlink" title="四元数"></a>四元数</h2><h3 id="定义"><a href="#定义" class="headerlink" title="定义"></a>定义</h3><p>一个四元数拥有一个实部和三个虚部:$q=q_0+q_1 i+q_2 j+q_3 k$,这三个虚部满足如下关系</p><p>$$<br>\begin{cases}<br>i^2=j^2=k^2=-1 \<br>ij=k,ji=-k \ jk=i,kj=-i \ki=j, ik=-j<br>\end{cases}<br>$$</p><p>$i,j,k$对应三个坐标轴,乘一次单位$i,j,k$表示旋转$180^{\circ}$</p><h3 id="用四元数表示旋转"><a href="#用四元数表示旋转" class="headerlink" title="用四元数表示旋转"></a>用四元数表示旋转</h3><p>假设有一个空间三维点$q=\begin{bmatrix}x,y,z\end{bmatrix} \in R^3$,以及一个由单位四元数$q$指定的旋转,那么旋转之后的点变为$q’$公式表示为</p><p>$$<br>p=\begin{bmatrix}0,x,y,z\end{bmatrix}^T = \begin{bmatrix}0, v\end{bmatrix}^T, p’=qpq^{-1}<br>$$</p><h2 id="变换的种类"><a href="#变换的种类" class="headerlink" title="变换的种类"></a>变换的种类</h2><p><img src="https://s2.loli.net/2024/03/25/D3v1nbXtsxwIWBH.png" alt="change1.png"></p><p><img src="https://s2.loli.net/2024/03/25/WReSs1XKxoL4a3f.png" alt="change2.png"></p><h1 id="李群和李代数"><a href="#李群和李代数" class="headerlink" title="李群和李代数"></a>李群和李代数</h1><h2 id="前言"><a href="#前言" class="headerlink" title="前言"></a>前言</h2><p>在SLAM中,除了表达3D旋转与位移之外,我们还要对它们进行估计,因为SLAM整个过程就是在不断地估计机器人的位姿与地图。为了做这件事,需要对变换矩阵进行插值、求导、迭代等操作。例如,在经典ICP问题中,给定了两组3D点,我们要计算它们之间的变换矩阵。假设第一组的3D点为$\mathbf{P}={ \mathbf{p}_i | i = [1,2, \ldots, N] }$,第二组3D点为$\mathbf{Q}={ \mathbf{q}_i | i = [1,2, \ldots, N] }$,那我们实际要做的事情是求一个欧氏变换$\mathbf{T}$,使得$\mathbf{T}$满足</p><p>$$ \forall i, \quad \mathbf{q}_i = \mathbf{T} \mathbf{p}_i $$</p><p>注意这里使用了齐次坐标表示。通常,这许多个匹配过的点是通过特征匹配得到的,构成了一个超定方程。而由于噪声的存在,这个方程往往是无解的。因此我们转而计算一个最小二乘:</p><p>$$ \mathop {\min }\limits_{\mathbf{T} } u\left( {\mathbf{T} } \right) = \sum\limits_{i = 1}^N { { { \left| { {\mathbf{q}_i} - \mathbf{T} {\mathbf{p}_i} } \right| }^2 } } $$</p><p>这时问题就来了:如果用迭代方式求解这个优化时(尽管可以不用迭代方式来求),如何求目标函数$u$相对于$\mathbf{T}$的导数呢?首先,$\mathbf{T}$只有6 个自由度,最好能够在一个六维空间表达它,那么$u(\mathbf{T})$相对于这个六维空间的导数(雅可比矩阵)是一个$6 \times 6$的矩阵。其次,$\mathbf{T}$对于乘法是封闭的,但对加法不封闭,即任意两个变换矩阵相加后并不是一个变换矩阵,这主要是因为旋转矩阵对加法是不封闭的。<br>出于这两个原因,我们希望有更好的数学工具帮助我们做这些事,而李群与李代数理论正好提供了这样的工具。李群与李代数广泛地用于机器人与计算机视觉领域,并在机器人动力学推导上占据重要地位。不过,由于SLAM不涉及过多的动力学推导。我们重点介绍它在SLAM中相关的几个重要的结果,而略去许多数学性质的证明。特别地,重点介绍$SO(3)$和$SE(3)$这两个李群与对应的李代数。</p><h2 id="李代数基础"><a href="#李代数基础" class="headerlink" title="李代数基础"></a>李代数基础</h2><p>首先,我们来讨论较为简单的三维旋转群。为了说明它的结构,首先介绍群的概念。<br><strong>群(Group)</strong>是一种集合加上一种运算的代数结构,记作$(A,\cdot)$。其中$A$代表集合,$\cdot$是定义在该集合上的二元运算。那么,如果这个运算满足以下几个条件,则称$G=(A, \cdot)$为群。</p><ul><li>封闭性: $\quad \forall a_1, a_2, \quad a_1 \cdot a_2 \in A$</li><li>结合律: $\quad \forall a_1, a_2, a_3, \quad (a_1 \cdot a_2) \cdot a_3 = a_1 \cdot ( a_2 \cdot a_3)$</li><li>幺元: $\quad \exists a_0 \in A, \quad s.t. \quad \forall a \in A, \quad a_0 \cdot a = a \cdot a_0 = a$</li><li>逆: $\quad \forall a \in A, \quad \exists a^{-1} \in A, \quad s.t. \quad a \cdot a^{-1} = a_0$</li></ul><p>读者可以记作“封结幺逆”(谐音凤姐咬你),并可以把一些常见的群放进去验证。例如整数的加法(幺元为0),去掉0后的有理数的乘法(幺元为1)。对于矩阵,可以找到一些常见的矩阵群,例如:</p><ul><li>一般线性群$GL(n)$ 指$n \times n$的可逆矩阵,它们对矩阵乘法成群。</li><li>特殊正交群$SO(n)$也就是所谓的旋转矩阵群,其中$SO(2)$ 和$SO(3)$最为常见。正式的记法是:$SO(n) = { \mathbf{R} \in \mathbb{R}^{n \times n} | \mathbf{R R}^T = \mathbf{I}, det(\mathbf{R})=1 }$</li><li>特殊欧氏群$SE(n)$ 也就是前面提到的$n$维欧氏变换,如$SE(2)$和$SE(3)$。这里给出$SE(3)$的记法:$SE(3)=\left{ T = \begin{bmatrix}R & t\o^T&1\\end{bmatrix} \in \mathbb{R}^{4 \times 4}|\mathbf{R} \in SO(3), \mathbf{t} \in \mathbb{R}^3 \right}$</li></ul><p>群结构保证了在群上的运算具有良好的性质,而群论则研究群的各种结构和性质。</p><p><strong>李群</strong>是指具有连续性质的群。并且,一般连续群上的运算还是无限可微,乃至解析的(解析比无限可微更强,它还要求任意点邻域的泰勒展开都收敛)。这个问题在20世纪初被称为希尔伯特第五问题,并已得到了解决。而李群,则指实数空间上的连续群。常见的李群包括上边提到的$GL(n), SO(n), SE(n)$,以及其他的如酉群$U(n)$,辛群$Sp(2n)$等等。</p><h2 id="三维旋转群-SO-3"><a href="#三维旋转群-SO-3" class="headerlink" title="三维旋转群$SO(3)$"></a>三维旋转群$SO(3)$</h2><p>三维旋转群$SO(3)$是特殊正交群$SO(n)$在$n=3$时的特例,它们可以用来描述三维空间的旋转,其元素都是$3 \times3$ 的正交且行列式为$+1$的矩阵。假设有这样一个矩阵$\mathbf{R}$,满足$\mathbf{R} \mathbf{R}^T=\mathbf{I}$。现在,考虑它随时间发生变化,即从$\mathbf{R}$ 变成了$\mathbf{R}(t)$,仍有$\mathbf{R}(t) \mathbf{R}(t) ^T = \mathbf{I}$。在等式两边对时间求导,得到:</p><p>$$<br>\mathbf{\dot{R} } (t) \mathbf{R} {(t)^T} + \mathbf{R} (t) \mathbf{\dot{R} } {(t)^T} = 0<br>$$</p><p>于是</p><p>$$<br>\mathbf{\dot{R} } (t) \mathbf{R} {(t)^T} = - \left( \mathbf{\dot{R} } (t) \mathbf{R} {(t)^T} \right)^T<br>$$</p><p>可以看出$\mathbf{\dot{R} } (t) \mathbf{R} {(t)^T}$是一个反对称矩阵。注意到对于任意一个$3 \times 3$的反对称矩阵,我们记它为$\mathbf{A}$。由于$\mathbf{A}^T=-\mathbf{A}$,所以它主对角线元素必为$0$,而非对角线元素则只有三个自由度。我们可以把它对应到一个向量$\mathbf{a}=[a_1, a_2, a_3]^T$中去:</p><p>$$<br>{\mathbf{a}^ \wedge } = \mathbf{A} = \begin{bmatrix}0 & a_3 & a_2\a_3&0&-a_1\-a_2&a_1&0 \end{bmatrix}<br>$$</p><p>其中$^{\wedge}$符号表示由向量转换为矩阵,反之我们也可以用符号$^{\vee}$定义由矩阵转换为向量的方式:</p><p>注意到这样定义的好处之一,是它与叉积的兼容性。我们可以直接把矩阵与任意向量的乘积$\mathbf{A} \mathbf{b}$ 写成 $\mathbf{a} \times \mathbf{b}$。读者可以自行验证这个兼容性。除此之外,这样定义的向量还有一些较好的性质,后文会提到。</p><p>现在,由于$\mathbf{\dot{R} } (t) \mathbf{R} {(t)^T}$是一个反对称矩阵,我们可以找到一个三维向量$\mathbf{\phi} (t) \in \mathbb{R}^3$与之对应。于是有:</p><p>$$<br>\mathbf{ \dot{R} } (t) \mathbf{R}(t)^T = \mathbf{\phi} (t) ^ {\wedge}<br>$$</p><p>左右各右乘$\mathbf{R}(t)$,由于$\mathbf{R}$为正交阵,有:</p><p>$$<br>\mathbf{ \dot{R} } (t) = \mathbf{\phi} (t)^{\wedge} \mathbf{R}(t) = \begin{bmatrix}0&- \phi_3&\phi_2\\phi_3&0&-\phi_1\-\phi_2&\phi_1&0 \end{bmatrix} \mathbf{R}(t)<br>$$</p><p>可以看到,每对旋转矩阵求一次导数,只需左乘一个$\mathbf{\phi}$矩阵即可。由于$\mathbf{\phi}$反映了$\mathbf{R}$的导数性质,故称它在$SO(3)$的正切空间(tangent space)上。同时,将上式类比于一个关于$\mathbf{R}$的微分方程,可得:</p><p>$$<br>\mathbf{R}(t) = \exp \left( \mathbf{\phi} (t) ^\wedge \right) \mathbf{R}(t_0)<br>$$</p><p>由此我们可以引出两个概念。</p><ol><li>求$\mathbf{\phi}$的方法以及它的结构?——$\mathbf{\phi}$是对应到$SO(3)$上的李代数$\mathfrak{so}(3)$</li><li>$\exp( \mathbf{\phi})$如何计算?——李群与李代数间的指数/对数映射。下面我们一一加以介绍。</li></ol><h2 id="李代数"><a href="#李代数" class="headerlink" title="李代数"></a>李代数</h2><p>对于$SO(3)$和$SE(3)$,李代数可定义于李群的正切空间上,描述了李群中元素局部性质,分别把它们记作小写的$\mathfrak{so}(3)$和$\mathfrak{se}(3)$。首先,给出通用的李代数的定义。</p><p>李代数由一个集合$\mathbb{V}$,一个数域$\mathbb{F}$和一个二元运算$[]$组成。如果它们满足以下几条性质,称$(\mathbb{V}, \mathbb{F}, [])$ 为一个李代数,记作$\mathfrak{g}$。</p><ul><li>封闭性:$\forall \mathbf{X}, \mathbf{Y} \in \mathbb{V}, [\mathbf{X} \mathbf{Y}] \in \mathbb{V}$</li><li>双线性:$\forall \mathbf{X,Y,Z} \in \mathbb{V}, a,b \in \mathbb{F}$,有 $[a\mathbf{X}+b\mathbf{Y}, \mathbf{Z}] = a[\mathbf{X}\mathbf{Z}] + b [ \mathbf{Y} \mathbf{Z} ] \quad [\mathbf{Z}, a \mathbf{X}+b\mathbf{Y}] = a [\mathbf{Z} \mathbf{X} ]+ b [\mathbf{ZY}]$</li><li>自反性:$\forall \mathbf{X} \in \mathbb{V}, [\mathbf{X} \mathbf{X}] = \mathbf{0}$</li><li>雅可比等价:$\forall \mathbf{X,Y,Z} \in \mathbb{V}, [\mathbf{X}, [\mathbf{YZ}] ] + [\mathbf{Z}, [\mathbf{YX}] ] + [\mathbf{Y}, [\mathbf{ZX}]]$</li></ul><p>从表面上来看,李代数所需要的性质还是挺多的。其中二元运算被称为<strong>李括号</strong>。相比于群中的较为简单的二元运算,李括号表达了两个集合元素的差异。它不要求结合律,而满足反对称性,以及元素和自己做李括号之后为零的性质。作为类比,三维向量$\mathbb{R}^3$ 上定义的叉积$\times$是一种李括号,因此$\mathfrak{g} = (\mathbb{R}^3, \mathbb{R}, \times)$构成了一个李代数。读者可以尝试将叉积的性质代入到上面四条性质中。</p><h2 id="三维旋转群与对应的李代数"><a href="#三维旋转群与对应的李代数" class="headerlink" title="三维旋转群与对应的李代数"></a>三维旋转群与对应的李代数</h2><p>$SO(3)$对应的李代数是定义在$\mathbb{R}^3$上的向量,我们记作$\mathbf{\phi}$(注意这是个向量,虽然希腊字母的粗体不明显)。根据前面的推导,每个$\mathbf{\phi}$都可以生成一个反对称矩阵:</p><p>$$<br>\mathbf{\Phi} = \mathbf{\phi}^{\wedge} = \begin{bmatrix} 0&-\phi_3&\phi_2\\phi_3&0&-\phi_1\-\phi_2&\phi_1&0\end{bmatrix} \in \mathbb{R}^{3 \times 3}<br>$$</p><p>在此定义下,两个向量$\mathbf{\phi}_1, \mathbf{\phi}_2$的李括号为:$[\mathbf{\phi}_1, \mathbf{\phi}_2] = \mathbf{ \Phi }_1 \mathbf{ \Phi }_2 - \mathbf{ \Phi }_2 \mathbf{ \Phi }_1$</p><p>读者可以去验证该定义下的李括号满足上面的几条性质。由于$\mathbf{\phi}$ 与反对称矩阵关系很紧密,在不引起歧义的情况下,就说$\mathfrak{so}(3)$的元素是3维向量或者3维反对称矩阵,不加区别:$\mathfrak{so}(3) = \left{ \Phi = \mathbf{\phi^\wedge} \in \mathbb{R}^{3 \times 3} | \mathbf{\phi} \in \mathbb{R}^3 \right}$</p><p>反对称矩阵有一些重要的性质,重点包括以下两条:$\mathbf{\phi} \mathbf{\phi}^T = \mathbf{\phi}^{\wedge} \mathbf{\phi}^{\wedge} + | \mathbf{\phi} |^2 \mathbf{I}_{3 \times 3}$</p><ul><li>当$\mathbf{\phi}$为单位向量时,进而有:$\mathbf{\phi} \mathbf{\phi}^T = \mathbf{\phi}^{\wedge} \mathbf{\phi}^{\wedge} + \mathbf{I}1$</li><li>$\mathbf{\phi}^{\wedge} \mathbf{\phi}^{\wedge} \mathbf{\phi}^{\wedge} = - \mathbf{\phi}^{\wedge}$</li></ul><p>这两条性质读者也可以自行验证,我们在指数映射中会用到。</p><p>至此,我们已清楚了$\mathfrak{so}(3)$的结构。它们是一个由三维向量组成的集合,每个向量对应到一个反对称矩阵,可以表达旋转矩阵的导数。现在来考虑$\exp ( \mathbf{\phi}^{\wedge} )$是如何计算的,为此我们引入指数映射。</p><h2 id="指数映射"><a href="#指数映射" class="headerlink" title="指数映射"></a>指数映射</h2><p>首先,回忆任意矩阵的指数映射。它可以写成一个泰勒展开,但是只有在收敛的情况下才会有结果,其结果仍是一个矩阵。</p><p>$$<br>\exp(\mathbf{A}) = \sum\limits_{n = 0}^\infty {\frac{1}{ { n! } }{ \mathbf{A}^n} }<br>$$</p><p>同样地,对$\mathfrak{so}(3)$中任意一元素$\mathbf{\phi}$,我们亦可按此方式定义它的指数映射:</p><p>$$<br>\exp(\mathbf{\phi}^\wedge) = \sum\limits_{n = 0}^\infty {\frac{1}{ {n!} }{ (\mathbf{\phi}^{\wedge})^n} }<br>$$</p><p>现在我们来仔细看看它的含义。由于$\mathbf{\phi}$是三维向量,我们可以定义它的模长和它的方向,分别记作$\theta$和$\mathbf{a}$(注意这里记号是有含义的,此时$\mathbf{a}$是一个单位长度的向量),那么按照上式,可以推出如下公式,注意中间使用了上面讲到了两个反对称矩阵的性质:</p><p>$$<br>\exp \left( { {\mathbf{\phi} ^ \wedge } } \right) = \exp \left( {\theta {\mathbf{a}^ \wedge } } \right) = \sum\limits_{n = 0}^\infty {\frac{1}{ {n!} }{ {\left( {\theta {\mathbf{a}^ \wedge } } \right)}^n} } \ =\mathbf{I} + \theta {\mathbf{a}^ \wedge } + \frac{1}{ {2!} }{\theta ^2}{\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge } + \frac{1}{ {3!} }{\theta ^3}{\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge } + \frac{1}{ {4!} }{\theta ^4}{\left( { {\mathbf{a}^ \wedge } } \right)^4} + …\ =\mathbf{a} {\mathbf{a}^T} - {\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge } + \theta {\mathbf{a}^ \wedge } + \frac{1}{ {2!} }\theta {\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge } - \frac{1}{ {3!} }{\theta ^3}{\mathbf{a}^ \wedge } + \frac{1}{ {4!} }{\theta ^4}{\left( { {\mathbf{a}^ \wedge } } \right)^4} + …\= \mathbf{a}{\mathbf{a}^T} + \left( {\theta - \frac{1}{ {3!} }{\theta ^3} + \frac{1}{ {5!} }{\theta ^5} - …} \right){\mathbf{a}^ \wedge } - \left( {1 - \frac{1}{ {2!} }{\theta ^2} + \frac{1}{ {4!} }{\theta ^4} - …} \right){\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge }\= {\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge } + \mathbf{I} + \sin \theta {\mathbf{a}^ \wedge } - \cos \theta {\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge }\= (1 - \cos \theta ){\mathbf{a}^ \wedge }{\mathbf{a}^ \wedge } + I + \sin \theta {\mathbf{a}^ \wedge }\= \cos \theta \mathbf{I} + (1 - \cos \theta )\mathbf{a}{\mathbf{a}^T} + \sin \theta {\mathbf{a}^ \wedge }<br>$$</p><p>最后我们得到了一个似曾相识的式子:</p><p>$$<br>\exp( \theta \mathbf{a} ) = \cos \theta \mathbf{I} + (1 - \cos \theta )\mathbf{a}{\mathbf{a}^T} + \sin \theta {\mathbf{a}^ \wedge }<br>$$</p><p>回忆前一节内容,它和<strong>罗德里格斯公式</strong>如出一辄。这表明,$\mathfrak{so}(3)$实际上就是由所谓的<strong>旋转向量</strong>组成的空间。特别地,当转轴取一定顺序时,李代数$\mathfrak{so}(3)$还会变为对应的欧拉角。通过罗德里格斯公式或者指数映射,我们把$\mathbb{R}^3$ 中的一个向量对应到了一个位于$SO(3)$中的3D旋转。</p><p>反之,如果定义对数映射,我们也能把$SO(3)$中的元素对应到$\mathfrak{so}(3)$中</p><p>$$<br>\mathbf{\phi} = \ln {\left( \mathbf{R} \right)^ \vee } = {\left( {\sum\limits_{n = 0}^\infty {\frac{ { { {\left( { - 1} \right)}^n } } }{ {n + 1} }{ {\left( { \mathbf{R} - \mathbf{I} } \right)}^{n + 1} } } } \right)^ \vee }<br>$$</p><p>其中$^\vee$表示从反对称矩阵到向量的对应关系,为$^\wedge$的逆运算。</p><p>读者可能会问,指数映射性质如何呢?它是一个双射吗?很遗憾,它只是一个满射。每个$SO(3)$中的元素,都可以找到$\mathfrak{so}(3)$中至少一个与之对应;但是可能存在多个$\mathfrak{so}(3)$中的元素,对应到同一个$SO(3)$元素上。至少对于旋转角$\theta$,我们知道它具有周期性。</p><p>$SO(3)$与$\mathfrak{so}(3)$的结论似乎在我们意料之中。它和我们前面讲的旋转向量与旋转矩阵很相似,而指数映射即是罗德里格斯公式。旋转向量可以视为旋转矩阵的导数,指导如何在旋转矩阵中进行微积分运算。</p><h2 id="三维欧氏群与对应的李代数"><a href="#三维欧氏群与对应的李代数" class="headerlink" title="三维欧氏群与对应的李代数"></a>三维欧氏群与对应的李代数</h2><p>下面我们来介绍三维欧氏群$SE(3)$以及对应的李代数$\mathfrak{se}(3)$。有了前面的基础,我们可以直接介绍它们的结构及运算了。$SE(3)$的结构已经在前面介绍群的时候给出:</p><p>$$<br>SE(3) = \left{ \mathbf{T} =\begin{bmatrix}\mathbf{R} & \mathbf{t} \{ {\mathbf{0}^T} } & 1 \end{bmatrix} \in \mathbb{R}^{4 \times 4} | \mathbf{R} \in SO(3), \mathbf{t} \in \mathbb{R}^3\right}<br>$$</p><p>每个变换矩阵有六个四由度,故对应的李代数位于$\mathbb{R}^6$中:$\mathfrak{se}(3) = \left{ \mathbf{ \Xi } = \mathbf{\xi}^\wedge \in \mathbb{R}^{4 \times 4} | \mathbf{\xi} \in \mathbb{R}^6 \right}$</p><p>但是$^\wedge$不再对应到一个反对称关系,而是</p><p>$$<br>\mathbf{\xi}^\wedge = \begin{bmatrix}\mathbf{\rho} \ \mathbf{\phi} \end{bmatrix}^ \wedge = \begin{bmatrix}{ {\mathbf{\phi} ^ \wedge } }&\mathbf{\rho} \{ {\mathbf{0}^T} }&0\end{bmatrix} = \mathbf{\Xi}<br>$$</p><p>可以看到,$\mathbf{\xi}$ 的前三维为旋转向量,后三维为平移向量,其定义也十分的直观。该李代数对应于微分方程:$\mathbf{\dot{T} }(t) = \mathbf{\xi}^\wedge(t) \mathbf{T}(t)$,因此$\mathbf{T}(t) = \exp ( \mathbf{\xi}(t)^\wedge ) \mathbf{T}(t)$。那么$\mathfrak{se}(3)$上的指数映射如何呢?略加推导可得:</p><p>$$<br>\exp \left( { { \mathbf{\xi} ^ \wedge } } \right) = \begin{bmatrix} {\sum\limits_{n = 0}^\infty {\frac{1}{ {n!} }{ {\left( { {\mathbf{\phi} ^ \wedge } } \right)}^n} } }&{\sum\limits_{n = 0}^\infty {\frac{1}{ {\left( {n + 1} \right)!} }{ {\left( { {\mathbf{\phi} ^ \wedge } } \right)}^n} \mathbf{\rho} } }\ { {\mathbf{0}^T} }&1\end{bmatrix}\= \begin{bmatrix} \mathbf{\Phi} &{\mathbf{J\rho} } \{ {\mathbf{0}^T} }&1 \end{bmatrix}<br>$$</p><p>左上角的$\mathbf{\Phi}$是我们熟知的$\mathfrak{so}(3)$中的元素,前文已经介绍过了。而右上角的$\mathbf{J}$则可整理为(设$\mathbf{\phi}=\theta\mathbf{a}$):</p><p>$$<br>\mathbf{J} = \frac{ {\sin \theta } }{\theta } \mathbf{I} + \left( {1 - \frac{ {\sin \theta } }{\theta } } \right) \mathbf{a} { \mathbf{a}^T} + \frac{ {1 - \cos \theta } }{\theta }{ \mathbf{a}^ \wedge }<br>$$</p><p>因此我们就得到了$\mathfrak{se}(3)$的指数映射的关系。 其对数映射亦可类比推得。</p><p><img src="https://s2.loli.net/2024/03/25/9d8X5enQjCgKiUZ.jpg" alt="list.jpg"></p><h1 id="李代数求导与扰动模型"><a href="#李代数求导与扰动模型" class="headerlink" title="李代数求导与扰动模型"></a>李代数求导与扰动模型</h1><h2 id="BCH公式与近似形式"><a href="#BCH公式与近似形式" class="headerlink" title="BCH公式与近似形式"></a>BCH公式与近似形式</h2><p>使用李代数的一大目的是进行优化,在优化过程中导数是非常必要的信息。考虑一个问题,当在$SO(3)$中完成两个矩阵乘法时,李代数中$SO(3)$上发生了什么改变?反过来说,当$SO(3)$上做两个李代数的加法时,$SO(3)$上是否对应着两个矩阵的乘积?如果成立,相当于:$\exp( \phi _ {1}^ {\wedge } )\exp( \phi _ {2}^ {\wedge } )=\exp(( \phi _ {1} + \phi _ {2} ) \wedge )$ ?</p><p>如果$\phi_ {1} \phi_{2}$为标量,显然该式成立;但此处计算的是矩阵的指数函数,而非标量的指数。换言之是在研究下式是否成立:$ln(\exp(A)\exp(B)) = A + B$</p><p>该式在矩阵时并不成立。两个李代数指数映射乘积的完整形式,由 <strong>Baker-Campbell-Hausdorff公式(BCH公式)</strong>出。</p><p>$$<br>\ln ( \exp(A) \exp(B))=A+B+ \frac {1}{2} [A,B]+ \frac {1}{12} [A,[A,B]]- \frac {1}{12} [B,[A,B]]+ \cdots<br>$$</p><p>其中 $[]$ 为李括号。BCH 公式说明,当处理两个矩阵指数之积时,它们会产生一些由李括号组成的余项。特别地,考虑$SO(3)$上的李代数,当$\phi_ {1}$ 或$\phi_{2}$为小量时,小于二次以上的项都可以被忽略掉。此时,BCH拥有线性近似表达 :</p><p>$$<br> \ln (\exp( \phi_{1}^{\wedge} )\exp(\phi_{2}^{\wedge}))^\vee \approx \begin{cases}\mathbf{J}<em>{l}(\phi</em>{2})^{-1}\phi_{1}+\phi_{2} \quad 当\phi_{1}为小量 \\mathbf{J}<em>{r} (\phi</em>{1})^{-1}\phi_{2}+\phi_{1} \quad 当\phi_{2}为小量\end{cases}<br>$$</p><p>以第一个近似为例。该式说明,当对一个旋转矩阵$\mathbf{R}<em>2$(李代数为$\phi_2$)左乘一个微小旋转矩阵$\mathbf{R}<em>1$(李代数为$\phi_1$)时,可以近似地看作,在原有的李代数$\phi_2$上加上了一项$\mathbf{J}</em>{l}(\phi</em>{2})^{-1}\phi_{1}$。同理,第二个近似描述了右乘一个微小位移的情况。于是,李代数在BCH近似下,分成了左乘近似和右乘近似两种,在使用时我们须注意使用的是左乘模型还是右乘模型。</p><p>左乘 BCH 近似雅可比</p><p>$$<br>\mathbf{J}_ {l} =\mathbf{J}= \frac {\sin \theta }{\theta } \mathbf{I}+(1- \frac {\sin \theta }{\theta } ) aa^ {T} + \frac {1-\cos \theta }{\theta } a^{\wedge}<br>$$</p><p>它的逆为:</p><p>$$<br>J_ {l}^ {-1} = \frac {\theta }{2} \cot \frac {\theta }{2} I+(1- \frac {\theta }{2} \cot \frac {\theta }{2} ) aa^ {T} - \frac {\theta }{2} a^{\wedge}<br>$$</p><p>右乘雅可比仅需要对自变量取负号即可:</p><p>$$<br>\mathbf{J}<em>r(\phi) = \mathbf{J}</em>{l}(-\phi)<br>$$</p><h3 id="BCH近似的意义"><a href="#BCH近似的意义" class="headerlink" title="BCH近似的意义"></a><strong>BCH近似的意义</strong></h3><p>假定对某个旋转$\boldsymbol{R}$,对应的李代数为$\phi$。我们给它左乘一个微小旋转,记作$\Delta \boldsymbol{R}$,对应的李代数为$\Delta \phi$。那么,在李群上,得到的结果就是$\Delta \boldsymbol{R} \cdot \boldsymbol{R}$,而在李代数上,根据BCH近似,为$\mathbf{J}^{-1}_{l}(\phi) \Delta \phi + \phi$。合并起来,可以写成:</p><p>$$<br>\exp( \Delta \phi ^ {\wedge } )\exp( \phi ^{\wedge})=\exp( (\phi +\mathbf{J}_{l}^ {-1}(\phi )\Delta \phi )^ {\wedge } )<br>$$</p><p>反之,如果我们在李代数上进行加法,让一个$\phi$加上 $\Delta \phi$,可以近似为李群上的左右雅可比的乘法:</p><p>$$<br>\exp( (\phi +\Delta \phi )^{\wedge})=\exp(( \mathbf{J}<em>{l} \Delta \phi )^{\wedge})\exp(\phi^ \wedge )=\exp( \phi^\wedge )\exp( (\mathbf{J}</em> {r}\triangle \phi )^ {\wedge })<br>$$</p><p>这为在李代数上做微积分提供了理论基础。同样,对于$SE(3)$,亦有类似的BCH近似:</p><p>$$<br>\exp( \Delta \xi^{\wedge} )\exp (\xi^\wedge)\approx \exp((\mathcal{J}<em>{l}^{-1}\Delta \xi +\xi )^{\wedge }), \ \exp(\xi^\wedge) \exp( \Delta \xi^ \wedge)\approx \exp((\mathcal{J}</em>{r}^{-1}\Delta \xi +\xi )^{\wedge })<br>$$</p><p>这里的$\mathcal{J}_{l}$形式比较复杂,是一个6×6的矩阵。</p><h2 id="SO-3-李代数上的求导"><a href="#SO-3-李代数上的求导" class="headerlink" title="$**SO(3)$李代数上的求导"></a>$**<strong>SO(3)$李代数上的求导</strong></h2><p>在SLAM中,要估计一个相机的位置和姿态,该位姿是由$SO(3)$上的旋转矩阵或$SE(3)$ 上的变换矩阵描述的。设某个时刻机器人的位姿为$T$,它观察到了一个世界坐标位于$p$的点,产生了一个观测数据$z$。由坐标变换关系知:</p><p>$$<br>z = Tp + w<br>$$</p><p>其中$w$为随机噪声。由于它的存在,$z$ 往往不可能精确地满足$z=Tp$的关系。所以通常会计算理想的观测与实际数据的误差:</p><p>$$<br>e =z−Tp<br>$$</p><p>假设一共有$N$个这样的路标点和观测,于是就有$N$个上式。那么,对机器人的位姿估计,相当于是寻找一个最优的$T$,使得整体误差最小化:</p><p>$$<br>\min_\mathbf{T} J(\mathbf{T})= \sum_{i=1}^{N}\parallel z_ {i} - \mathbf{T} \mathbf{p}_ {i}\parallel_{2}^{2}<br>$$</p><p>求解此问题,需要计算目标函数$J$关于变换矩阵T的导数。</p><p>重点是构建与位姿有关的函数,讨论该函数关于位姿的导数,以调整当前的估计值。然而$SO(3),SE(3)$上并没有良好定义的加法,它们是群。如果把$\boldsymbol{T}$当成一个普通矩阵来处理优化,那就必须对它加以约束(旋转矩阵的约束是行列式值唯一,计算复杂)。而从李代数角度来说,由于李代数由向量组成,具有良好的加法运算。</p><p>使用李代数解决求导问题的思路分为两种:</p><ol><li>用李代数表示姿态,然后根据李代数加法来对李代数求导。</li><li>对李群左乘或右乘微小扰动,然后对该扰动求导,称为左扰动和右扰动模型。</li></ol><p>第一种方式对应到李代数的求导模型,而第二种则对应到扰动模型。</p><h2 id="李代数求导"><a href="#李代数求导" class="headerlink" title="李代数求导"></a><strong>李代数求导</strong></h2><p>考虑$SO(3)$上的情况。假设对一个空间点$\boldsymbol{p}$进行了旋转,得到了$\boldsymbol{Rp}$。计算旋转之后点的坐标相对于旋转的导数,记为 :$\frac{\partial \boldsymbol{Rp} }{\partial \boldsymbol{R} }$</p><p>由于$SO(3)$没有加法,所以该导数无法按照导数的定义进行计算$\boldsymbol{R}$。设对应的李代数为$\phi$,转而计算$\frac{\partial\left(\exp \left(\boldsymbol{\phi}^{\wedge}\right) \boldsymbol{p}\right)}{\partial \boldsymbol{\phi} }$</p><p>按照导数的定义,推导出了旋转后的点相对于李代数的导数:$\frac{\partial(\boldsymbol{R} \boldsymbol{p})}{\partial \boldsymbol{\phi} }=(-\boldsymbol{R} \boldsymbol{p})^{\wedge} \boldsymbol{J}_{l}$</p><p>注:这里并不能按照矩阵微分来定义导数,只是一个记号。</p><h3 id="扰动模型(左乘)"><a href="#扰动模型(左乘)" class="headerlink" title="扰动模型(左乘)"></a><strong>扰动模型(左乘)</strong></h3><p>另一种求导方式是对$\boldsymbol{R}$进行一次扰动$\Delta \boldsymbol{R}$,看结果相对于扰动的变化率。这个扰动可以乘在左边也可以乘在右边,最后结果会有一点微小的差异。以左扰动为例,设左扰动$\Delta \boldsymbol{R}$对应的李代数为$\varphi$。对$\varphi$求导,即:</p><p>$$<br>\frac{\partial(\boldsymbol{R} \boldsymbol{p})}{\partial \boldsymbol{\varphi} }=\lim_{\boldsymbol{\varphi} \rightarrow \boldsymbol{0} }\frac{\exp(\boldsymbol{\varphi}^\wedge) \exp(\boldsymbol{\phi}^\wedge)\boldsymbol{p} - \exp(\boldsymbol{\phi}^\wedge)\boldsymbol{p} }{\boldsymbol{\varphi} } \ =\lim_{\boldsymbol{\varphi} \rightarrow \boldsymbol{0} }\frac{(\boldsymbol{I}+{\varphi}^\wedge) \exp(\boldsymbol{\phi}^\wedge)\boldsymbol{p} - \exp(\boldsymbol{\phi}^\wedge)\boldsymbol{p} }{\boldsymbol{\varphi} }\ = \lim_{\boldsymbol{\varphi} \rightarrow \boldsymbol{0} }\frac{\boldsymbol{\varphi}^\wedge \boldsymbol{Rp} }{\boldsymbol{\varphi} } = \lim_{\boldsymbol{\varphi} \rightarrow \boldsymbol{0} }\frac{-(\boldsymbol{Rp})^\wedge \boldsymbol{\varphi} }{\boldsymbol{\varphi} } = -(\boldsymbol{Rp})^\wedge<br>$$</p><p>相比于直接对李代数求导,省去了一个雅可比矩阵的计算。这使得扰动模型更为实用,在位姿估计当中具有重要的意义。</p><h3 id="SE-3-上的李代数求导"><a href="#SE-3-上的李代数求导" class="headerlink" title="$SE(3)$上的李代数求导"></a>$<strong>SE(3)$上的李代数求导</strong></h3><p>书中只给出$SE(3)$上的扰动模型。假设某空间点$\boldsymbol{p}$经过一次变换$\boldsymbol{T}$(对应李代数为$\boldsymbol{\xi}$),得到$\boldsymbol{Tp}$。给$\boldsymbol{T}$左乘一个扰动$\Delta \boldsymbol{T}= \exp(\delta \boldsymbol{\xi}^\wedge)$,设扰动项的李代数为 $\delta \boldsymbol{\xi} = [\delta \boldsymbol{\rho},\delta \boldsymbol{\phi}]^T$,那么:</p><p>$$<br>\frac{\partial(\boldsymbol{T} \boldsymbol{p})}{\partial \delta \boldsymbol{\xi} }=\lim_{\delta \boldsymbol{\xi} \rightarrow \boldsymbol{0} }\frac{\exp(\delta \boldsymbol{\xi}^\wedge) \exp(\boldsymbol{\xi}^\wedge)\boldsymbol{p} - \exp(\boldsymbol{\xi}^\wedge)\boldsymbol{p} }{\boldsymbol{\xi} } \=\lim_{\delta \boldsymbol{\xi} \rightarrow \boldsymbol{0} }\frac{(\boldsymbol{I}+\delta \boldsymbol{\xi}^\wedge) \exp(\boldsymbol{\xi}^\wedge)\boldsymbol{p} - \exp(\boldsymbol{\xi}^\wedge)\boldsymbol{p} }{\delta \boldsymbol{\xi} }\ = \lim_{\delta \boldsymbol{\xi} \rightarrow \boldsymbol{0} }\frac {\begin{bmatrix} \delta\phi^\wedge & \delta \boldsymbol{\rho}\\boldsymbol{0}^T & 0\end{bmatrix}\begin{bmatrix} \boldsymbol{Rp}+\boldsymbol{t}\1\end{bmatrix} }{\delta \boldsymbol{\xi} } \ =\lim_{\delta \boldsymbol{\xi} \rightarrow \boldsymbol{0} }\frac {\begin{bmatrix} \delta\phi^\wedge(\boldsymbol{Rp}+\boldsymbol{t}) + \delta \boldsymbol{\rho}\ \boldsymbol{0}^T \end{bmatrix} }{\begin{bmatrix} \delta \boldsymbol{\rho},\delta \boldsymbol{\phi}\end{bmatrix}^T} \=\begin{bmatrix} \boldsymbol{I} & -(\boldsymbol{Rp}+\boldsymbol{t})^\wedge \ \boldsymbol{0}^T & \boldsymbol{0}^T\end{bmatrix}\xlongequal{def} (\boldsymbol{Tp})^{\bigodot}<br>$$</p><p>把最后的结果定义成一个算符$^{\bigodot}$ ,它把一个齐次坐标的空间点变换成一个4×6的矩阵。</p>]]></content>
<summary type="html">SLAM数学基础:三维空间的刚体运动和李群李代数</summary>
<category term="计算机视觉" scheme="https://thinksky5124.github.io/categories/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/"/>
<category term="SLAM" scheme="https://thinksky5124.github.io/categories/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/SLAM/"/>
<category term="SLAM" scheme="https://thinksky5124.github.io/tags/SLAM/"/>
<category term="数学原理" scheme="https://thinksky5124.github.io/tags/%E6%95%B0%E5%AD%A6%E5%8E%9F%E7%90%86/"/>
</entry>
<entry>
<title>权重初始化</title>
<link href="https://thinksky5124.github.io/2022/08/18/%E6%A8%A1%E5%9E%8B%E4%BC%98%E5%8C%96/"/>
<id>https://thinksky5124.github.io/2022/08/18/%E6%A8%A1%E5%9E%8B%E4%BC%98%E5%8C%96/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<h1 id="模型优化"><a href="#模型优化" class="headerlink" title="模型优化"></a>模型优化</h1><p><a href="https://mp.weixin.qq.com/s/Pht5qhGWtRiyUVjBDRt8vA">kaiming初始化的推导</a></p><h1 id="权重初始化"><a href="#权重初始化" class="headerlink" title="权重初始化"></a>权重初始化</h1><h2 id="为什么需要权重初始化"><a href="#为什么需要权重初始化" class="headerlink" title="为什么需要权重初始化"></a>为什么需要权重初始化</h2><p>网络训练的过程中,容易出现梯度消失(梯度特别的接近0)和梯度爆炸(梯度特别的大)的情况,导致大部分反向传播得到的梯度不起作用或者起反作用。研究人员希望能够有一种好的权重初始化方法:让网络前向传播或者反向传播的时候,卷积的输出和前传的梯度比较稳定。合理的方差既保证了数值一定的不同,又保证了数值一定的稳定。(通过卷积权重的合理初始化, 让计算过程中的数值分布稳定)</p><h2 id="推导的先验知识"><a href="#推导的先验知识" class="headerlink" title="推导的先验知识"></a><strong><strong>推导的先验知识</strong></strong></h2><p><img src="https://s2.loli.net/2024/03/25/3IfCZkYbwQAX7eK.png" alt="conv_simple.png"></p><p>参照上面的卷积图,对输入的特征图进行的卷积。具体要研究的是输出的一个点的方差(紫色点)。所以是通过黄色的输入(个)和绿色的卷积参数(个)去计算一个输出值(紫色输出)的方差。 <strong>一个点</strong>对应于原论文里面的说法为<strong>a response。</strong>感觉这个是理解权重初始化的重点。基于独立同分布的强假设:输入的每个值都是独立同分布的,所以和独立同分布的参数进行卷积得到结果的分布也是相同的。所以其他的3个输出点的方差也是一样的。进一步说,虽然输入是个不同的值。但是我们可以这样认为:<strong>有一个满足某分布的随机变量,然后随机抽样48次,这48个值就可以组成了输入,且独立同分布(也可称输入的每个像素点是独立同分布的)。</strong>卷积的参数也可以这么认为。那么我们可以用一个随机变量表示48个输入,也可以用一个随机变量表示27个卷积参数,亦可以用一个随机变量表示4个输出值。</p><h3 id="公式"><a href="#公式" class="headerlink" title="公式"></a>公式</h3><p>$$<br>var(X_1+…+X_n)=var(X_1)+…+var(X_n)<br>$$</p><p>上式表示独立随机变量之和的方差等于各变量的方差之和,如果$X_1$和$X_2$还是同分布的,那么$var(X_1)=var(X_2)->var(X_1)+var(X_2)=2var(X_1)=2var(X_2)$。将这个应用在卷积求和的那一步(卷积先乘,再求和)。</p><p>$$<br>var(X)=E(x^2)-(EX)^2<br>$$</p><p>上式是通过期望求方差的公式,方差等于平方的期望减去期望的平方。如果$E(X)=0$,那么$var(X)=E(X^2)$。</p><p>$$<br>var(XY)=var(X)var(Y)+var(X)(EY)^2+var(Y)(EX)^2<br>$$</p><p>上式式独立变量乘积的一个公式(协方差为0)如果$E(X)=E(Y)=0$,那么$var(XY)=var(X)var(Y)$。</p><h2 id="kaiming初始化"><a href="#kaiming初始化" class="headerlink" title="kaiming初始化"></a>kaiming初始化</h2><ul><li>前向传播的时候, 每一层的卷积计算结果的方差为1.</li><li>反向传播的时候, 每一 层的继续往前传的梯度方差为1(因为每层会有两个梯度的计算,一个用来更新当前层的权重,一个继续传播,用于前面层的梯度的计算。)</li></ul><h3 id="源码"><a href="#源码" class="headerlink" title="源码"></a>源码</h3><p>方差的计算需要两个值:<strong>gain</strong>和<strong>fan。gain</strong>值由激活函数决定。<strong>fan</strong>值由权重参数的数量和传播的方向决定。<strong>fan_in</strong>表示前向传播,<strong>fan_out</strong>表示反向传播。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">kaiming_normal_</span>(<span class="params">tensor, a=<span class="number">0</span>, mode=<span class="string">'fan_in'</span>, nonlinearity=<span class="string">'leaky_relu'</span></span>):</span><br><span class="line"> fan = _calculate_correct_fan(tensor, mode) </span><br><span class="line"> <span class="comment"># 通过mode判断是前向传播还是反向传播, 生成不同的一个fan值.</span></span><br><span class="line"> gain = calculate_gain(nonlinearity, a)</span><br><span class="line"> <span class="comment"># 通过判断是哪种激活函数生成一个gain值</span></span><br><span class="line"> std = gain / math.sqrt(fan) <span class="comment"># 通过fan值和gain值进行标准差的计算</span></span><br><span class="line"> <span class="keyword">with</span> torch.no_grad():</span><br><span class="line"> <span class="keyword">return</span> tensor.normal_(<span class="number">0</span>, std)</span><br></pre></td></tr></tbody></table></figure><p>下面的代码根据网络设计时<strong>卷积权重的形状</strong>和前向传播还是反向传播,进行<strong>fan</strong>值的计算。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">_calculate_fan_in_and_fan_out</span>(<span class="params">tensor</span>):</span><br><span class="line"> dimensions = tensor.dim() <span class="comment"># 返回的是维度</span></span><br><span class="line"> <span class="keyword">if</span> dimensions < <span class="number">2</span>:</span><br><span class="line"> <span class="keyword">raise</span> ValueError(<span class="string">"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"</span>)</span><br><span class="line"> <span class="keyword">if</span> dimensions == <span class="number">2</span>: <span class="comment"># Linear</span></span><br><span class="line"> fan_in = tensor.size(<span class="number">1</span>) </span><br><span class="line"> fan_out = tensor.size(<span class="number">0</span>)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> num_input_fmaps = tensor.size(<span class="number">1</span>) <span class="comment"># 卷积的输入通道大小</span></span><br><span class="line"> num_output_fmaps = tensor.size(<span class="number">0</span>) <span class="comment"># 卷积的输出通道大小</span></span><br><span class="line"> receptive_field_size = <span class="number">1</span></span><br><span class="line"> <span class="keyword">if</span> tensor.dim() > <span class="number">2</span>:</span><br><span class="line"> receptive_field_size = tensor[<span class="number">0</span>][<span class="number">0</span>].numel() <span class="comment"># 卷积核的大小:k*k</span></span><br><span class="line"> fan_in = num_input_fmaps * receptive_field_size <span class="comment"># 输入通道数量*卷积核的大小. 用于前向传播</span></span><br><span class="line"> fan_out = num_output_fmaps * receptive_field_size <span class="comment"># 输出通道数量*卷积核的大小. 用于反向传播</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> fan_in, fan_out</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">_calculate_correct_fan</span>(<span class="params">tensor, mode</span>):</span><br><span class="line"> mode = mode.lower()</span><br><span class="line"> valid_modes = [<span class="string">'fan_in'</span>, <span class="string">'fan_out'</span>]</span><br><span class="line"> <span class="keyword">if</span> mode <span class="keyword">not</span> <span class="keyword">in</span> valid_modes:</span><br><span class="line"> <span class="keyword">raise</span> ValueError(<span class="string">"Mode {} not supported, please use one of {}"</span>.<span class="built_in">format</span>(mode, valid_modes))</span><br><span class="line"></span><br><span class="line"> fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)</span><br><span class="line"> <span class="keyword">return</span> fan_in <span class="keyword">if</span> mode == <span class="string">'fan_in'</span> <span class="keyword">else</span> fan_out</span><br></pre></td></tr></tbody></table></figure><p>下面是通过不同的激活函数返回一个<strong>gain</strong>值,当然也说明了是recommend。可以自己修改。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">calculate_gain</span>(<span class="params">nonlinearity, param=<span class="literal">None</span></span>):</span><br><span class="line"> <span class="string">r"""Return the recommended gain value for the given nonlinearity function.</span></span><br><span class="line"><span class="string"> The values are as follows:</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> ================= ====================================================</span></span><br><span class="line"><span class="string"> nonlinearity gain</span></span><br><span class="line"><span class="string"> ================= ====================================================</span></span><br><span class="line"><span class="string"> Linear / Identity :math:`1`</span></span><br><span class="line"><span class="string"> Conv{1,2,3}D :math:`1`</span></span><br><span class="line"><span class="string"> Sigmoid :math:`1`</span></span><br><span class="line"><span class="string"> Tanh :math:`\frac{5}{3}`</span></span><br><span class="line"><span class="string"> ReLU :math:`\sqrt{2}`</span></span><br><span class="line"><span class="string"> Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`</span></span><br><span class="line"><span class="string"> ================= ====================================================</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Args:</span></span><br><span class="line"><span class="string"> nonlinearity: the non-linear function (`nn.functional` name)</span></span><br><span class="line"><span class="string"> param: optional parameter for the non-linear function</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Examples:</span></span><br><span class="line"><span class="string"> >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> linear_fns = [<span class="string">'linear'</span>, <span class="string">'conv1d'</span>, <span class="string">'conv2d'</span>, <span class="string">'conv3d'</span>, <span class="string">'conv_transpose1d'</span>, <span class="string">'conv_transpose2d'</span>, <span class="string">'conv_transpose3d'</span>]</span><br><span class="line"> <span class="keyword">if</span> nonlinearity <span class="keyword">in</span> linear_fns <span class="keyword">or</span> nonlinearity == <span class="string">'sigmoid'</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">elif</span> nonlinearity == <span class="string">'tanh'</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">5.0</span> / <span class="number">3</span></span><br><span class="line"> <span class="keyword">elif</span> nonlinearity == <span class="string">'relu'</span>:</span><br><span class="line"> <span class="keyword">return</span> math.sqrt(<span class="number">2.0</span>)</span><br><span class="line"> <span class="keyword">elif</span> nonlinearity == <span class="string">'leaky_relu'</span>:</span><br><span class="line"> <span class="keyword">if</span> param <span class="keyword">is</span> <span class="literal">None</span>:</span><br><span class="line"> negative_slope = <span class="number">0.01</span></span><br><span class="line"> <span class="keyword">elif</span> <span class="keyword">not</span> <span class="built_in">isinstance</span>(param, <span class="built_in">bool</span>) <span class="keyword">and</span> <span class="built_in">isinstance</span>(param, <span class="built_in">int</span>) <span class="keyword">or</span> <span class="built_in">isinstance</span>(param, <span class="built_in">float</span>):</span><br><span class="line"> <span class="comment"># True/False are instances of int, hence check above</span></span><br><span class="line"> negative_slope = param</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">raise</span> ValueError(<span class="string">"negative_slope {} not a valid number"</span>.<span class="built_in">format</span>(param))</span><br><span class="line"> <span class="keyword">return</span> math.sqrt(<span class="number">2.0</span> / (<span class="number">1</span> + negative_slope ** <span class="number">2</span>))</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">raise</span> ValueError(<span class="string">"Unsupported nonlinearity {}"</span>.<span class="built_in">format</span>(nonlinearity))</span><br></pre></td></tr></tbody></table></figure><p>下面是kaiming初始化均匀分布的计算。为啥还有个均匀分布?<strong>权重初始化推导的只是一个方差, 并没有限定是正态分布</strong>,均匀分布也是有方差的,并且均值为0的时候,可以通过方差算出均匀分布的最小值和最大值。</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">kaiming_uniform_</span>(<span class="params">tensor, a=<span class="number">0</span>, mode=<span class="string">'fan_in'</span>, nonlinearity=<span class="string">'leaky_relu'</span></span>):</span><br><span class="line"></span><br><span class="line"> fan = _calculate_correct_fan(tensor, mode)</span><br><span class="line"> gain = calculate_gain(nonlinearity, a)</span><br><span class="line"> std = gain / math.sqrt(fan)</span><br><span class="line"> bound = math.sqrt(<span class="number">3.0</span>) * std <span class="comment"># Calculate uniform bounds from standard deviation</span></span><br><span class="line"> <span class="keyword">with</span> torch.no_grad():</span><br><span class="line"> <span class="keyword">return</span> tensor.uniform_(-bound, bound)</span><br></pre></td></tr></tbody></table></figure><h3 id="数学原理"><a href="#数学原理" class="headerlink" title="数学原理"></a>数学原理</h3><p>kaiming初始化的推导过程只包含卷积和ReLU激活函数,默认是vgg类似的网络,没有残差,concat之类的结构, 也没有BN层。</p><p>$$<br>Y_l=W_lX_l+B_l<br>$$</p><p>此处,Y_l表示某个位置的输出值,X_l表示被卷积的输入,有kxkx</p>]]></content>
<summary type="html">权重初始化:kaiming初始化</summary>
<category term="模型优化" scheme="https://thinksky5124.github.io/categories/%E6%A8%A1%E5%9E%8B%E4%BC%98%E5%8C%96/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E6%A8%A1%E5%9E%8B%E4%BC%98%E5%8C%96/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="模型优化" scheme="https://thinksky5124.github.io/tags/%E6%A8%A1%E5%9E%8B%E4%BC%98%E5%8C%96/"/>
</entry>
<entry>
<title>Normalization</title>
<link href="https://thinksky5124.github.io/2022/08/18/%E7%BD%91%E7%BB%9C%E6%AD%A3%E5%88%99%E5%8C%96/"/>
<id>https://thinksky5124.github.io/2022/08/18/%E7%BD%91%E7%BB%9C%E6%AD%A3%E5%88%99%E5%8C%96/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<h1 id="网络正则化"><a href="#网络正则化" class="headerlink" title="网络正则化"></a>网络正则化</h1><h1 id="Normalization"><a href="#Normalization" class="headerlink" title="Normalization"></a>Normalization</h1><h2 id="Batch-Normalization"><a href="#Batch-Normalization" class="headerlink" title="Batch Normalization"></a><strong><strong>Batch Normalization</strong></strong></h2><p>BN的公式:</p><p><img src="https://s2.loli.net/2024/03/25/RI5K7EbTkqzYCPQ.png" alt="BN_formula.png"></p><p>BN就是在深度神经网络训练时通过对每一个batch的数据采用均值和方差进行归一化,使得每一层神经网络的输入保持相同的分布,这样能够加快训练的速度。此外,因为在训练时,为每一次迭代求全局的均值和方差是不现实的,因此借鉴moment的方式对均值和方差进行更新,使得每一层归一化的均值和方差都不一样,也相当于引入了噪声,能够增加模型的鲁棒性,有效减少过拟合。</p><h2 id="Layer-Normalization"><a href="#Layer-Normalization" class="headerlink" title="Layer Normalization"></a>Layer <strong><strong>Normalization</strong></strong></h2><p><img src="https://s2.loli.net/2024/03/25/iM2fH1PO7jagWTI.png" alt="layer_norm.png"></p><p><img src="https://s2.loli.net/2024/03/25/Y5aDRWQ9fANZLu4.png" alt="layer_norm1.png"></p><p>BN抹平了不同特征之间的大小关系,而保留了不同样本之间的大小关系。</p><ul><li>不同图片的的同一通道的相对关系是保留的,即不同图片的同一通达的特征是可以比较的</li><li>同一图片的不同通道的特征则是失去了可比性</li></ul><p>LN抹平了不同样本之间的大小关系,而保留了不同特征之间的大小关系。</p><ul><li>同一句子中词义向量的相对大小是保留的,或者也可以说LayerNorm不改变词义向量的方向,只改变它的模。</li><li>不同句子的词义向量则是失去了可比性。</li></ul><h2 id="Batch-Normalization-注意事项"><a href="#Batch-Normalization-注意事项" class="headerlink" title="Batch Normalization 注意事项"></a><strong><strong>Batch Normalization 注意事项</strong></strong></h2><p><a href="https://zhuanlan.zhihu.com/p/380620373">BatchNorm避坑指南</a></p><p><a href="https://arxiv.org/pdf/2105.07576.pdf">论文链接</a></p><p>注意此语境下主要讨论的都是图像处理领域中的BatchNormalization应用问题,所以讨论的是BatchNorm2D。在训练阶段对形状为[N,C,H,W]的mini-batch<strong>X</strong>,BatchNorm首先计算各个通道上的均值和方差:</p><p><img src="https://s2.loli.net/2024/03/25/HejY7ZVpOxSIJow.png" alt="bn_var_mean.png"></p><p>然后,在对特征x进行归一化:</p><p><img src="https://s2.loli.net/2024/03/25/AEsFmo3nzC9byhf.png" alt="fe_norm.png"></p><p>可以看到计算均值和方差是依赖batch的,这也就是BatchNorm的名字由来。在测试阶段,BatchNorm采用的均值和方差是从训练过程估计的全局统计量(population statistics):$\mu_{pop}$和$\sigma_{pop}^2$,这两个参数也是从训练数据学习到的参数(但不是可训练参数,没有BP过程)。常规的做法在训练阶段采用EMA( exponential moving average,指数移动平均,在<a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization">TensorFlow</a>中EMA产生的均值和方差称为<code>moving_mean</code>和<code>moving_var</code>,而<a href="https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html">PyTorch</a>则称为<code>running_mean</code>和<code>running_var</code>)来估计:</p><p><img src="https://s2.loli.net/2024/03/25/169SXwRF5Vva7pe.png" alt="bn_form.png"></p><p>训练阶段采用的是mini-batch统计量,而测试阶段是采用全局统计量,这就造成了BatchNorm的训练和测试不一致问题,这个后面会详细讨论。</p><p>除了归一化,BatchNorm还包含对各个channel的特征做affine transform(增加特征表征能力):</p><p><img src="https://s2.loli.net/2024/03/25/OxmHYy42NMkvJVp.png" alt="bn_form_1.png"></p><p>这里的$\gamma$和$\beta$是可训练的参数,但是这个过程其实没有batch的参与,从实现上等价于额外增加一个depthwise 1 × 1卷积层。BatchNorm的麻烦主要来自于mini-batch统计量的计算和归一化中,这个affine transform不是影响因素,所以后面的讨论主要集中在前面。</p><p>围绕着<code>batch</code>所能带来的问题,论文共讨论了BatchNorm的四个方面:</p><ul><li><strong>Population Statistics</strong>:EMA是否能够准确估计全局统计量以及PreciseBN;</li><li><strong>Batch in Training and Testing</strong>:训练采用mini-batch统计量,而测试采用全局统计量,由此带来的不一致问题;</li><li><strong>Batch from Different Domains</strong>:BatchNorm在multiple domains中遇到的问题;</li><li><strong>Information Leakage within a Batch</strong>:BatchNorm所导致的信息泄露问题;</li></ul><p>第二个应该是大家都熟知的问题,但是其实BatchNorm可能影响的方面是很多的,如域适应(domain adaptation)和对比学习中信息泄露问题。另外,这里讨论的4个方面也不是独立的,它们往往交织在一起。</p><h3 id="Population-Statistics"><a href="#Population-Statistics" class="headerlink" title="Population Statistics"></a><strong>Population Statistics</strong></h3><p>训练过程中的均值和方差是mini-batch计算出来的,但是在推理阶段往往是每次只处理一个sample,没有办法再计算依赖batch的统计量。BatchNorm采用的方法是训练过程中用EMA估计全局统计量,但是这个估计可能会够好:当$\lambda$较大时,每个iteration中mini-batch的统计量对全局统计量贡献很少,这会导致更新过慢;当$\lambda$较大时,每个iteration中mini-batch的统计量会起主导作用,导致估计值不能代表全局。一般情况$\lambda$取一个较大的值,如0.9或0.99,这是一个超参数。论文中在ResNet50的训练过程(256 GPU,每个GPU <code>batch_size</code>=32)随机选择模型的某个BatchNorm层的某个channel,绘制了其EMA mean以及population mean,这里的population mean采用当前模型在100 mini-batches的batch mean的平均值来估计,这个可以代表当前模型的全局统计量,对比图如下所示。在训练前期,从图a可以看到EMA mean和当前的batch mean是有距离的,而图b说明EMA mean是落后于当前模型的近似全局统计量的,但是到训练中后期EMA mean就比较准确了。</p><p><img src="https://s2.loli.net/2024/03/25/ZUbEKdvcnfOCTiR.png" alt="stable_exp.png"></p><p>这说明EMA统计量在训练早期是有偏差的。一个准确的全局统计量应该是:使用整个训练集作为一个batch计算特征的均值和方差,但是这个计算成本太高了,论文中提出采用一种近似方法来计算:首先采用固定模型(训练好的)计算很多mini-batch;然后聚合每个mini-batch的统计量来得到全局统计量。假定共需要计算$N$个samples,<code>batch_size</code>为$B$,那么共计算$k=N/B$个mini-batch,记它们的统计量为$\mu_{B_i}$,$\sigma^2_{B_i}(i=1,…,k)$,那么全局统计量可以近似这样计算:</p><p><img src="https://s2.loli.net/2024/03/25/32UhneIoX8RLsQD.png" alt="bn_form_2.png"></p><p>这其实只是一种聚合方式,论文附录也讨论了其它计算方式,结果是类似的。这种BatchNorm称为<code>PreciseBN</code>,具体代码实现可以参考<a href="https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/precise_bn.py">fvcore.nn.precise_bn</a>:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">_PopulationVarianceEstimator</span>:</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> Alternatively, one can estimate population variance by the sample variance</span></span><br><span class="line"><span class="string"> of all batches combined. This needs to use the batch size of each batch</span></span><br><span class="line"><span class="string"> in this function to undo the bessel-correction.</span></span><br><span class="line"><span class="string"> This produces better estimation when each batch is small.</span></span><br><span class="line"><span class="string"> See Appendix of the paper "Rethinking Batch in BatchNorm" for details.</span></span><br><span class="line"><span class="string"> In this implementation, we also take into account varying batch sizes.</span></span><br><span class="line"><span class="string"> A batch of N1 samples with a mean of M1 and a batch of N2 samples with a</span></span><br><span class="line"><span class="string"> mean of M2 will produce a population mean of (N1M1+N2M2)/(N1+N2) instead</span></span><br><span class="line"><span class="string"> of (M1+M2)/2.</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, mean_buffer: torch.Tensor, var_buffer: torch.Tensor</span>) -> <span class="literal">None</span>:</span><br><span class="line"> self.pop_mean: torch.Tensor = torch.zeros_like(mean_buffer) <span class="comment"># population mean</span></span><br><span class="line"> self.pop_square_mean: torch.Tensor = torch.zeros_like(var_buffer) <span class="comment"># population variance </span></span><br><span class="line"> self.tot = <span class="number">0</span> <span class="comment"># total samples</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># update per mini-batch, is called by `update_bn_stats`</span></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">update</span>(<span class="params"></span></span><br><span class="line"><span class="params"> self, batch_mean: torch.Tensor, batch_var: torch.Tensor, batch_size: <span class="built_in">int</span></span></span><br><span class="line"><span class="params"> </span>) -> <span class="literal">None</span>:</span><br><span class="line"> self.tot += batch_size</span><br><span class="line"> batch_square_mean = batch_mean.square() + batch_var * (</span><br><span class="line"> (batch_size - <span class="number">1</span>) / batch_size</span><br><span class="line"> )</span><br><span class="line"> self.pop_mean += (batch_mean - self.pop_mean) * (batch_size / self.tot)</span><br><span class="line"> self.pop_square_mean += (batch_square_mean - self.pop_square_mean) * (</span><br><span class="line"> batch_size / self.tot</span><br><span class="line"> )</span><br><span class="line"></span><br><span class="line"><span class="meta"> @property</span></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">pop_var</span>(<span class="params">self</span>) -> torch.Tensor:</span><br><span class="line"> <span class="keyword">return</span> self.pop_square_mean - self.pop_mean.square()</span><br></pre></td></tr></tbody></table></figure><p>论文中以ResNet50的训练为例对比了EMA和PreciseBN的效果,如下图所示,可以看到PreciseBN比EMA效果更加稳定,特别是训练早期(此时模型未收敛),虽然最终两者的效果接近。</p><p><img src="https://s2.loli.net/2024/03/25/wlPHIycaRTDFX2s.png" alt="train_stable_exp.png"></p><p>进一步地,如果训练采用更大的batch size,实验发现EMA在训练过程中的抖动更大,但此时PreciseBN效果比较稳定。当采用larger batch训练时,学习速率增大,而且EMA更新次数减少,这些都会对EMA产生较大影响。<strong>综上,虽然EMA和PreciseBN最终效果接近(因此EMA的缺点往往被忽视),但是在模型未收敛的训练早期,PreciseBN更加稳定,像强化学习需要在训练早期评估模型效果这种场景,PreciseBN能带来更加稳定可靠的结果。</strong></p><p><img src="https://s2.loli.net/2024/03/25/tmNWwo8vx5eJOXu.png" alt="train_stable_exp_1.png"></p><p>此外,论文也通过实验证明了<strong>PreciseBN只需要$10^3-10^4$ samples就可以得到比较好的结果</strong>,以ImageNet训练为例,采用PreciseBN评估只需要增加0.5%的训练时间。</p><p><img src="https://s2.loli.net/2024/03/25/LUMTa3rCyefqDiN.png" alt="add_exp_time.png"></p><p>另外,论文里还对比了batch size对PreciseBN的影响。这里先理清楚两个概念:(1)<code>normalization batch size</code>(NBS):实际计算统计量的mini-batch的size;(2)<code>total batch size</code>或者<code>SGD batch size</code>:每个iteration中mini-batch的size,或者说每执行一次SGD算法的batch size;两者在多卡训练过程是不等同的(此时NBS是per-GPU batch size,而<code>SyncBN</code>可以实现两者一致)。从结果来看,NBS较小时,模型效果会变差,但是<strong>PreciseBN的batch size是相对NBS独立的,所以选择batch size 时PreciseBN可以取得稳定的效果,并且在NBS较小时相比EMA提升效果</strong>。</p><p><img src="https://s2.loli.net/2024/03/25/FdLsPHol4zAC9Qe.png" alt="mae_exp_time.png"></p><h3 id="Batch-in-Training-and-Testing"><a href="#Batch-in-Training-and-Testing" class="headerlink" title="Batch in Training and Testing"></a><strong><strong>Batch in Training and Testing</strong></strong></h3><p>前面已经说过BatchNorm在训练时采用的是mini-batch统计量,而测试时采用的全局统计量,这就导致了训练和测试的不一致性,从而带来对模型性能的影响。为此,论文还是以ResNet50训练为例分析这种不一致带来的影响,这里还同时比较了不同NBS带来的差异(SGD batch size固定在1024,此时NBS从2~1024变化),分别对比不同NBS下的三个指标:(1)采用mini-batch统计量在训练集上的分类误差;(2)采用mini-batch统计量在验证集上的分类误差;(3)采用全局统计量在验证集上的分类误差。这里(1)和(3)其实是常规评估方法,而(2)往往不会这样做,但是(1)和(2)就保持一致了(训练和测试均采用mini-batch统计量)。对比结果如下所示,从中可以得到三个方面的结论:</p><ul><li><strong>training noise</strong>:训练集误差随着NBS增大而减少,这主要是由于SGD训练噪音所导致的,当NBS较小时,mini-batch统计量波动大导致优化困难,从而产生较大的训练误差;</li><li><strong>generalization gap</strong>:对比(1)和(2),两者均采用mini-batch统计量,差异就来自数据集不同,这部分性能差异就是泛化gap;</li><li><strong>train-test inconsistency</strong>:对比(2)和(3),两者数据集一样,但是(2)采用mini-batch统计量,而(3)采用全局统计量,这部分性能差异就是训练和测试不一致所导致的;</li></ul><p><img src="https://s2.loli.net/2024/03/25/yYnMCPkjcQaV8gZ.png" alt="exp_fig1.png"></p><p>另外,我们可以看到当NBS较小时,(2)和(3)的性能差距是比较大的,这说明<strong>如果训练的NBS比较小时在测试时采用mini-batch统计量效果会更好</strong>,此时<strong>保持一致是比较重要的</strong>(这点至关重要)。当NBS较大时,(2)和(3)的差异就比较小,此时mini-batch统计量越来越接近全局统计量。</p><p>虽然NBS较小时,在测试时采用mini-batch统计量效果更好,但是在实际场景中几乎不会这样处理(一般情况下都是每次处理一个样本)。不过还是有一些特例,比如两阶段检测模型R-CNN中,R-CNN的head输入是每个图像的一系列region-of-interest (RoIs),一般情况下一个图像会有$10^2-10^3$个RoIs,实际情况这些RoIs是组成batch进行处理的,训练过程是所有图像的RoIs,而测试时是单个图像的RoIs组成batch,在这种情况中测试时就可以选择mini-batch统计量。这里以Mask R-CNN为实验模型,将默认的<code>2fc box head</code></p><p>(2个全连接层)换成<code>4conv1fc head</code>(4个卷积层加一个,并且在box head和mask head的每个卷积层后面都加上BatchNorm层,实验结果如下,可以看到采用mini-batch统计量是优于采用全局统计量的,另外在训练过程中每个GPU只用一张图像时,此时测试时采用全局统计量效果会很差,这里有另外的过拟合问题存在,后面再述(BatchNorm导致的信息泄露)。另外R-CNN的head还存在另外的一种训练和测试的inconsistency:训练时mini-batch是正负样本抽样的,而测试时是按score选取的topK,mini-batch的分布就发生了变化。</p><p><img src="https://s2.loli.net/2024/03/25/OSZ5NAGDPu419Rc.png" alt="exp_fig2.png"></p><p>另外一个避免训练和测试的inconsistency可选方案是训练也采用全局统计量,常用的方案是Frozen BatchNorm (FrozenBN)(训练中直接采用EMA统计量模型无法训练),FrozenBN指的是采用一个提前算好的固定全局统计量,此时BatchNorm的训练优化就只有一个linear transform了。FrozenBN采用的情景是将一个已经训练好的模型迁移到其它任务,如在ImageNet训练的ResNet模型在迁移到下游检测任务时一般采用FrozenBN。不过我们也可以在模型的训练过程中采用FrozenBN,论文中还是以ResNet50为例,在前80个epoch采用正常的BN训练,在后20个epoch采用FrozenBN,对比效果如下,可以看到FrozenBN在NBS较小时也是表现较好,优于测试时采用mini-batch统计量,这不失为一种好的方案。这里值得注意的是当NBS较大时,FrozenBN还是测试时采用mini-batch统计量均差于常规方案(BN训练,测试时采用全局统计量)。</p><p><img src="https://s2.loli.net/2024/03/25/N3VMORZvClrPjDq.png" alt="exp_fig3.png"></p><p>包含BatchNorm的模型训练过程包含两个学习过程:一是模型主体参数是通过SGD学习得到的(<code>SGD training</code>),二是全局统计量是通过EMA或者PreciseBN从训练数据中学习得到(<code>population statistics training</code>)。当训练数据和测试数据分布不同时,我们称之为domain shift,这个时候学习得到的全局统计量就可能会在测试时失效,这个问题已经有论文提出要采用<a href="https://arxiv.org/abs/1603.04779">Adaptive BatchNorm</a>来解决,即在测试数据上重新计算全局统计量。这里还是以ResNet50为例(SGD batch size为1024,NBS为32),用ImageNet-C数据集(ImageNet的扰动版本,共三种类型:contrast,gaussian noise和jpeg compression)来评估domain shift问题,结果如下:</p><p><img src="https://s2.loli.net/2024/03/25/OUM4XelHdC63DWE.png" alt="exp_fig4.png"></p><p>从表中可以明显看出:<strong>当出现domain shift问题后,采用Adaptive BatchNorm在target domain数据集上重新计算全局统计量可以提升模型效果</strong>。不过从表最后一行可以看到,如果在ImageNet验证集上重新计算统计量(直接采用inference-time预处理),最终效果要稍微差于原来结果(23.4 VS 23.8),这可能说明如果不存在明显的domain shift,原始处理方式是最好的。</p><p>除了domain shift,训练数据存在multi-domain也会对BatchNorm产生影响,这个问题更复杂了。这里以RetinaNet模型来说明multi-domain的出现可能出现的问题。RetinaNet的head包含4个卷积层以及最终的分类器和回归器,其输入是来自不同尺度的5个特征($P_3, P_4, P_5,P_6,P_7$),这可以看成5个不同的domain。head在5个特征上是共享的,默认head是不包含BatchNorm,当我们在head的每个卷积后加上BatchNorm后,问题就变得复杂了。首先,首先就是SGD训练过程mini-batch统计量的计算,明显有两种不同处理方式,一是对不同domain的特征输入单独计算mini-batch统计量并单独归一化,二是将所有domain的特征concat在一起,计算一个mini-batch统计量来归一化。这两种处理方式如下所示:</p><p><img src="https://s2.loli.net/2024/03/25/oxLUhvpTBEAYfFD.png" alt="exp_fig5.png"></p><p>这里记上述SGD训练过程中的两种方式分别为<code>domain-specific statistics</code>和<code>shared statistics</code>。对于学习全局统计量,同样存在对应的两种方式,即每个domain的特征单独学习一套全局统计量,还是共享一套全局统计量。对于BatchNorm的affine transform layer也存在两种选择:每个domain一套参数还是共享参数。不同组合的模型效果如下表所示:</p><p><img src="https://s2.loli.net/2024/03/25/O9AMcNqm15Vl4H7.png" alt="exp_fig6.png"></p><p>从表中结果可以总结两个结论:(1)<strong>SGD training和population statistics training保持一致非常重要</strong>,此时都可以取得较好的结果(行1,行4和行6);(2)affine transform layer无论单独参数还是共享基本不影响结果。这里的一个小插曲是如果直接在head中加上BatchNorm,其实对应的是行3,其实这是因为不同尺度的特征是序列处理的,这就造成了SGD training其实是domain-specific的,但全局统计量是共享的,此时效果就较差,所以大部分实现中要不然不用norm,要不然就用GroupNorm。不同组合的实现代码如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 简单地加上BN,注意forward时,不同特征是串行处理的,那么SGD training其实是domain-specific的,</span></span><br><span class="line"><span class="comment"># 但是只维持一套全局统计量,所以测试时又是共享的</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">RetinaNetHead_Row3</span>:</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, num_conv, channel</span>):</span><br><span class="line"> head = []</span><br><span class="line"> <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(num_conv):</span><br><span class="line"> head.append(nn.Conv2d(channel, channel, <span class="number">3</span>))</span><br><span class="line"> head.append(nn.BatchNorm2d(channel))</span><br><span class="line"> self.head = nn.Sequential(∗head)</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, inputs: <span class="type">List</span>[Tensor]</span>):</span><br><span class="line"> <span class="keyword">return</span> [self.head(i) <span class="keyword">for</span> i <span class="keyword">in</span> inputs]</span><br><span class="line"></span><br><span class="line"><span class="comment"># 如果要共享,那么在forward时对特征进行concat来统一计算并归一化 </span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">RetinaNetHead_Row1</span>(<span class="title class_ inherited__">RetinaNetHead_Row3</span>):</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, inputs: <span class="type">List</span>[Tensor]</span>):</span><br><span class="line"> <span class="keyword">for</span> mod <span class="keyword">in</span> self.head:</span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">isinstance</span>(mod, nn.BatchNorm2d):</span><br><span class="line"> <span class="comment"># for BN layer, normalize all inputs together</span></span><br><span class="line"> shapes = [i.shape <span class="keyword">for</span> i <span class="keyword">in</span> inputs]</span><br><span class="line"> spatial_sizes = [s[<span class="number">2</span>] ∗ s[<span class="number">3</span>] <span class="keyword">for</span> s <span class="keyword">in</span> shapes]</span><br><span class="line"> x = [i.flatten(<span class="number">2</span>) <span class="keyword">for</span> i <span class="keyword">in</span> inputs]</span><br><span class="line"> x = torch.cat(x, dim=<span class="number">2</span>).unsqueeze(<span class="number">3</span>)</span><br><span class="line"> x = mod(x).split(spatial_sizes, dim=<span class="number">2</span>)</span><br><span class="line"> inputs = [i.view(s) <span class="keyword">for</span> s, i <span class="keyword">in</span> <span class="built_in">zip</span>(shapes, x)]</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="comment"># for conv layer, apply it separately</span></span><br><span class="line"> inputs = [mod(i) <span class="keyword">for</span> i <span class="keyword">in</span> inputs]</span><br><span class="line"> <span class="keyword">return</span> inputs</span><br><span class="line"></span><br><span class="line"><span class="comment"># 另外一种简单的处理方式是每个特征采用各自的BN</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">RetinaNetHead_Row6</span>:</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, num_conv, channel, num_features</span>):</span><br><span class="line"> <span class="comment"># num_features: number of features coming from</span></span><br><span class="line"> <span class="comment"># different FPN levels, e.g. 5</span></span><br><span class="line"> heads = [[] <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(num_levels)]</span><br><span class="line"> <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(num_conv):</span><br><span class="line"> conv = nn.Conv2d(channel, channel, <span class="number">3</span>)</span><br><span class="line"> <span class="keyword">for</span> h <span class="keyword">in</span> heads:</span><br><span class="line"> <span class="comment"># add a shared conv and a domain−specific BN</span></span><br><span class="line"> h.extend([conv, nn.BatchNorm2d(channel)])</span><br><span class="line"> self.heads = [nn.Sequential(∗h) <span class="keyword">for</span> h <span class="keyword">in</span> heads]</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, inputs: <span class="type">List</span>[Tensor]</span>):</span><br><span class="line"> <span class="comment"># end up with one head for each input</span></span><br><span class="line"> <span class="keyword">return</span> [head(i) <span class="keyword">for</span> head, i <span class="keyword">in</span></span><br><span class="line"> <span class="built_in">zip</span>(self.heads, inputs)]</span><br></pre></td></tr></tbody></table></figure><p>对于行2和行4,可以通过训练好的行1和行3模型重新在训练数据上计算domain-specific全局统计量即可,在实现时,可以如下:</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">CycleBatchNormList</span>(nn.ModuleList):</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> A hacky way to implement domain-specific BatchNorm</span></span><br><span class="line"><span class="string"> if it's guaranteed that a fixed number of domains will be</span></span><br><span class="line"><span class="string"> called with fixed order.</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, length, channels</span>):</span><br><span class="line"> <span class="built_in">super</span>().__init__([nn.BatchNorm2d(channels, affine=<span class="literal">False</span>) <span class="keyword">for</span> k <span class="keyword">in</span> <span class="built_in">range</span>(length)])</span><br><span class="line"> <span class="comment"># shared affine, domain-specific BN</span></span><br><span class="line"> self.weight = nn.Parameter(torch.ones(channels))</span><br><span class="line"> self.bias = nn.Parameter(torch.zeros(channels))</span><br><span class="line"> self._pos = <span class="number">0</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line"> ret = self[self._pos](x)</span><br><span class="line"> self._pos = (self._pos + <span class="number">1</span>) % <span class="built_in">len</span>(self)</span><br><span class="line"></span><br><span class="line"> w = self.weight.reshape(<span class="number">1</span>, -<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>)</span><br><span class="line"> b = self.bias.reshape(<span class="number">1</span>, -<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>)</span><br><span class="line"> <span class="keyword">return</span> ret * w + b</span><br><span class="line"></span><br><span class="line"><span class="comment"># 训练好模型,我们可以重新将BN层换成以上实现,就可以在训练数据上重新计算domain-specific全局统计量</span></span><br></pre></td></tr></tbody></table></figure><p>RetinaNet面临的其实是特征层面的multi-domain问题,而且每个batch中的各个domain是均匀的。如果是数据层面的multi-domain,其面临的问题将会复杂,此时domain的分布比例也是多变的(BatchNorm可能会偏向训练数据较多的那个domain),但是总的原则是尽量减少不一致性,因为<strong>consistency is crucial</strong>。</p><h3 id="Information-Leakage-within-a-Batch"><a href="#Information-Leakage-within-a-Batch" class="headerlink" title="Information Leakage within a Batch"></a><strong><strong>Information Leakage within a Batch</strong></strong></h3><p>BatchNorm面临的另外一个挑战,就是可能出现信息泄露,这里所说的信息泄露指的是模型学习到了利用mini-batch的信息来做预测,而这些其实并不是我们要学习的,因为这样模型可能难以对mini-batch里的每个sample单独做预测。</p><p><img src="https://s2.loli.net/2024/03/25/rEbjz38tDIfgNm6.png" alt="exp_fig7.png"></p><p>比如BatchNorm的作者曾做过这样一个实验,在ResNet50的训练过程中,NBS=32,但是保证里面包含16个类别,每个类别有2个图像,这样一种特殊的设计要模型在训练过程中强制记忆了这种模式(可能是每个mini-batch中必须有同类别存在),那么在测试时如果输入不是这种设计,效果就会变差。这个在验证集上不同处理结果如上所示,可以看到测试时无论是采用全局统计量还是random mini-batch统计量,效果均较差,但是如果采用和训练过程同样的模式,效果就比较好。这其实也从侧面说明保持一致是多么的重要。</p><p>前面说过,如果在R-CNN的head加入BatchNorm,那么在测试时采用mini-batch统计量会比全局统计量会效果更好,这里面其实也存在信息泄露的问题。对于每个GPU只有一个image的情况,每个mini-batch里面的RoIs全部来自于一个图像,这时候模型就可能依赖mini-batch来做预测,那么测试时采用全局统计量效果就会差了,对于每个GPU有多个图像时,情况还稍好一些,所以原来的结果中单卡单图像效果最差。一种解决方案是采用shuffle BN,就是head进行处理前,先随机打乱所有卡上的RoIs特征,每个卡分配随机的RoIs,这样就避免前面那个可能出现的信息泄露,head处理完后再shuffle回来,具体处理流程如下所示:</p><p><img src="https://s2.loli.net/2024/03/25/DLsXo2dYGUl1ev5.png" alt="exp_fig8.png"></p><p>这个具体的代码实现见<a href="https://github.com/facebookresearch/detectron2/blob/60fd4885d7cfd52d4267d1da9ebb6b2b9a3fc937/projects/Rethinking-BatchNorm/configs/mask_rcnn_BNhead_shuffle.py">mask_rcnn_BNhead_shuffle.py</a>。其实在MoCo中也使用了shuffle BN来防止信息泄露。另外还是可以采用SyncBN来避免这种问题(或者说是global BN,增大了mini-batch,这样就可以减弱上述影响)。具体的对比结果如下所示,可以看到采用shuffle BN和SyncBN均可以避免信息泄露,得到较好的效果。注意shuffle BN的 cross-GPU synchronization要比SyncBN要少,效率更高一些。</p><p>另外一个常见的场景是对比学习中信息泄露,因为对比学习往往需要对同一个图像做不同的augmentations来作为正样本,这其实一个sample既当输入又当目标,mini-batch可能会泄露信息导致模型学习不到好的特征(普通的BN是per-GPU normalize,这意味正样本的计算都在同一个local mini-batch中)。MoCo采用的是shuffle BN(其实是encode_k采用shuffle BN,这样两次正样本的计算就有区别),而SimCLR和BYOL采用的是SyncBN(扩大mini-batch减少影响)。另外旷视提出的<a href="https://arxiv.org/abs/2101.07525">Momentum^2 Teacher</a>来采用moving average statistics来防止信息泄露。一个插曲是这篇<a href="https://generallyintelligent.ai/blog/2020-08-24-understanding-self-supervised-contrastive-learning/">博客</a>指出其实BN才是不需要负样本的BYOL成功的关键,因为BN隐式地引入了负样本从而形成了对比学习,虽然后面BYOL又证明不需要BN也可以取得好的效果,但是还是比BN差一点。这说明BN确实能够隐式编码batch信息。</p><h3 id="总结"><a href="#总结" class="headerlink" title="总结"></a><strong>总结</strong></h3><p>一个简单的BatchNorm,如果我们使用不当,往往会出现一些让人意料的结果,所以要谨慎处理。总结来看,主要有如下结论和指南:</p><ul><li>模型在未收敛时使用EMA统计量来评估模型是不稳定的,一种替代方案是PreciseBN;</li><li>BatchNorm本身存在训练和测试的不一致性,特别是NBS较少时,这种不一致会更强,可用的方案是测试时也采用mini-batch统计量或者采用FrozenBN;</li><li>在domain shift场景中,最好基于target domain数据重新计算全局统计量,在multi-domain数据训练时,要特别注意处理的一致性;</li><li>BatchNorm会存在信息泄露的风险,这处理mini-batch时要防止特殊的出现。</li></ul><p>我个人认为下列两个原则可能是普适的:</p><ul><li>尽量减少训练和测试的不一致行为,不一致行为会导致测试时性能恶化;</li><li>尽量减少训练过程的bias而应适当增加noise,以防止模型训练走捷径而学习到无法泛化的特征。</li></ul><h1 id="Dropout"><a href="#Dropout" class="headerlink" title="Dropout"></a><strong><strong>Dropout</strong></strong></h1><p>dropout在训练时,以一定的概率p来drop掉相应的神经网络节点,以(1-p)的概率来retain相应的神经网络节点,这相当于每一次训练时模型的网络结构都不一样,也可以理解为训练时添加了噪声,所以能够有效减少过拟合。<br>问题呢,是出在测试时,因为训练的时候以概率p drop了一些节点,比如dropout设置为0.5,隐藏层共有6个节点,那训练的时候有3个节点的值被丢弃,而测试的时候这6个节点都被保留下来,这就导致了训练和测试的时候以该层节点为输入的下一层的神经网络节点获取的期望会有量级上的差异。为了解决这个问题,在训练时对当前dropout层的输出数据除以(1-p),之后再输入到下一层的神经元节点,以作为失活神经元的补偿,以使得在训练时和测试时每一层的输入有大致相同的期望。</p><h1 id="Normalization和Dropout搭配"><a href="#Normalization和Dropout搭配" class="headerlink" title="Normalization和Dropout搭配"></a><strong><strong>Normalization和Dropout搭配</strong></strong></h1><p>产生的问题:方差偏移</p><p><img src="https://s2.loli.net/2024/03/25/ai7ovKNGmP4zrEs.png" alt="exp_fig9.png"></p><p>首先,先明确dropout和BN结合使用使模型性能下降的连接方式,用通俗的话讲,就是你先在网络的内部使用dropout,随后再跟上一个BN层,而且这个BN层还不止一个。那么问题出在哪呢?原因有二。首先,如上图所示,因为训练时采用了dropout,虽然通过除以(1-p)的方式来使得训练和测试时,每个神经元输入的期望大致相同,但是他们的方差却不一样。第二,BN是采用训练时得到的均值和方差对数据进行归一化的,现在dropout层的方差都不一样了,一步错步步错,最终导致输出不准确,影响最后的性能。</p><p>针对方差偏移,<a href="https://arxiv.org/pdf/1801.05134.pdf">论文</a>给出了两种解决方案:</p><ul><li>拒绝方差偏移,只在所有BN层的后面采用dropout层,现在大部分开源的模型,都在网络的中间加了BN,你也就只能在softmax的前一层加加dropout了,我亲自试过,效果还行,至少不会比不加dropout差。还有另外一种方法是模型训练完后,固定参数,以测试模式对训练数据求BN的均值和方差,再对测试数据进行归一化,论文证明这种方法优于baseline。</li><li>dropout原文提出了一种高斯dropout,论文再进一步对高斯dropout进行扩展,提出了一个均匀分布Dropout,这样做带来了一个好处就是这个形式的Dropout(又称为“Uout”)对方差的偏移的敏感度降低了,总得来说就是整体方差偏地没有那么厉害了。可以看得出来实验性能整体上比第一个方案好,这个方法显得更加稳定。</li></ul>]]></content>
<summary type="html">Normalization的各种方式</summary>
<category term="模型优化" scheme="https://thinksky5124.github.io/categories/%E6%A8%A1%E5%9E%8B%E4%BC%98%E5%8C%96/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E6%A8%A1%E5%9E%8B%E4%BC%98%E5%8C%96/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="模型优化" scheme="https://thinksky5124.github.io/tags/%E6%A8%A1%E5%9E%8B%E4%BC%98%E5%8C%96/"/>
</entry>
<entry>
<title>Pytorch分布式训练踩坑</title>
<link href="https://thinksky5124.github.io/2022/08/18/%E8%B8%A9%E5%9D%91%E8%AE%B0/"/>
<id>https://thinksky5124.github.io/2022/08/18/%E8%B8%A9%E5%9D%91%E8%AE%B0/</id>
<published>2022-08-18T08:00:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<h1 id="踩坑记"><a href="#踩坑记" class="headerlink" title="踩坑记"></a>踩坑记</h1><h1 id="修改网络模块或者获得模型的某个参数"><a href="#修改网络模块或者获得模型的某个参数" class="headerlink" title="修改网络模块或者获得模型的某个参数"></a><strong>修改网络模块或者获得模型的某个参数</strong></h1><p>解决方法:<strong>model后面添加module</strong></p><p>获取到网络模型后,使用并行方法,并将网络模型和参数移到GPU上。<strong>注意,若需要修改网络模块或者获得模型的某个参数,一定要在model后面加上.module,否则会报错,比如:</strong></p><p><code>model.img_size 要改成 model.module.img_size</code></p><h1 id="最后一个batch卡死"><a href="#最后一个batch卡死" class="headerlink" title="最后一个batch卡死"></a>最后一个batch卡死</h1><p>现象:PyTorch 训练时在第一个 epoch 的最后一个 batch 卡死</p><p>原因:Pytorch 的多 GPU 处理接口是 <code>torch.nn.DataParallel(module, device_ids)</code>,该接口还要求输入数据的 batch 数量要不小于所指定的 GPU 数量。另根据官网的解释和注释 (The batch size should be larger than the number of GPUs used.),batch的数量会均分到每块GPU上进行处理,因此要保证一个整数的关系。</p><p>解决方法:一定要注意在使用多块 GPU 训练时,注意 <code>batch_size</code> 的取值,避免出现最后一个 batch 的实际size小于所指定的 GPU 数量的情况。</p><h1 id="模型训练不动了"><a href="#模型训练不动了" class="headerlink" title="模型训练不动了"></a>模型训练不动了</h1><p>现象:显卡利用率100%,但是模型不动了</p><p>原因:在模型训练时使用了一些同步原语,但是,模型的训练iter不一,数据长度也不一,导致进程自己死锁,无法继续进行</p><p>解决方式:自行写通讯结构</p>]]></content>
<summary type="html">Pytorch分布式训练踩坑</summary>
<category term="分布式训练" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/categories/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="分布式训练" scheme="https://thinksky5124.github.io/tags/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83/"/>
<category term="深度学习" scheme="https://thinksky5124.github.io/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
<category term="踩坑" scheme="https://thinksky5124.github.io/tags/%E8%B8%A9%E5%9D%91/"/>
</entry>
<entry>
<title>CMUParellelComputer课程概述</title>
<link href="https://thinksky5124.github.io/2022/07/21/class_info/"/>
<id>https://thinksky5124.github.io/2022/07/21/class_info/</id>
<published>2022-07-21T23:03:20.000Z</published>
<updated>2024-03-25T04:16:50.878Z</updated>
<content type="html"><![CDATA[<p>课程地址:<a href="http://www.cs.cmu.edu/afs/cs/academic/class/15418-f18/www/schedule.html">http://www.cs.cmu.edu/afs/cs/academic/class/15418-f18/www/schedule.html</a></p><h1 id="Lecture-1"><a href="#Lecture-1" class="headerlink" title="Lecture 1"></a>Lecture 1</h1><h2 id="为什么需要并行计算?"><a href="#为什么需要并行计算?" class="headerlink" title="为什么需要并行计算?"></a>为什么需要并行计算?</h2><p>单核CPU的性能几乎成指数级增长,而且英特尔在2004已经到达了单核的功耗墙。</p><p><img src="https://s2.loli.net/2024/03/25/vnUhqJ2XPaFVpdL.jpg" alt="单核CPU的性能"></p><p><img src="https://s2.loli.net/2024/03/25/rhltZgXQe36CzyJ.jpg" alt="CPU参数图"></p><p>在2004年之前想要让程序变得更快,买一个新的机器</p><p>但是2004年之后,需要进行并行编程。</p><h2 id="目前的CPU架构"><a href="#目前的CPU架构" class="headerlink" title="目前的CPU架构"></a>目前的CPU架构</h2><p><img src="https://s2.loli.net/2024/03/25/6ptGgeORBVDzoPN.jpg" alt="IntelSkylake"></p><p>最左边的是GPU,右边排列了4块CPU,中间是通讯总线。可见现代CPU都是采用多核设计,以达到更快的速度。</p><h2 id="什么是并行计算?"><a href="#什么是并行计算?" class="headerlink" title="什么是并行计算?"></a>什么是并行计算?</h2><p>A parallel computer is a collection of processing elements<br>that cooperate to solve problems quickly。</p><p>并行计算器,就是协调多个单核处理器以解决问题。</p><p>加速比</p><p>$$<br>speedup(using \quad P \quad processors) = \frac{execution \quad time(using \quad 1 \quad processors)}{execution \quad time (using \quad P \quad processors)}<br>$$</p><p>并行计算的原则</p><ul><li>交流信息严重限制了加速比</li><li>不平衡的工作分配会限制加速比</li></ul>]]></content>
<summary type="html">个人学习CMU并行计算课程时的笔记、总结等</summary>
<category term="并行计算" scheme="https://thinksky5124.github.io/categories/%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<category term="并行计算" scheme="https://thinksky5124.github.io/tags/%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<category term="CMU课程" scheme="https://thinksky5124.github.io/tags/CMU%E8%AF%BE%E7%A8%8B/"/>
</entry>
<entry>
<title>CPU并行计算</title>
<link href="https://thinksky5124.github.io/2022/07/21/cpu_computer/"/>
<id>https://thinksky5124.github.io/2022/07/21/cpu_computer/</id>
<published>2022-07-21T23:03:20.000Z</published>
<updated>2024-03-25T04:19:12.246Z</updated>
<content type="html"><![CDATA[<h1 id="Lecture-2-A-Modern-Multi-Core-Processor"><a href="#Lecture-2-A-Modern-Multi-Core-Processor" class="headerlink" title="Lecture 2 A Modern Multi-Core Processor"></a>Lecture 2 A Modern Multi-Core Processor</h1><ul><li>理解并行计算的形式</li><li>理解延迟(latency)和带宽(bandwidth)</li></ul><h2 id="Parallel-Execution"><a href="#Parallel-Execution" class="headerlink" title="Parallel Execution"></a>Parallel Execution</h2><h3 id="单线程执行-编译器定义"><a href="#单线程执行-编译器定义" class="headerlink" title="单线程执行 - 编译器定义"></a>单线程执行 - 编译器定义</h3><p>例程:使用泰勒公式计算$sin(x)$。</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">sinx</span><span class="params">(<span class="type">int</span> N,<span class="type">int</span> terms, <span class="type">float</span>* x,<span class="type">float</span>* result)</span></span>{</span><br><span class="line"> <span class="keyword">for</span>(<span class="type">int</span> i=<span class="number">0</span>;i<N;i++){</span><br><span class="line"> <span class="type">float</span> value = x[i];</span><br><span class="line"> <span class="type">float</span> numer = x[i] * x[i] *x[i];</span><br><span class="line"> <span class="type">int</span> denom = <span class="number">6</span>;</span><br><span class="line"> <span class="type">int</span> sign = <span class="number">-1</span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span>(<span class="type">int</span> j=<span class="number">1</span>;j<= terms;j++){</span><br><span class="line"> value += sign * numer / denom;</span><br><span class="line"> numer *= x[i] * x[i];</span><br><span class="line"> denom *= (<span class="number">2</span>*j+<span class="number">2</span>)*(<span class="number">2</span>*j+<span class="number">3</span>);</span><br><span class="line"> sign *= <span class="number">-1</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> result[i]=value;</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><p><img src="https://s2.loli.net/2024/03/25/JDzKdeycxpHo59h.jpg" alt="单核单线程单流水处理模型"></p><ul><li>解码模块(Fetch/Decode)读取指令</li><li>ALU模块负责执行</li><li>上下文存储器(Exceution Context)负责存储执行数据</li></ul><h3 id="多线程执行-用户定义"><a href="#多线程执行-用户定义" class="headerlink" title="多线程执行 - 用户定义"></a>多线程执行 - 用户定义</h3><p>概念:指令级并行 instruction level parallelism (ILP)</p><p><img src="https://s2.loli.net/2024/03/25/oN83kVxGqgOEJUd.jpg" alt="超标量处理器模型"><br><img src="https://s2.loli.net/2024/03/25/2meGLwOuWCJj4is.jpg" alt="多核处理器模型"></p><p>例程:</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">typedef</span> <span class="keyword">struct</span>{</span><br><span class="line"> <span class="type">int</span> N;</span><br><span class="line"> <span class="type">int</span> terms;</span><br><span class="line"> <span class="type">float</span> *x;</span><br><span class="line"> <span class="type">float</span> result;</span><br><span class="line">} my_args;</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">parallel_sinx</span><span class="params">(<span class="type">int</span> N,<span class="type">int</span> terms, <span class="type">float</span>* x,<span class="type">float</span>* result)</span></span>{</span><br><span class="line"> <span class="type">pthread_t</span> thread_id;</span><br><span class="line"> my_args args;</span><br><span class="line"></span><br><span class="line"> args.N=<span class="number">2</span>/N;</span><br><span class="line"> args.terms=terms;</span><br><span class="line"> args.x=x;</span><br><span class="line"> args.result=result;</span><br><span class="line"></span><br><span class="line"> <span class="built_in">pthread_create</span>(&thread_id, <span class="literal">NULL</span>, my_thread_start, &args);<span class="comment">//launch thread</span></span><br><span class="line"> <span class="built_in">sinx</span>(N-args.N,terms,x+args.N,result+args.N);</span><br><span class="line"> <span class="built_in">pthread_join</span>(thread_id, <span class="literal">NULL</span>);</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">sinx</span><span class="params">(<span class="type">int</span> N,<span class="type">int</span> terms, <span class="type">float</span>* x,<span class="type">float</span>* result)</span></span>{</span><br><span class="line"> <span class="keyword">for</span>(<span class="type">int</span> i=<span class="number">0</span>;i<N;i++){</span><br><span class="line"> <span class="type">float</span> value = x[i];</span><br><span class="line"> <span class="type">float</span> numer = x[i] * x[i] *x[i];</span><br><span class="line"> <span class="type">int</span> denom = <span class="number">6</span>;</span><br><span class="line"> <span class="type">int</span> sign = <span class="number">-1</span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span>(<span class="type">int</span> j=<span class="number">1</span>;j<= terms;j++){</span><br><span class="line"> value += sign * numer / denom;</span><br><span class="line"> numer *= x[i] * x[i];</span><br><span class="line"> denom *= (<span class="number">2</span>*j+<span class="number">2</span>)*(<span class="number">2</span>*j+<span class="number">3</span>);</span><br><span class="line"> sign *= <span class="number">-1</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> result[i]=value;</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><p>上述例程表示线程级并行的代码描述,通过将整个工作分散分配到不同的处理器核中,以获得加速的效果。工作的分配方式将很大程度上影响加速的效果。</p><h3 id="数据并行-用户定义、编译器定义"><a href="#数据并行-用户定义、编译器定义" class="headerlink" title="数据并行 - 用户定义、编译器定义"></a>数据并行 - 用户定义、编译器定义</h3><p><img src="https://s2.loli.net/2024/03/25/T8FkUAmjd2saXMn.jpg" alt="SMID"></p><p>概念:SMID 单指令多数据处理</p><p>使用AVX指令集的代码如下:</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><immintrin.h></span></span></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">sinx</span><span class="params">(<span class="type">int</span> N, <span class="type">int</span> terms, <span class="type">float</span>* x, <span class="type">float</span>* sinx)</span></span></span><br><span class="line"><span class="function"></span>{</span><br><span class="line"> <span class="type">float</span> three_fact = <span class="number">6</span>; <span class="comment">// 3!</span></span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> i=<span class="number">0</span>; i<N; i+=<span class="number">8</span>)</span><br><span class="line"> {</span><br><span class="line"> __m256 origx = _mm256_load_ps(&x[i]);</span><br><span class="line"> __m256 value = origx;</span><br><span class="line"> __m256 numer = _mm256_mul_ps(origx, _mm256_mul_ps(origx, origx));</span><br><span class="line"> __m256 denom = _mm256_broadcast_ss(&three_fact);</span><br><span class="line"> <span class="type">int</span> sign = <span class="number">-1</span>;</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> j=<span class="number">1</span>; j<=terms; j++)</span><br><span class="line"> {</span><br><span class="line"> <span class="comment">// value += sign * numer / denom</span></span><br><span class="line"> __m256 tmp = _mm256_div_ps(_mm256_mul_ps(_mm256_broadcast_ss(sign),numer),denom);</span><br><span class="line"> value = _mm256_add_ps(value, tmp);</span><br><span class="line"> numer = _mm256_mul_ps(numer, _mm256_mul_ps(origx, origx));</span><br><span class="line"> denom = _mm256_mul_ps(denom, _mm256_broadcast_ss((<span class="number">2</span>*j+<span class="number">2</span>) * (<span class="number">2</span>*j+<span class="number">3</span>)));</span><br><span class="line"> sign *= <span class="number">-1</span>;</span><br><span class="line"> }</span><br><span class="line"> _mm256_store_ps(&sinx[i], value);</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><p>按照上图所示CPU模型,通过使用SMID单个核可以一次指令同时处理8个数据,达到并行处理数据的效果。</p><p>另一种自动优化代码的写法,前提是循环loop之间相互独立</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">sinx</span><span class="params">(<span class="type">int</span> N, <span class="type">int</span> terms, <span class="type">float</span>* x, <span class="type">float</span>* result)</span></span></span><br><span class="line"><span class="function"></span>{</span><br><span class="line"> <span class="comment">// declare independent loop iterations</span></span><br><span class="line"> forall (<span class="type">int</span> i from <span class="number">0</span> to N<span class="number">-1</span>)</span><br><span class="line"> {</span><br><span class="line"> <span class="type">float</span> value = x[i];</span><br><span class="line"> <span class="type">float</span> numer = x[i] * x[i] * x[i];</span><br><span class="line"> <span class="type">int</span> denom = <span class="number">6</span>; <span class="comment">// 3!</span></span><br><span class="line"> <span class="type">int</span> sign = <span class="number">-1</span>;</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> j=<span class="number">1</span>; j<=terms; j++)</span><br><span class="line"> {</span><br><span class="line"> value += sign * numer / denom</span><br><span class="line"> numer *= x[i] * x[i];</span><br><span class="line"> denom *= (<span class="number">2</span>*j+<span class="number">2</span>) * (<span class="number">2</span>*j+<span class="number">3</span>);</span><br><span class="line"> sign *= <span class="number">-1</span>;</span><br><span class="line"> }</span><br><span class="line"> result[i] = value;</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><h3 id="条件执行"><a href="#条件执行" class="headerlink" title="条件执行"></a>条件执行</h3><p>假设数据并行处理器需要执行如下代码</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">float</span> x = A[i];</span><br><span class="line"><span class="keyword">if</span>(x><span class="number">0</span>){</span><br><span class="line"> <span class="type">float</span> tmp = <span class="built_in">exp</span>(x,<span class="number">5.f</span>);</span><br><span class="line"> x = tmp + kMyConst2;</span><br><span class="line">}<span class="keyword">else</span>{</span><br><span class="line"> <span class="type">float</span> tmp = kMyConst1;</span><br><span class="line"> x = <span class="number">2.f</span> * tmp;</span><br><span class="line">}</span><br><span class="line">result[i] = x;</span><br></pre></td></tr></tbody></table></figure><p>在处理器进行分支预测之后,处理器会首先并行执行值未真的那些ALU,再之后执行那些值为假的ALU,以达到并行处理的效果。</p><p><img src="https://s2.loli.net/2024/03/25/MPSRkb4tX7DHnwu.png" alt="在数据并行时进行条件执行"></p><h3 id="SMID的有关概念"><a href="#SMID的有关概念" class="headerlink" title="SMID的有关概念"></a>SMID的有关概念</h3><ul><li>要使用并行处理需要程序手动进行特定编码,比如使用:SSE、AVX等代码。</li><li>但是一旦使用并行处理,编译器无法检查与保证循环的独立性,需要自行保证</li><li>GPU的数据并行处理性能要高于CPU的数据并行处理性能</li></ul><p>比如:</p><p><img src="https://s2.loli.net/2024/03/25/PAJuyXgRSYkdBnt.png" alt="CPU与GPU的SMID能力对比"></p><h2 id="Accessing-Memory"><a href="#Accessing-Memory" class="headerlink" title="Accessing Memory"></a>Accessing Memory</h2><h3 id="术语"><a href="#术语" class="headerlink" title="术语"></a>术语</h3><ul><li>内存延时:内存延迟是指等待对系统内存中存储数据的访问完成时引起的延期。 单位:100机器周期、100毫秒</li><li>内存带宽:内存系统可以提供给处理器数据的速度。 单位:20GB/s</li><li>吞吐量(throughput):芯片单位时间处理数据的多少</li></ul><p>现代处理器中不可避免因为内存延时而降低CPU处理速率,但是可以通过一些办法来“隐藏”内存延时,比如:</p><ul><li>多级缓存</li><li>预储存</li><li>多线程技术</li><li>存储执行上下文</li></ul><p><img src="https://s2.loli.net/2024/03/25/Kmn3pBR7udSCboV.png" alt="CPU与GPU的内存结构对比"></p>]]></content>
<summary type="html">个人学习CMU并行计算课程时的笔记、总结等</summary>
<category term="并行计算" scheme="https://thinksky5124.github.io/categories/%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<category term="并行计算" scheme="https://thinksky5124.github.io/tags/%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<category term="CMU课程" scheme="https://thinksky5124.github.io/tags/CMU%E8%AF%BE%E7%A8%8B/"/>
</entry>
<entry>
<title>数字图像处理基础(二)</title>
<link href="https://thinksky5124.github.io/2021/01/28/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86%E5%9F%BA%E7%A1%802/"/>
<id>https://thinksky5124.github.io/2021/01/28/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86%E5%9F%BA%E7%A1%802/</id>
<published>2021-01-28T03:55:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<p>本系列是参考<a href="https://baike.baidu.com/item/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86%EF%BC%88%E7%AC%AC%E4%B8%89%E7%89%88%EF%BC%89">冈萨雷斯《数字图像处理》</a>与<a href="http://vision.stanford.edu/teaching/cs131_fall2021/index.html">斯坦福CS131课程</a>进行自我总结而成的数字图像处理基础知识。</p><h1 id="概念"><a href="#概念" class="headerlink" title="概念"></a>概念</h1><ul><li>空间域:指的是图像平面本身。空间域中的图像处理方法直接对图像中的像素进行处理</li><li>变换域:如频率域。对图像进行变换,变换到变换域中。在变换域中进行处理,处理之后再对图像进行反变换,把结果带回空间域。</li></ul><h1 id="空间域变换表达式"><a href="#空间域变换表达式" class="headerlink" title="空间域变换表达式"></a>空间域变换表达式</h1><p>$$ g(x,y) = T[ f(x,y) ] \tag{1}$$</p><p>式中$f(x,y)$是输入图像,$g(x,y)$是输出图像,$T$是在点$(x,y)$的一个领域上定义的针对$f$的算子,或者叫变换函数。</p><h1 id="灰度变换"><a href="#灰度变换" class="headerlink" title="灰度变换"></a>灰度变换</h1><p>灰度变换作为增强图像的一种方法,通常用于调整图像对比度,使得图像满足于“特定”的要求。</p><p>常用灰度变换函数有</p><ul><li>反转变换</li><li>对数变换</li><li>幂律变换</li><li>分段线性函数变换</li><li>比特平面分层 - 可用于图像压缩与重构</li></ul><h1 id="图像直方图"><a href="#图像直方图" class="headerlink" title="图像直方图"></a>图像直方图</h1><p>参考:<a href="https://blog.csdn.net/qq_38701868/article/details/89215881">https://blog.csdn.net/qq_38701868/article/details/89215881</a></p><p>简单来说,直方图就是对数据进行统计的一种方法,并且将统计值组织到一系列实现定义好的 bin 当中。其中, bin 为直方图中经常用到的一个概念,可以译为 “直条” 或 “组距”,其数值是从数据中计算出的特征统计量,这些数据可以是诸如梯度、方向、色彩或任何其他特征。且无论如何,直方图获得的是数据分布的统计图。通常直方图的维数要低于原始数据。</p><p>图像直方图(Image Histogram)是用以表示数字图像中亮度分布的直方图,标绘了图像中每个亮度值的像素数。这种直方图中,横坐标的左侧为纯黑、较暗的区域,而右侧为较亮、纯白的区域。因此一张较暗图片的直方图中的数据多集中于左侧和中间部分,而整体明亮、只有少量阴影的图像则相反。</p><p>图像直方图的数学定义如下<br>$$ h(r_k) = n_k , k=0,1,2, \cdots ,L-1 \tag{2}$$<br>更加常用的是归一化后的图像直方图<br>$$ p(r_k) = \frac{h(r_k)}{MN} = \frac{n_k}{MN} \tag{3} $$</p><p><img src="https://img-blog.csdnimg.cn/20190324115843708.png" alt="图像直方图示意"></p><h1 id="空间滤波基础"><a href="#空间滤波基础" class="headerlink" title="空间滤波基础"></a>空间滤波基础</h1><h2 id="线性滤波器"><a href="#线性滤波器" class="headerlink" title="线性滤波器"></a>线性滤波器</h2><p>线性空间滤波器在图像$f$和滤波器核$w$之间执行乘积运算。核是一个阵列,其大小定义了运算的邻域,其系数决定了该滤波器的性质。</p><p>相关滤波器公式<br>$$ g(x,y) = \sum_{s=-a}^{a} \sum_{t=-b}^{b} w(s,t)f(x+s,y+t) \tag{4}$$<br>该公式适用于任何奇数大小的核。</p><p>形象化表示如下图</p><p><img src="https://img-blog.csdnimg.cn/20200416181230334.gif" alt="空间滤波示意图"></p><p>图片参考:<a href="https://blog.csdn.net/IT_charge/article/details/105563188">https://blog.csdn.net/IT_charge/article/details/105563188</a></p><p>注意到如果从图像的左上角作为第一个元素计算,有一部分的元素并没有定义,这是后通常采用填充的方法进行运算,一般填充0,但是并不只有填充0这一种方法。而上图采用的是缩小滤波范围的方法,即不从左上角第一元素开始。</p><h2 id="相关滤波和卷积滤波"><a href="#相关滤波和卷积滤波" class="headerlink" title="相关滤波和卷积滤波"></a>相关滤波和卷积滤波</h2><p>相关滤波和卷积滤波不同的是,卷积将相关滤波器的核旋转了180°。</p><pre><code>为什么通常采用卷积而不采用相关?将核旋转180°之后,再执行滤波操作在通过离散单位冲激函数后会得到核的一个副本,而且卷积运算满足交换律结合律和分配律。</code></pre><p><img src="https://img-blog.csdnimg.cn/20200416184210853.png" alt="相关滤波和卷积滤波一维的过程"></p><p>注:$f=[0\ 0\ 0\ 1\ 0\ 0\ 0]$称为离散单位冲激函数</p><p>核旋转180°相当于原本的核绕横轴和纵轴各翻转一次,或者绕对角线翻转一次。</p><p>卷积滤波的公式<br>$$ (w*f)(x,y) = \sum_{s=-a}^{a} \sum_{t=-b}^{b} w(s,t)f(x-s,y-t) \tag{5}$$<br>注:以后讲解滤波时,线性空间滤波与空间卷积同义。</p><h2 id="卷积和相关的一些基本运算性质"><a href="#卷积和相关的一些基本运算性质" class="headerlink" title="卷积和相关的一些基本运算性质"></a>卷积和相关的一些基本运算性质</h2><p>破折号表示性质不成立</p><table><thead><tr><th>性质</th><th align="center">卷积</th><th align="right">相关</th></tr></thead><tbody><tr><td>交换律</td><td align="center">$f<em>g=g</em>f$</td><td align="right">-</td></tr><tr><td>结合律</td><td align="center">$f*(g<em>h)=(f</em>g)*h$</td><td align="right">-</td></tr><tr><td>分配律</td><td align="center">$f*(g+h)=(f<em>g)+(f</em>h)$</td><td align="right">$f*(g+h)=(f<em>g)+(f</em>h)$</td></tr></tbody></table><h2 id="补充有关卷积的一些知识点"><a href="#补充有关卷积的一些知识点" class="headerlink" title="补充有关卷积的一些知识点"></a>补充有关卷积的一些知识点</h2><h3 id="卷积之后图像大小"><a href="#卷积之后图像大小" class="headerlink" title="卷积之后图像大小"></a>卷积之后图像大小</h3><p>若卷积核和图像的大小分别为$m×n$和$M×N$,则在图像的顶部和底部分别补上$(m-1)$行0,在图像的左侧和右侧分别补上$(n-1)$列0,在这些条件下,卷积之后的图像大小为<br>$$ S_v=m+M-1 \quad S_h=n+N-1 $$</p><h3 id="卷积可压缩"><a href="#卷积可压缩" class="headerlink" title="卷积可压缩"></a>卷积可压缩</h3><pre><code>原理:卷积运算满足交换律</code></pre><p>对于Q阶段的滤波器,如果一个图像$f$首先通过卷积核$w_1$,之后又通过卷积核$w_2$,以此类推,则滤波可在单阶段完成$w*f$,其中<br>$$ w=w_1 * w_2 * w_3 * \cdots * w_Q $$</p>]]></content>
<summary type="html">灰度变换与空间滤波</summary>
<category term="数字图像处理" scheme="https://thinksky5124.github.io/categories/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86/"/>
<category term="计算机视觉" scheme="https://thinksky5124.github.io/categories/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/"/>
<category term="图像基础" scheme="https://thinksky5124.github.io/categories/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/%E5%9B%BE%E5%83%8F%E5%9F%BA%E7%A1%80/"/>
<category term="数字图像处理" scheme="https://thinksky5124.github.io/tags/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86/"/>
<category term="图像处理基本概念" scheme="https://thinksky5124.github.io/tags/%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86%E5%9F%BA%E6%9C%AC%E6%A6%82%E5%BF%B5/"/>
</entry>
<entry>
<title>数字图像处理基础(一)</title>
<link href="https://thinksky5124.github.io/2021/01/27/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86%E5%9F%BA%E7%A1%801/"/>
<id>https://thinksky5124.github.io/2021/01/27/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86%E5%9F%BA%E7%A1%801/</id>
<published>2021-01-27T02:51:00.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<p>本系列是参考<a href="https://baike.baidu.com/item/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86%EF%BC%88%E7%AC%AC%E4%B8%89%E7%89%88%EF%BC%89">冈萨雷斯《数字图像处理》</a>与<a href="http://vision.stanford.edu/teaching/cs131_fall2021/index.html">斯坦福CS131课程</a>进行自我总结而成的数字图像处理基础知识。</p><h1 id="数字图像表示"><a href="#数字图像表示" class="headerlink" title="数字图像表示"></a>数字图像表示</h1><h2 id="数学形式定义"><a href="#数学形式定义" class="headerlink" title="数学形式定义"></a>数学形式定义</h2><p>令$f(s,t)$表示连续图像函数,通过采样和量化可以将这幅连续图像转化成数字图像$f(x,y)$,该图像有$M$行$N$列,其中$f(x,y)$是离散坐标,函数$f(x,y)$的值表示灰度值。</p><p>计算集为了方便存储和处理通常将图像表示成矩阵形式,形式如下<br>$$ A=\left[<br> \begin{matrix}<br> a_{0,0} & a_{0,1} & \cdots & a_{0,N-1} \<br> a_{1,0} & a_{1,1} & \cdots & a_{1,N-1} \<br> \vdots & \vdots & \ddots & \vdots \<br> a_{M-1,0} & a_{M-1,1} & \cdots & a_{M-1,N-1}<br> \end{matrix}<br> \right] \tag{1}<br>$$</p><p>许多图像显示器都是从左上角开始向右移动,每次扫描一行。习惯上将左上角作为矩阵的第一个元,这也符合笛卡尔坐标系。</p><p>$f(x,y)$函数的值$L$离散化后,通常称为灰度级数,一般取2的整数倍<br>$$ L=2^{k} \tag{2}$$<br>$L$的取值在$[0,L-1]$区间内,这个区域称作动态范围。</p><p>一幅数字图像存储所需要的比特数$b$可以表示为</p><p>$$b=MNk \tag{3} $$</p><h2 id="饱和度、噪声与对比度"><a href="#饱和度、噪声与对比度" class="headerlink" title="饱和度、噪声与对比度"></a>饱和度、噪声与对比度</h2><h3 id="饱和度"><a href="#饱和度" class="headerlink" title="饱和度"></a>饱和度</h3><p>饱和度是指一个最大值,超过该值的所有灰度值都会被裁掉,图像显示则显示显示器所能显示的最大亮度</p><h3 id="噪声"><a href="#噪声" class="headerlink" title="噪声"></a>噪声</h3><p>图像噪声是指存在于图像数据中的不必要的或多余的干扰信息。</p><p><img src="https://s2.loli.net/2024/03/25/ye1LdGtQJ6goUI8.jpg" alt="图像噪声.jpg"></p><h4 id="对比度与反差比"><a href="#对比度与反差比" class="headerlink" title="对比度与反差比"></a>对比度与反差比</h4><p>对比度是指图像中最高和最低灰度级间的灰度差,反差比则是它们之间的比率</p><h2 id="空间分辨率和灰度分辨率"><a href="#空间分辨率和灰度分辨率" class="headerlink" title="空间分辨率和灰度分辨率"></a>空间分辨率和灰度分辨率</h2><h3 id="空间分辨率"><a href="#空间分辨率" class="headerlink" title="空间分辨率"></a>空间分辨率</h3><p>空间分辨率的单位是点数$dpi$,越大说明图像显示的原本信息更细腻</p><h3 id="灰度分辨率"><a href="#灰度分辨率" class="headerlink" title="灰度分辨率"></a>灰度分辨率</h3><p>灰度分辨率是指灰度级中可分辨的最小变化,一般采用8比特,也有使用16、32比特,但是不常见。注:在16级灰度或者跟小的灰度等级中,会出现伪轮廓,类似地图中的等高线。</p><h1 id="图像内插"><a href="#图像内插" class="headerlink" title="图像内插"></a>图像内插</h1><p>内插通常在图像放大、缩小、旋转和几何校正等任务中使用。内插在放大缩小中使用重采样的方法,内插是用已知数据来估计未知位置的值的过程。</p><h2 id="最邻近插值法"><a href="#最邻近插值法" class="headerlink" title="最邻近插值法"></a>最邻近插值法</h2><p>将原图像中最邻近的灰度赋给每一个新位置。将目标图像中的点,对应到原图像中后,找到最相邻的整数坐标点的像素值,作为该点的像素值输出。</p><p><img src="https://s2.loli.net/2024/03/25/kOa5gEhS6QPCGWB.jpg" alt="最邻近插值法.jpg"></p><p>这种方法虽然简单,但是容易失真。</p><h2 id="双线性内插"><a href="#双线性内插" class="headerlink" title="双线性内插"></a>双线性内插</h2><p>双线性内插使用4个最近的灰度来计算给定位置的灰度。令$(x,y)$表示待赋值的灰度值的位置,$v(x,y)$表示灰度值,公式如下:</p><p>$$ v(x,y) = ax+by +cxy + d \tag{4}$$</p><p>4个系数可有点$(x,y)$的四个最近的点写出四个未知方程求出。</p><h2 id="双三次内插"><a href="#双三次内插" class="headerlink" title="双三次内插"></a>双三次内插</h2><p>双线性内插使用16个最近的灰度来计算给定位置的灰度。令$(x,y)$表示待赋值的灰度值的位置,$v(x,y)$表示灰度值,公式如下:</p><p>$$ v(x,y) = \sum_{i=0}^{3}\sum_{j=0}^{3}a_{ij}x^iy^j \tag{5}$$</p><p>16个系数可有点$(x,y)$的16个最近的点写出16个未知方程求出。它是Adobe商业公司使用的标准内插法。</p><h2 id="区别"><a href="#区别" class="headerlink" title="区别"></a>区别</h2><p>图像的精细程度:双三次内插法>双线性内插>最邻近插值法</p><h1 id="像素的基本关系"><a href="#像素的基本关系" class="headerlink" title="像素的基本关系"></a>像素的基本关系</h1><p>参考:<br>作者:Lemon雷<br>链接:<a href="https://www.jianshu.com/p/2aef925ed39e">https://www.jianshu.com/p/2aef925ed39e</a><br>来源:简书</p><h2 id="相邻的定义"><a href="#相邻的定义" class="headerlink" title="相邻的定义"></a>相邻的定义</h2><p>两个像素连通的两个条件是:</p><ol><li><p>两个像素的位置是否相邻</p></li><li><p>两个像素的灰度值是否满足特定的相似性准则(同时满足某种条件,比如在某个集合内或者相等)</p></li></ol><p>我们令V是用于定义连通性的灰度值集合。比如V={x|0<x<125} (x是指像素点的灰度值)。那么:</p><h3 id="4连通"><a href="#4连通" class="headerlink" title="4连通"></a>4连通</h3><p>对于灰度值在V集合中的像素p和q,如果q在p的4邻域中(即N4(p)),那么称像素p和q是4连通的</p><p><img src="http://upload-images.jianshu.io/upload_images/4158488-0aaf776971c469e2.png" alt="4连通示意图"></p><h3 id="8连通"><a href="#8连通" class="headerlink" title="8连通"></a>8连通</h3><p>对于灰度值在V集合中的像素p和q,如果q在p的8邻域中(即N8(p)),那么称像素p和q是8连通的</p><p><img src="http://upload-images.jianshu.io/upload_images/4158488-d475cf2b8c170276.png" alt="8连通示意图"></p><h3 id="m连通(混合连通)"><a href="#m连通(混合连通)" class="headerlink" title="m连通(混合连通)"></a>m连通(混合连通)</h3><p>对于灰度值在V集合中的像素p和q,如果:</p><ol><li><p>q在p的4邻域中,或者</p></li><li><p>q在p的D邻域中,并且p的4邻域与q的4邻域的交集是空的(即没有灰度值在V集合中的像素点)</p></li></ol><p>那么称这两个像素是是m连通的,即4连通和D连通的混合连通。</p><p><img src="http://upload-images.jianshu.io/upload_images/4158488-842e7913041974b0.png" alt="m连通示意图"></p><p>注:m连通(混合连通)是8连通的改进版,这个概念的提出就是为了消除8连通的二义性</p><h2 id="距离测度"><a href="#距离测度" class="headerlink" title="距离测度"></a>距离测度</h2><h3 id="欧几里得距离"><a href="#欧几里得距离" class="headerlink" title="欧几里得距离"></a>欧几里得距离</h3><p>$$ D_{e}^{(p,q)} = [ (x-u)^2 + (y-v)^2 ]^{\frac{1}{2}} $$</p><h3 id="D4距离"><a href="#D4距离" class="headerlink" title="D4距离"></a>D4距离</h3><p>$$ D_{4}^{(p,q)} = |x-u|+|y-v| $$</p><h3 id="D8距离"><a href="#D8距离" class="headerlink" title="D8距离"></a>D8距离</h3><p>$$ D_{8}^{(p,q)} = max(|x-u|,|y-v|) $$</p><h1 id="加性图像降噪的数学原理"><a href="#加性图像降噪的数学原理" class="headerlink" title="加性图像降噪的数学原理"></a>加性图像降噪的数学原理</h1><p>假设图像 $f(x,y)$ 是被加性噪声$\eta(x,y)$污染后的图像,也即<br>$$ g(x,y)= f(x,y) + \eta(x,y) \tag{6}$$<br>其中$\eta(x,y)$噪声在每个坐标上是不相关的,并且均值为0。</p><p>若图像噪声满足上述关系,可以证明对$K$幅图像进行平均得到:<br>$$ \bar{g}(x,y) = \frac{1}{K} \sum_{i=1}^K g_i(x,y) \tag{7}$$<br>$$ E{\bar{g}(x,y)} = f(x,y) \tag{8}$$<br>$$ \sigma_{\bar{g}(x,y)}^2 = \frac{1}{K} \sigma_{\eta(x,y)}^2 \tag{9}$$</p><p>可知当$K$逐渐变大时,图像的噪声水平越低。</p><p>下面为加入了高斯白噪声的图片</p><p><img src="http://accu.cc/img/pil/agwn/jp_agwn.jpg" alt="噪声图片"></p><p>使用如上原理去除噪声后</p><p><img src="http://accu.cc/img/pil/agwn/jp_denoise.jpg" alt="除噪后图片"></p><p>图片来源:<a href="http://accu.cc/content/pil/agwn/">http://accu.cc/content/pil/agwn/</a></p><h1 id="比较图像"><a href="#比较图像" class="headerlink" title="比较图像"></a>比较图像</h1><h2 id="相减方法"><a href="#相减方法" class="headerlink" title="相减方法"></a>相减方法</h2><p>$$ g(x,y) = f(x,y) - h(x,y) \tag{10}$$<br>f(x,y)为模板图像,h(x,y)为摄影图像</p><h2 id="阴影校正"><a href="#阴影校正" class="headerlink" title="阴影校正"></a>阴影校正</h2><p>假设g(x,y)为采样得到的图像,f(x,y)为理想图像,h(x,y)为阴影<br>$$ g(x,y) = f(x,y) h(x,y) \tag{11}$$<br>通过乘以h(x,y)的反函数,即可获得理想图像。</p><h2 id="作用"><a href="#作用" class="headerlink" title="作用"></a>作用</h2><p>可以用来校正阴影和获得ROI。</p><h2 id="运算公式"><a href="#运算公式" class="headerlink" title="运算公式"></a>运算公式</h2><p>$$ g_m = g - min(g) \tag{12}$$<br>$$ g_s = K [ g_m / max(g_m) ] \tag{13}$$</p><h2 id="几何运算"><a href="#几何运算" class="headerlink" title="几何运算"></a>几何运算</h2><p>仿射变换可以完成图像的缩放、旋转、平移或剪切变换<br>$$<br> \left[<br> \begin{matrix}<br> x’ \<br> y’ \<br> 1<br> \end{matrix}<br> \right]=T\left[<br> \begin{matrix}<br> x \<br> y \<br> 1<br> \end{matrix}<br> \right]=<br> \left[<br> \begin{matrix}<br> a_{1,1} & a_{1,2} & a_{1,3} \<br> a_{2,1} & a_{2,2} & a_{2,3} \<br> 0 & 0 & 1<br> \end{matrix}<br> \right]\left[<br> \begin{matrix}<br> x \<br> y \<br> 1<br> \end{matrix}<br> \right]<br> \tag{14}<br>$$</p><h3 id="典型的仿射变换矩阵"><a href="#典型的仿射变换矩阵" class="headerlink" title="典型的仿射变换矩阵"></a>典型的仿射变换矩阵</h3><p>变换名称|仿射矩阵$T$|<br>—|:–:|—:<br>恒等|$\left[\begin{matrix}1 & 0 & 0 \0 & 1 & 0 \0 & 0 & 1\end{matrix}\right]$ |<br>缩入\反射|$\left[\begin{matrix}c_x & 0 & 0 \0 & c_y & 0 \0 & 0 & 1\end{matrix}\right]$ |<br>关于原点旋转|$\left[\begin{matrix}cos\theta & -sin\theta & 0 \sin\theta & cos\theta & 0 \0 & 0 & 1\end{matrix}\right]$ |<br>平移|$\left[\begin{matrix}1 & 0 & t_x \0 & 1 & t_y \0 & 0 & 1\end{matrix}\right]$ |<br>垂直剪切|$\left[\begin{matrix}1 & s_v & 0 \0 & 1 & 0 \0 & 0 & 1\end{matrix}\right]$ |<br>水平剪切|$\left[\begin{matrix}1 & 0 & 0 \s_h & 1 & 0 \0 & 0 & 1\end{matrix}\right]$ |</p><p>注意使用区分:前向映射和后向映射</p><p><img src="https://img-blog.csdn.net/20150404170230472" alt="前向映射"></p><p><img src="https://img-blog.csdn.net/20150404172327459" alt="后向映射"></p><h1 id="色彩空间"><a href="#色彩空间" class="headerlink" title="色彩空间"></a>色彩空间</h1><p>参考:百度百科、维基百科</p><p>色彩是人的眼睛对于不同频率的光线的不同感受,色彩既是客观存在的(不同频率的光)又是主观感知的,有认识差异。所以人类对于色彩的认识经历了极为漫长的过程,直到近代才逐步完善起来,但至今,人类仍不能说对色彩完全了解并准确表述了,许多概念不是那么容易理解。“色彩空间”一词源于西方的“Color Space”,又称作“色域”,色彩学中,人们建立了多种色彩模型,以一维、二维、三维甚至四维空间坐标来表示某一色彩,这种坐标系统所能定义的色彩范围即色彩空间。我们经常用到的色彩空间主要有RGB、CMYK、Lab等。</p><h2 id="CIE-1931色彩空间"><a href="#CIE-1931色彩空间" class="headerlink" title="CIE 1931色彩空间"></a>CIE 1931色彩空间</h2><p>在CIE XYZ色彩空间中,三色刺激值并不是指人类眼睛对短、中和长波(S、M和L)的反应,而是一组称为X、Y和Z的值,约略对应于红色、绿色和蓝色(但要留意X、Y和Z值并不是真的看起来是红、绿和蓝色,而是从红色、绿色和蓝色导出来的参数),并使用CIE 1931 XYZ颜色匹配函数来计算。</p><p><img src="https://pic4.zhimg.com/80/v2-34b10192f00ee84add06ee0c74148d97_1440w.jpg" alt="CIE XYZ"></p><p>$$ X(Red)+Y(Green)+Z(Blue)=1 \tag{15}$$<br>知道其中两个值就可以根据[公式15]推算另外第三个值,所以色彩空间可以用二维坐标表示,上图圈起来的范围是人类肉眼可见的颜色范围。</p><h2 id="非线性色彩空间HSV"><a href="#非线性色彩空间HSV" class="headerlink" title="非线性色彩空间HSV"></a>非线性色彩空间HSV</h2><p>参考:<a href="https://blog.csdn.net/weixin_43269204/article/details/94628987">https://blog.csdn.net/weixin_43269204/article/details/94628987</a></p><p>HSV(Hue,Saturation,Value)是根据颜色的直观特性由A. R. Smith在1978年创建的一种颜色空间, 也称六角锥体模型,HSV即色相H(Hue)、饱和度S(Saturation)、明度V(Value)。色相是色彩的基本属性,就是平常说的颜色的名称,如红色、黄色等,H由绕V轴的旋转角给定,红色对应于角度0°,绿色对应于角度120°,蓝色对应于角度240°,每一种颜色和它的补色相差180°;饱和度(S)是指色彩的纯度,越高色彩越纯,低则逐渐变灰;明度(V),颜色明亮的程度,明度值与发光体的光亮度有关。圆锥的顶点处,V=0,代表黑色,圆锥的顶面中心处V=1,S=0,代表白色,其空间模型如图所示。</p><p><img src="https://img-blog.csdnimg.cn/20190704164221409.png" alt="HSV"></p><h2 id="RBG颜色空间"><a href="#RBG颜色空间" class="headerlink" title="RBG颜色空间"></a>RBG颜色空间</h2><p>RGB颜色空间以R(Red:红)、G(Green:绿)、B(Blue:蓝)三种基本色为基础,进行不同程度的叠加,产生丰富而广泛的颜色,所以俗称三基色模式。在大自然中有无穷多种不同的颜色,而人眼只能分辨有限种不同的颜色,RGB模式可表示一千六百多万种不同的颜色,在人眼看来它非常接近大自然的颜色,故又称为自然色彩模式。红绿蓝代表可见光谱中的三种基本颜色或称为三原色,每一种颜色按其亮度的不同分为256个等级。当色光三原色重叠时,由于不同的混色比例能产生各种中间色。</p><p><img src="https://img-blog.csdnimg.cn/20190704164624824.png" alt="RGB"></p><h1 id="白平衡"><a href="#白平衡" class="headerlink" title="白平衡"></a>白平衡</h1><p>白平衡是对感器接收的图像数据进行调整,以正确的渲染处理中性色(白色,灰色等)。此调整由数码相机自动执行(自定义设置为不同的光);胶片相机提供多种不同的滤镜和胶片类型,为不同的拍摄条件。</p><p><img src="https://th.bing.com/th/id/R1b55b5e64bd5cb55098a72dbdeb6b8fd?rik=84LAoCaWj8FdAw&riu=http://img.fotomen.cn/2011/11/white-balance07.jpg" alt="白平衡"></p>]]></content>
<summary type="html">数字图像的基础运算和常用概念和定义等。</summary>
<category term="数字图像处理" scheme="https://thinksky5124.github.io/categories/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86/"/>
<category term="计算机视觉" scheme="https://thinksky5124.github.io/categories/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/"/>
<category term="图像基础" scheme="https://thinksky5124.github.io/categories/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86/%E8%AE%A1%E7%AE%97%E6%9C%BA%E8%A7%86%E8%A7%89/%E5%9B%BE%E5%83%8F%E5%9F%BA%E7%A1%80/"/>
<category term="数字图像处理" scheme="https://thinksky5124.github.io/tags/%E6%95%B0%E5%AD%97%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86/"/>
<category term="图像处理基本概念" scheme="https://thinksky5124.github.io/tags/%E5%9B%BE%E5%83%8F%E5%A4%84%E7%90%86%E5%9F%BA%E6%9C%AC%E6%A6%82%E5%BF%B5/"/>
</entry>
<entry>
<title>Faster-RCNN的pytorch实现</title>
<link href="https://thinksky5124.github.io/2021/01/12/Faster-RCNN%E7%9A%84pytorch%E5%AE%9E%E7%8E%B0/"/>
<id>https://thinksky5124.github.io/2021/01/12/Faster-RCNN%E7%9A%84pytorch%E5%AE%9E%E7%8E%B0/</id>
<published>2021-01-12T10:43:27.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<p>代码来源:<a href="https://github.com/jwyang/faster-rcnn.pytorch">https://github.com/jwyang/faster-rcnn.pytorch</a></p><h1 id="代码解读"><a href="#代码解读" class="headerlink" title="代码解读"></a>代码解读</h1><h2 id="把python2的新版本特性导入到python2当前版本"><a href="#把python2的新版本特性导入到python2当前版本" class="headerlink" title="把python2的新版本特性导入到python2当前版本"></a>把python2的新版本特性导入到python2当前版本</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> __future__ <span class="keyword">import</span> absolute_import <span class="comment">#设置成绝对引用</span></span><br><span class="line"><span class="keyword">from</span> __future__ <span class="keyword">import</span> division <span class="comment">#使用精确除法</span></span><br><span class="line"><span class="keyword">from</span> __future__ <span class="keyword">import</span> print_function <span class="comment">#打印函数</span></span><br></pre></td></tr></tbody></table></figure><h2 id="导入库"><a href="#导入库" class="headerlink" title="导入库"></a>导入库</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> _init_paths <span class="comment">#自定义包含文件 用于设置环境路径</span></span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> sys</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> argparse</span><br><span class="line"><span class="keyword">import</span> pprint <span class="comment">#美化打印库</span></span><br><span class="line"><span class="keyword">import</span> pdb <span class="comment">#python调试库</span></span><br><span class="line"><span class="keyword">import</span> time <span class="comment">#python时间函数库,用于计算运行时间</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch.autograd <span class="keyword">import</span> Variable</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">import</span> torch.optim <span class="keyword">as</span> optim</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> torchvision.transforms <span class="keyword">as</span> transforms <span class="comment">#pytorch处理图像视频的torchvision</span></span><br><span class="line"><span class="keyword">from</span> torch.utils.data.sampler <span class="keyword">import</span> Sampler</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> roi_data_layer.roidb <span class="keyword">import</span> combined_roidb</span><br><span class="line"><span class="keyword">from</span> roi_data_layer.roibatchLoader <span class="keyword">import</span> roibatchLoader</span><br><span class="line"><span class="keyword">from</span> model.utils.config <span class="keyword">import</span> cfg, cfg_from_file, cfg_from_list, get_output_dir</span><br><span class="line"><span class="keyword">from</span> model.utils.net_utils <span class="keyword">import</span> weights_normal_init, save_net, load_net, \</span><br><span class="line"> adjust_learning_rate, save_checkpoint, clip_gradient</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> model.faster_rcnn.vgg16 <span class="keyword">import</span> vgg16</span><br><span class="line"><span class="keyword">from</span> model.faster_rcnn.resnet <span class="keyword">import</span> resnet</span><br></pre></td></tr></tbody></table></figure><h2 id="命令行参数定义"><a href="#命令行参数定义" class="headerlink" title="命令行参数定义"></a>命令行参数定义</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">parse_args</span>():</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> Parse input arguments</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> parser = argparse.ArgumentParser(description=<span class="string">'Train a Fast R-CNN network'</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--dataset'</span>, dest=<span class="string">'dataset'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'training dataset'</span>,</span><br><span class="line"> default=<span class="string">'pascal_voc'</span>, <span class="built_in">type</span>=<span class="built_in">str</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--net'</span>, dest=<span class="string">'net'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'vgg16, res101'</span>,</span><br><span class="line"> default=<span class="string">'vgg16'</span>, <span class="built_in">type</span>=<span class="built_in">str</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--start_epoch'</span>, dest=<span class="string">'start_epoch'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'starting epoch'</span>,</span><br><span class="line"> default=<span class="number">1</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--epochs'</span>, dest=<span class="string">'max_epochs'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'number of epochs to train'</span>,</span><br><span class="line"> default=<span class="number">20</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--disp_interval'</span>, dest=<span class="string">'disp_interval'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'number of iterations to display'</span>,</span><br><span class="line"> default=<span class="number">100</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--checkpoint_interval'</span>, dest=<span class="string">'checkpoint_interval'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'number of iterations to display'</span>,</span><br><span class="line"> default=<span class="number">10000</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"></span><br><span class="line"> parser.add_argument(<span class="string">'--save_dir'</span>, dest=<span class="string">'save_dir'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'directory to save models'</span>, default=<span class="string">"models"</span>,</span><br><span class="line"> <span class="built_in">type</span>=<span class="built_in">str</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--nw'</span>, dest=<span class="string">'num_workers'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'number of workers to load data'</span>,</span><br><span class="line"> default=<span class="number">0</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--cuda'</span>, dest=<span class="string">'cuda'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'whether use CUDA'</span>,</span><br><span class="line"> action=<span class="string">'store_true'</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--ls'</span>, dest=<span class="string">'large_scale'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'whether use large imag scale'</span>,</span><br><span class="line"> action=<span class="string">'store_true'</span>) </span><br><span class="line"> parser.add_argument(<span class="string">'--mGPUs'</span>, dest=<span class="string">'mGPUs'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'whether use multiple GPUs'</span>,</span><br><span class="line"> action=<span class="string">'store_true'</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--bs'</span>, dest=<span class="string">'batch_size'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'batch_size'</span>,</span><br><span class="line"> default=<span class="number">1</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--cag'</span>, dest=<span class="string">'class_agnostic'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'whether to perform class_agnostic bbox regression'</span>,</span><br><span class="line"> action=<span class="string">'store_true'</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># config optimization</span></span><br><span class="line"> parser.add_argument(<span class="string">'--o'</span>, dest=<span class="string">'optimizer'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'training optimizer'</span>,</span><br><span class="line"> default=<span class="string">"sgd"</span>, <span class="built_in">type</span>=<span class="built_in">str</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--lr'</span>, dest=<span class="string">'lr'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'starting learning rate'</span>,</span><br><span class="line"> default=<span class="number">0.001</span>, <span class="built_in">type</span>=<span class="built_in">float</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--lr_decay_step'</span>, dest=<span class="string">'lr_decay_step'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'step to do learning rate decay, unit is epoch'</span>,</span><br><span class="line"> default=<span class="number">5</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--lr_decay_gamma'</span>, dest=<span class="string">'lr_decay_gamma'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'learning rate decay ratio'</span>,</span><br><span class="line"> default=<span class="number">0.1</span>, <span class="built_in">type</span>=<span class="built_in">float</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># set training session</span></span><br><span class="line"> parser.add_argument(<span class="string">'--s'</span>, dest=<span class="string">'session'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'training session'</span>,</span><br><span class="line"> default=<span class="number">1</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># resume trained model</span></span><br><span class="line"> parser.add_argument(<span class="string">'--r'</span>, dest=<span class="string">'resume'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'resume checkpoint or not'</span>,</span><br><span class="line"> default=<span class="literal">False</span>, <span class="built_in">type</span>=<span class="built_in">bool</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--checksession'</span>, dest=<span class="string">'checksession'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'checksession to load model'</span>,</span><br><span class="line"> default=<span class="number">1</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--checkepoch'</span>, dest=<span class="string">'checkepoch'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'checkepoch to load model'</span>,</span><br><span class="line"> default=<span class="number">1</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"> parser.add_argument(<span class="string">'--checkpoint'</span>, dest=<span class="string">'checkpoint'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'checkpoint to load model'</span>,</span><br><span class="line"> default=<span class="number">0</span>, <span class="built_in">type</span>=<span class="built_in">int</span>)</span><br><span class="line"><span class="comment"># log and display</span></span><br><span class="line"> parser.add_argument(<span class="string">'--use_tfb'</span>, dest=<span class="string">'use_tfboard'</span>,</span><br><span class="line"> <span class="built_in">help</span>=<span class="string">'whether use tensorboard'</span>,</span><br><span class="line"> action=<span class="string">'store_true'</span>)</span><br><span class="line"></span><br><span class="line"> args = parser.parse_args()</span><br><span class="line"> <span class="keyword">return</span> args</span><br></pre></td></tr></tbody></table></figure><h2 id="定义类"><a href="#定义类" class="headerlink" title="定义类"></a>定义类</h2><p>样本类</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">sampler</span>(<span class="title class_ inherited__">Sampler</span>):</span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, train_size, batch_size</span>):</span><br><span class="line"> self.num_data = train_size</span><br><span class="line"> self.num_per_batch = <span class="built_in">int</span>(train_size / batch_size)</span><br><span class="line"> self.batch_size = batch_size</span><br><span class="line"> self.<span class="built_in">range</span> = torch.arange(<span class="number">0</span>,batch_size).view(<span class="number">1</span>, batch_size).long()</span><br><span class="line"> self.leftover_flag = <span class="literal">False</span></span><br><span class="line"> <span class="keyword">if</span> train_size % batch_size:</span><br><span class="line"> self.leftover = torch.arange(self.num_per_batch*batch_size, train_size).long()</span><br><span class="line"> self.leftover_flag = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__iter__</span>(<span class="params">self</span>):</span><br><span class="line"> rand_num = torch.randperm(self.num_per_batch).view(-<span class="number">1</span>,<span class="number">1</span>) * self.batch_size</span><br><span class="line"> self.rand_num = rand_num.expand(self.num_per_batch, self.batch_size) + self.<span class="built_in">range</span></span><br><span class="line"></span><br><span class="line"> self.rand_num_view = self.rand_num.view(-<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> self.leftover_flag:</span><br><span class="line"> self.rand_num_view = torch.cat((self.rand_num_view, self.leftover),<span class="number">0</span>)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">iter</span>(self.rand_num_view)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__len__</span>(<span class="params">self</span>):</span><br><span class="line"> <span class="keyword">return</span> self.num_data</span><br></pre></td></tr></tbody></table></figure><h2 id="训练主体"><a href="#训练主体" class="headerlink" title="训练主体"></a>训练主体</h2><h3 id="训练设置载入"><a href="#训练设置载入" class="headerlink" title="训练设置载入"></a>训练设置载入</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br></pre></td><td class="code"><pre><span class="line">args = parse_args()</span><br><span class="line"></span><br><span class="line"> <span class="built_in">print</span>(<span class="string">'Called with args:'</span>)</span><br><span class="line"> <span class="built_in">print</span>(args)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> args.dataset == <span class="string">"pascal_voc"</span>:</span><br><span class="line"> args.imdb_name = <span class="string">"voc_2007_trainval"</span></span><br><span class="line"> args.imdbval_name = <span class="string">"voc_2007_test"</span></span><br><span class="line"> args.set_cfgs = [<span class="string">'ANCHOR_SCALES'</span>, <span class="string">'[8, 16, 32]'</span>, <span class="string">'ANCHOR_RATIOS'</span>, <span class="string">'[0.5,1,2]'</span>, <span class="string">'MAX_NUM_GT_BOXES'</span>, <span class="string">'20'</span>]</span><br><span class="line"> <span class="keyword">elif</span> args.dataset == <span class="string">"pascal_voc_0712"</span>:</span><br><span class="line"> args.imdb_name = <span class="string">"voc_2007_trainval+voc_2012_trainval"</span></span><br><span class="line"> args.imdbval_name = <span class="string">"voc_2007_test"</span></span><br><span class="line"> args.set_cfgs = [<span class="string">'ANCHOR_SCALES'</span>, <span class="string">'[8, 16, 32]'</span>, <span class="string">'ANCHOR_RATIOS'</span>, <span class="string">'[0.5,1,2]'</span>, <span class="string">'MAX_NUM_GT_BOXES'</span>, <span class="string">'20'</span>]</span><br><span class="line"> <span class="keyword">elif</span> args.dataset == <span class="string">"coco"</span>:</span><br><span class="line"> args.imdb_name = <span class="string">"coco_2014_train+coco_2014_valminusminival"</span></span><br><span class="line"> args.imdbval_name = <span class="string">"coco_2014_minival"</span></span><br><span class="line"> args.set_cfgs = [<span class="string">'ANCHOR_SCALES'</span>, <span class="string">'[4, 8, 16, 32]'</span>, <span class="string">'ANCHOR_RATIOS'</span>, <span class="string">'[0.5,1,2]'</span>, <span class="string">'MAX_NUM_GT_BOXES'</span>, <span class="string">'50'</span>]</span><br><span class="line"> <span class="keyword">elif</span> args.dataset == <span class="string">"imagenet"</span>:</span><br><span class="line"> args.imdb_name = <span class="string">"imagenet_train"</span></span><br><span class="line"> args.imdbval_name = <span class="string">"imagenet_val"</span></span><br><span class="line"> args.set_cfgs = [<span class="string">'ANCHOR_SCALES'</span>, <span class="string">'[4, 8, 16, 32]'</span>, <span class="string">'ANCHOR_RATIOS'</span>, <span class="string">'[0.5,1,2]'</span>, <span class="string">'MAX_NUM_GT_BOXES'</span>, <span class="string">'30'</span>]</span><br><span class="line"> <span class="keyword">elif</span> args.dataset == <span class="string">"vg"</span>:</span><br><span class="line"> <span class="comment"># train sizes: train, smalltrain, minitrain</span></span><br><span class="line"> <span class="comment"># train scale: ['150-50-20', '150-50-50', '500-150-80', '750-250-150', '1750-700-450', '1600-400-20']</span></span><br><span class="line"> args.imdb_name = <span class="string">"vg_150-50-50_minitrain"</span></span><br><span class="line"> args.imdbval_name = <span class="string">"vg_150-50-50_minival"</span></span><br><span class="line"> args.set_cfgs = [<span class="string">'ANCHOR_SCALES'</span>, <span class="string">'[4, 8, 16, 32]'</span>, <span class="string">'ANCHOR_RATIOS'</span>, <span class="string">'[0.5,1,2]'</span>, <span class="string">'MAX_NUM_GT_BOXES'</span>, <span class="string">'50'</span>]</span><br><span class="line"></span><br><span class="line"> args.cfg_file = <span class="string">"cfgs/{}_ls.yml"</span>.<span class="built_in">format</span>(args.net) <span class="keyword">if</span> args.large_scale <span class="keyword">else</span> <span class="string">"cfgs/{}.yml"</span>.<span class="built_in">format</span>(args.net)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> args.cfg_file <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line"> cfg_from_file(args.cfg_file)</span><br><span class="line"> <span class="keyword">if</span> args.set_cfgs <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line"> cfg_from_list(args.set_cfgs)</span><br><span class="line"></span><br><span class="line"> <span class="built_in">print</span>(<span class="string">'Using config:'</span>)</span><br><span class="line"> pprint.pprint(cfg)</span><br><span class="line"> np.random.seed(cfg.RNG_SEED)</span><br><span class="line"></span><br><span class="line"> <span class="comment">#torch.backends.cudnn.benchmark = True</span></span><br><span class="line"> <span class="keyword">if</span> torch.cuda.is_available() <span class="keyword">and</span> <span class="keyword">not</span> args.cuda:</span><br><span class="line"> <span class="built_in">print</span>(<span class="string">"WARNING: You have a CUDA device, so you should probably run with --cuda"</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="训练集准备"><a href="#训练集准备" class="headerlink" title="训练集准备"></a>训练集准备</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># train set</span></span><br><span class="line"><span class="comment"># -- Note: Use validation set and disable the flipped to enable faster loading.</span></span><br><span class="line">cfg.TRAIN.USE_FLIPPED = <span class="literal">True</span></span><br><span class="line">cfg.USE_GPU_NMS = args.cuda</span><br><span class="line">imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdb_name)</span><br><span class="line"><span class="comment"># imdb 是数据本身</span></span><br><span class="line"><span class="comment"># roidb 是ROI数据</span></span><br><span class="line">train_size = <span class="built_in">len</span>(roidb)</span><br></pre></td></tr></tbody></table></figure>]]></content>
<summary type="html"><p>代码来源:<a href="https://github.com/jwyang/faster-rcnn.pytorch">https://github.com/jwyang/faster-rcnn.pytorch</a></p>
<h1 id="代码解读"><a href=</summary>
<category term="人工智能学习" scheme="https://thinksky5124.github.io/categories/%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E5%AD%A6%E4%B9%A0/"/>
<category term="目标检测" scheme="https://thinksky5124.github.io/categories/%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E5%AD%A6%E4%B9%A0/%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B/"/>
<category term="目标检测" scheme="https://thinksky5124.github.io/tags/%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B/"/>
<category term="Faster-RCNN" scheme="https://thinksky5124.github.io/tags/Faster-RCNN/"/>
<category term="pytorch" scheme="https://thinksky5124.github.io/tags/pytorch/"/>
</entry>
<entry>
<title>CPU并行计算</title>
<link href="https://thinksky5124.github.io/2021/01/07/CPU%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<id>https://thinksky5124.github.io/2021/01/07/CPU%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/</id>
<published>2021-01-07T10:27:06.000Z</published>
<updated>2024-04-16T08:57:05.643Z</updated>
<content type="html"><![CDATA[<h1 id="Lecture-2-A-Modern-Multi-Core-Processor"><a href="#Lecture-2-A-Modern-Multi-Core-Processor" class="headerlink" title="Lecture 2 A Modern Multi-Core Processor"></a>Lecture 2 A Modern Multi-Core Processor</h1><ul><li>理解并行计算的形式</li><li>理解延迟(latency)和带宽(bandwidth)</li></ul><h2 id="Parallel-Execution"><a href="#Parallel-Execution" class="headerlink" title="Parallel Execution"></a>Parallel Execution</h2><h3 id="单线程执行-编译器定义"><a href="#单线程执行-编译器定义" class="headerlink" title="单线程执行 - 编译器定义"></a>单线程执行 - 编译器定义</h3><p>例程:使用泰勒公式计算$sin(x)$。</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">sinx</span><span class="params">(<span class="type">int</span> N,<span class="type">int</span> terms, <span class="type">float</span>* x,<span class="type">float</span>* result)</span></span>{</span><br><span class="line"> <span class="keyword">for</span>(<span class="type">int</span> i=<span class="number">0</span>;i<N;i++){</span><br><span class="line"> <span class="type">float</span> value = x[i];</span><br><span class="line"> <span class="type">float</span> numer = x[i] * x[i] *x[i];</span><br><span class="line"> <span class="type">int</span> denom = <span class="number">6</span>;</span><br><span class="line"> <span class="type">int</span> sign = <span class="number">-1</span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span>(<span class="type">int</span> j=<span class="number">1</span>;j<= terms;j++){</span><br><span class="line"> value += sign * numer / denom;</span><br><span class="line"> numer *= x[i] * x[i];</span><br><span class="line"> denom *= (<span class="number">2</span>*j+<span class="number">2</span>)*(<span class="number">2</span>*j+<span class="number">3</span>);</span><br><span class="line"> sign *= <span class="number">-1</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> result[i]=value;</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/JDzKdeycxpHo59h.jpg","alt":"单核单线程单流水处理模型"}]</div> <div class="gallery-items"> </div> </div><ul><li>解码模块(Fetch/Decode)读取指令</li><li>ALU模块负责执行</li><li>上下文存储器(Exceution Context)负责存储执行数据</li></ul><h3 id="多线程执行-用户定义"><a href="#多线程执行-用户定义" class="headerlink" title="多线程执行 - 用户定义"></a>多线程执行 - 用户定义</h3><p>概念:指令级并行 instruction level parallelism (ILP)</p><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/oN83kVxGqgOEJUd.jpg","alt":"超标量处理器模型"},{"url":"https://s2.loli.net/2024/03/25/2meGLwOuWCJj4is.jpg","alt":"多核处理器模型"}]</div> <div class="gallery-items"> </div> </div><p>例程:</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">typedef</span> <span class="keyword">struct</span>{</span><br><span class="line"> <span class="type">int</span> N;</span><br><span class="line"> <span class="type">int</span> terms;</span><br><span class="line"> <span class="type">float</span> *x;</span><br><span class="line"> <span class="type">float</span> result;</span><br><span class="line">} my_args;</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">parallel_sinx</span><span class="params">(<span class="type">int</span> N,<span class="type">int</span> terms, <span class="type">float</span>* x,<span class="type">float</span>* result)</span></span>{</span><br><span class="line"> <span class="type">pthread_t</span> thread_id;</span><br><span class="line"> my_args args;</span><br><span class="line"></span><br><span class="line"> args.N=<span class="number">2</span>/N;</span><br><span class="line"> args.terms=terms;</span><br><span class="line"> args.x=x;</span><br><span class="line"> args.result=result;</span><br><span class="line"></span><br><span class="line"> <span class="built_in">pthread_create</span>(&thread_id, <span class="literal">NULL</span>, my_thread_start, &args);<span class="comment">//launch thread</span></span><br><span class="line"> <span class="built_in">sinx</span>(N-args.N,terms,x+args.N,result+args.N);</span><br><span class="line"> <span class="built_in">pthread_join</span>(thread_id, <span class="literal">NULL</span>);</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">sinx</span><span class="params">(<span class="type">int</span> N,<span class="type">int</span> terms, <span class="type">float</span>* x,<span class="type">float</span>* result)</span></span>{</span><br><span class="line"> <span class="keyword">for</span>(<span class="type">int</span> i=<span class="number">0</span>;i<N;i++){</span><br><span class="line"> <span class="type">float</span> value = x[i];</span><br><span class="line"> <span class="type">float</span> numer = x[i] * x[i] *x[i];</span><br><span class="line"> <span class="type">int</span> denom = <span class="number">6</span>;</span><br><span class="line"> <span class="type">int</span> sign = <span class="number">-1</span>;</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span>(<span class="type">int</span> j=<span class="number">1</span>;j<= terms;j++){</span><br><span class="line"> value += sign * numer / denom;</span><br><span class="line"> numer *= x[i] * x[i];</span><br><span class="line"> denom *= (<span class="number">2</span>*j+<span class="number">2</span>)*(<span class="number">2</span>*j+<span class="number">3</span>);</span><br><span class="line"> sign *= <span class="number">-1</span>;</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> result[i]=value;</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><p>上述例程表示线程级并行的代码描述,通过将整个工作分散分配到不同的处理器核中,以获得加速的效果。工作的分配方式将很大程度上影响加速的效果。</p><h3 id="数据并行-用户定义、编译器定义"><a href="#数据并行-用户定义、编译器定义" class="headerlink" title="数据并行 - 用户定义、编译器定义"></a>数据并行 - 用户定义、编译器定义</h3><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/T8FkUAmjd2saXMn.jpg","alt":"SMID"}]</div> <div class="gallery-items"> </div> </div><p>概念:SMID 单指令多数据处理</p><p>使用AVX指令集的代码如下:</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">#<span class="keyword">include</span> <span class="string"><immintrin.h></span></span></span><br><span class="line"><span class="function"><span class="type">void</span> <span class="title">sinx</span><span class="params">(<span class="type">int</span> N, <span class="type">int</span> terms, <span class="type">float</span>* x, <span class="type">float</span>* sinx)</span></span></span><br><span class="line"><span class="function"></span>{</span><br><span class="line"> <span class="type">float</span> three_fact = <span class="number">6</span>; <span class="comment">// 3!</span></span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> i=<span class="number">0</span>; i<N; i+=<span class="number">8</span>)</span><br><span class="line"> {</span><br><span class="line"> __m256 origx = _mm256_load_ps(&x[i]);</span><br><span class="line"> __m256 value = origx;</span><br><span class="line"> __m256 numer = _mm256_mul_ps(origx, _mm256_mul_ps(origx, origx));</span><br><span class="line"> __m256 denom = _mm256_broadcast_ss(&three_fact);</span><br><span class="line"> <span class="type">int</span> sign = <span class="number">-1</span>;</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> j=<span class="number">1</span>; j<=terms; j++)</span><br><span class="line"> {</span><br><span class="line"> <span class="comment">// value += sign * numer / denom</span></span><br><span class="line"> __m256 tmp = _mm256_div_ps(_mm256_mul_ps(_mm256_broadcast_ss(sign),numer),denom);</span><br><span class="line"> value = _mm256_add_ps(value, tmp);</span><br><span class="line"> numer = _mm256_mul_ps(numer, _mm256_mul_ps(origx, origx));</span><br><span class="line"> denom = _mm256_mul_ps(denom, _mm256_broadcast_ss((<span class="number">2</span>*j+<span class="number">2</span>) * (<span class="number">2</span>*j+<span class="number">3</span>)));</span><br><span class="line"> sign *= <span class="number">-1</span>;</span><br><span class="line"> }</span><br><span class="line"> _mm256_store_ps(&sinx[i], value);</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><p>按照上图所示CPU模型,通过使用SMID单个核可以一次指令同时处理8个数据,达到并行处理数据的效果。</p><p>另一种自动优化代码的写法,前提是循环loop之间相互独立</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="type">void</span> <span class="title">sinx</span><span class="params">(<span class="type">int</span> N, <span class="type">int</span> terms, <span class="type">float</span>* x, <span class="type">float</span>* result)</span></span></span><br><span class="line"><span class="function"></span>{</span><br><span class="line"> <span class="comment">// declare independent loop iterations</span></span><br><span class="line"> forall (<span class="type">int</span> i from <span class="number">0</span> to N<span class="number">-1</span>)</span><br><span class="line"> {</span><br><span class="line"> <span class="type">float</span> value = x[i];</span><br><span class="line"> <span class="type">float</span> numer = x[i] * x[i] * x[i];</span><br><span class="line"> <span class="type">int</span> denom = <span class="number">6</span>; <span class="comment">// 3!</span></span><br><span class="line"> <span class="type">int</span> sign = <span class="number">-1</span>;</span><br><span class="line"> <span class="keyword">for</span> (<span class="type">int</span> j=<span class="number">1</span>; j<=terms; j++)</span><br><span class="line"> {</span><br><span class="line"> value += sign * numer / denom</span><br><span class="line"> numer *= x[i] * x[i];</span><br><span class="line"> denom *= (<span class="number">2</span>*j+<span class="number">2</span>) * (<span class="number">2</span>*j+<span class="number">3</span>);</span><br><span class="line"> sign *= <span class="number">-1</span>;</span><br><span class="line"> }</span><br><span class="line"> result[i] = value;</span><br><span class="line"> }</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><h3 id="条件执行"><a href="#条件执行" class="headerlink" title="条件执行"></a>条件执行</h3><p>假设数据并行处理器需要执行如下代码</p><figure class="highlight c++"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">float</span> x = A[i];</span><br><span class="line"><span class="keyword">if</span>(x><span class="number">0</span>){</span><br><span class="line"> <span class="type">float</span> tmp = <span class="built_in">exp</span>(x,<span class="number">5.f</span>);</span><br><span class="line"> x = tmp + kMyConst2;</span><br><span class="line">}<span class="keyword">else</span>{</span><br><span class="line"> <span class="type">float</span> tmp = kMyConst1;</span><br><span class="line"> x = <span class="number">2.f</span> * tmp;</span><br><span class="line">}</span><br><span class="line">result[i] = x;</span><br></pre></td></tr></tbody></table></figure><p>在处理器进行分支预测之后,处理器会首先并行执行值未真的那些ALU,再之后执行那些值为假的ALU,以达到并行处理的效果。</p><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/MPSRkb4tX7DHnwu.png","alt":"在数据并行时进行条件执行"}]</div> <div class="gallery-items"> </div> </div><h3 id="SMID的有关概念"><a href="#SMID的有关概念" class="headerlink" title="SMID的有关概念"></a>SMID的有关概念</h3><ul><li>要使用并行处理需要程序手动进行特定编码,比如使用:SSE、AVX等代码。</li><li>但是一旦使用并行处理,编译器无法检查与保证循环的独立性,需要自行保证</li><li>GPU的数据并行处理性能要高于CPU的数据并行处理性能</li></ul><p>比如:</p><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/PAJuyXgRSYkdBnt.png","alt":"CPU与GPU的SMID能力对比"}]</div> <div class="gallery-items"> </div> </div><h2 id="Accessing-Memory"><a href="#Accessing-Memory" class="headerlink" title="Accessing Memory"></a>Accessing Memory</h2><h3 id="术语"><a href="#术语" class="headerlink" title="术语"></a>术语</h3><ul><li>内存延时:内存延迟是指等待对系统内存中存储数据的访问完成时引起的延期。 单位:100机器周期、100毫秒</li><li>内存带宽:内存系统可以提供给处理器数据的速度。 单位:20GB/s</li><li>吞吐量(throughput):芯片单位时间处理数据的多少</li></ul><p>现代处理器中不可避免因为内存延时而降低CPU处理速率,但是可以通过一些办法来“隐藏”内存延时,比如:</p><ul><li>多级缓存</li><li>预储存</li><li>多线程技术</li><li>存储执行上下文</li></ul><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/Kmn3pBR7udSCboV.png","alt":"CPU与GPU的内存结构对比"}]</div> <div class="gallery-items"> </div> </div>]]></content>
<summary type="html">个人学习CMU并行计算课程时的笔记、总结等</summary>
<category term="并行计算" scheme="https://thinksky5124.github.io/categories/%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<category term="并行计算" scheme="https://thinksky5124.github.io/tags/%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<category term="CMU课程" scheme="https://thinksky5124.github.io/tags/CMU%E8%AF%BE%E7%A8%8B/"/>
</entry>
<entry>
<title>CMUParellelComputer课程概述</title>
<link href="https://thinksky5124.github.io/2021/01/07/%E8%AF%BE%E7%A8%8B%E6%A6%82%E8%BF%B0/"/>
<id>https://thinksky5124.github.io/2021/01/07/%E8%AF%BE%E7%A8%8B%E6%A6%82%E8%BF%B0/</id>
<published>2021-01-07T08:45:45.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<p>课程地址:<a href="http://www.cs.cmu.edu/afs/cs/academic/class/15418-f18/www/schedule.html">http://www.cs.cmu.edu/afs/cs/academic/class/15418-f18/www/schedule.html</a></p><h1 id="Lecture-1"><a href="#Lecture-1" class="headerlink" title="Lecture 1"></a>Lecture 1</h1><h2 id="为什么需要并行计算?"><a href="#为什么需要并行计算?" class="headerlink" title="为什么需要并行计算?"></a>为什么需要并行计算?</h2><p>单核CPU的性能几乎成指数级增长,而且英特尔在2004已经到达了单核的功耗墙。</p><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/vnUhqJ2XPaFVpdL.jpg","alt":"single_chip_performance.jpg"}]</div> <div class="gallery-items"> </div> </div><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/rhltZgXQe36CzyJ.jpg","alt":"CPUTrends.jpg"}]</div> <div class="gallery-items"> </div> </div><p>在2004年之前想要让程序变得更快,买一个新的机器</p><p>但是2004年之后,需要进行并行编程。</p><h2 id="目前的CPU架构"><a href="#目前的CPU架构" class="headerlink" title="目前的CPU架构"></a>目前的CPU架构</h2><div class="gallery-container" data-type="data" data-button=""> <div class="gallery-data">[{"url":"https://s2.loli.net/2024/03/25/6ptGgeORBVDzoPN.jpg","alt":"IntelSkylake.jpg"}]</div> <div class="gallery-items"> </div> </div><p>最左边的是GPU,右边排列了4块CPU,中间是通讯总线。可见现代CPU都是采用多核设计,以达到更快的速度。</p><h2 id="什么是并行计算?"><a href="#什么是并行计算?" class="headerlink" title="什么是并行计算?"></a>什么是并行计算?</h2><p>A parallel computer is a collection of processing elements<br>that cooperate to solve problems quickly。</p><p>并行计算器,就是协调多个单核处理器以解决问题。</p><p>加速比</p><p>$$ speedup(using \quad P \quad processors) = \frac{execution \quad time(using \quad 1 \quad processors)}{execution \quad time (using \quad P \quad processors)} $$</p><p>并行计算的原则</p><ul><li>交流信息严重限制了加速比</li><li>不平衡的工作分配会限制加速比</li></ul>]]></content>
<summary type="html">个人学习CMU并行计算课程时的笔记、总结等</summary>
<category term="并行计算" scheme="https://thinksky5124.github.io/categories/%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<category term="并行计算" scheme="https://thinksky5124.github.io/tags/%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97/"/>
<category term="CMU课程" scheme="https://thinksky5124.github.io/tags/CMU%E8%AF%BE%E7%A8%8B/"/>
</entry>
<entry>
<title>SPP-Net的tensorflow实现</title>
<link href="https://thinksky5124.github.io/2021/01/05/SPP-Net%E7%9A%84tensorflow%E5%AE%9E%E7%8E%B0/"/>
<id>https://thinksky5124.github.io/2021/01/05/SPP-Net%E7%9A%84tensorflow%E5%AE%9E%E7%8E%B0/</id>
<published>2021-01-05T03:50:12.000Z</published>
<updated>2024-04-16T08:57:05.647Z</updated>
<content type="html"><![CDATA[<p>代码来源:<a href="https://github.com/peace195/sppnet">https://github.com/peace195/sppnet</a></p><h1 id="代码解读"><a href="#代码解读" class="headerlink" title="代码解读"></a>代码解读</h1><h2 id="导入包"><a href="#导入包" class="headerlink" title="导入包"></a>导入包</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> sys</span><br><span class="line"><span class="keyword">import</span> tarfile</span><br><span class="line"><span class="keyword">from</span> six.moves.urllib.request <span class="keyword">import</span> urlretrieve</span><br><span class="line"><span class="keyword">from</span> six.moves <span class="keyword">import</span> cPickle <span class="keyword">as</span> pickle</span><br><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"><span class="keyword">import</span> math</span><br><span class="line"><span class="keyword">import</span> random</span><br><span class="line"><span class="keyword">import</span> re</span><br><span class="line"><span class="keyword">import</span> scipy.io</span><br><span class="line"><span class="keyword">import</span> PIL</span><br><span class="line"><span class="keyword">from</span> numpy <span class="keyword">import</span> *</span><br><span class="line"><span class="keyword">from</span> pylab <span class="keyword">import</span> *</span><br><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> defaultdict</span><br><span class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br></pre></td></tr></tbody></table></figure><h2 id="参数设置"><a href="#参数设置" class="headerlink" title="参数设置"></a>参数设置</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">DROPOUT = <span class="number">0.5</span> <span class="comment">#随机失活概率</span></span><br><span class="line">LEARNING_RATE = <span class="number">0.1</span> <span class="comment">#</span></span><br><span class="line">VALIDATION_SIZE = <span class="number">0</span> <span class="comment">#</span></span><br><span class="line">TRAINING_ITERATIONS = <span class="number">50000</span> <span class="comment">#训练次数</span></span><br><span class="line">WEIGHT_DECAY = <span class="number">0.00005</span> <span class="comment">#正则化系数</span></span><br></pre></td></tr></tbody></table></figure><h2 id="加载数据"><a href="#加载数据" class="headerlink" title="加载数据"></a>加载数据</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line">net_data = load(<span class="string">"bvlc_alexnet.npy"</span>,allow_pickle=<span class="literal">True</span>).item()</span><br><span class="line"></span><br><span class="line">out_pool_size = [<span class="number">8</span>, <span class="number">6</span>, <span class="number">4</span>] <span class="comment">#设置金字塔池化的长度,一共三个尺度</span></span><br><span class="line">hidden_dim = <span class="number">0</span></span><br><span class="line"><span class="keyword">for</span> item <span class="keyword">in</span> out_pool_size:</span><br><span class="line"> hidden_dim = hidden_dim + item * item</span><br><span class="line"> </span><br><span class="line">data_folder = <span class="string">'./jpg'</span></span><br><span class="line">labels = scipy.io.loadmat(<span class="string">'imagelabels.mat'</span>)</span><br><span class="line">setid = scipy.io.loadmat(<span class="string">'setid.mat'</span>)</span><br><span class="line"></span><br><span class="line">labels = labels[<span class="string">'labels'</span>][<span class="number">0</span>] - <span class="number">1</span></span><br><span class="line">trnid = np.array(setid[<span class="string">'tstid'</span>][<span class="number">0</span>]) - <span class="number">1</span></span><br><span class="line">tstid = np.array(setid[<span class="string">'trnid'</span>][<span class="number">0</span>]) - <span class="number">1</span></span><br><span class="line">valid = np.array(setid[<span class="string">'valid'</span>][<span class="number">0</span>]) - <span class="number">1</span></span><br><span class="line"></span><br><span class="line">num_classes = <span class="number">102</span></span><br><span class="line">data_dir = <span class="built_in">list</span>() <span class="comment">#加载文件名</span></span><br><span class="line"><span class="keyword">for</span> img <span class="keyword">in</span> os.listdir(data_folder):</span><br><span class="line"> data_dir.append(os.path.join(data_folder, img))</span><br><span class="line"></span><br><span class="line">data_dir.sort()</span><br></pre></td></tr></tbody></table></figure><h2 id="子函数"><a href="#子函数" class="headerlink" title="子函数"></a>子函数</h2><h3 id="打印迭代训练输出"><a href="#打印迭代训练输出" class="headerlink" title="打印迭代训练输出"></a>打印迭代训练输出</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">print_activations</span>(<span class="params">t</span>):</span><br><span class="line"> <span class="built_in">print</span>(t.op.name, <span class="string">' '</span>, t.get_shape().as_list())</span><br></pre></td></tr></tbody></table></figure><h3 id="整理标签"><a href="#整理标签" class="headerlink" title="整理标签"></a>整理标签</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">dense_to_one_hot</span>(<span class="params">labels_dense, num_classes</span>):</span><br><span class="line"> num_labels = labels_dense.shape[<span class="number">0</span>]</span><br><span class="line"> index_offset = np.arange(num_labels) * num_classes</span><br><span class="line"> labels_one_hot = np.zeros((num_labels, num_classes))</span><br><span class="line"> labels_one_hot.flat[index_offset + labels_dense.ravel()] = <span class="number">1</span></span><br><span class="line"> <span class="keyword">return</span> labels_one_hot</span><br></pre></td></tr></tbody></table></figure><h3 id="读取训练图像"><a href="#读取训练图像" class="headerlink" title="读取训练图像"></a>读取训练图像</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">read_images_from_disk</span>(<span class="params">input_queue</span>):</span><br><span class="line"> label = input_queue[<span class="number">1</span>]</span><br><span class="line"> file_contents = tf.read_file(input_queue[<span class="number">0</span>])</span><br><span class="line"> example = tf.image.decode_jpeg(file_contents, channels=<span class="number">3</span>)</span><br><span class="line"> <span class="comment"># example = tf.cast(example, tf.float32 )</span></span><br><span class="line"> <span class="keyword">return</span> example, label</span><br></pre></td></tr></tbody></table></figure><h3 id="通过函数产生权重数据"><a href="#通过函数产生权重数据" class="headerlink" title="通过函数产生权重数据"></a>通过函数产生权重数据</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">weight_variable</span>(<span class="params">shape, name</span>):</span><br><span class="line"> initial = tf.truncated_normal(shape, stddev=<span class="number">0.01</span>, name=name)</span><br><span class="line"> <span class="keyword">return</span> tf.Variable(initial)</span><br></pre></td></tr></tbody></table></figure><h3 id="通过函数产生偏置数据"><a href="#通过函数产生偏置数据" class="headerlink" title="通过函数产生偏置数据"></a>通过函数产生偏置数据</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">bias_variable</span>(<span class="params">shape, name</span>):</span><br><span class="line"> initial = tf.constant(<span class="number">0.1</span>, shape=shape, name=name)</span><br><span class="line"> <span class="keyword">return</span> tf.Variable(initial)</span><br></pre></td></tr></tbody></table></figure><h3 id="实现caffe多通道卷积,当窗口不够时,舍弃"><a href="#实现caffe多通道卷积,当窗口不够时,舍弃" class="headerlink" title="实现caffe多通道卷积,当窗口不够时,舍弃"></a>实现caffe多通道卷积,当窗口不够时,舍弃</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">conv</span>(<span class="params"><span class="built_in">input</span>, kernel, biases, k_h, k_w, c_o, s_h, s_w, padding = <span class="string">"VALID"</span>, group = <span class="number">1</span></span>):</span><br><span class="line"> <span class="string">'''From https://github.com/ethereon/caffe-tensorflow</span></span><br><span class="line"><span class="string"> '''</span></span><br><span class="line"> c_i = <span class="built_in">input</span>.get_shape()[-<span class="number">1</span>]</span><br><span class="line"> <span class="keyword">assert</span> c_i % group == <span class="number">0</span></span><br><span class="line"> <span class="keyword">assert</span> c_o % group == <span class="number">0</span></span><br><span class="line"> convolve = <span class="keyword">lambda</span> i, k: tf.nn.conv2d(i, k, [<span class="number">1</span>, s_h, s_w, <span class="number">1</span>], padding=padding)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> group == <span class="number">1</span>:</span><br><span class="line"> conv = convolve(<span class="built_in">input</span>, kernel)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> input_groups = tf.split(axis=<span class="number">3</span>, num_or_size_splits=group, value=<span class="built_in">input</span>)</span><br><span class="line"> kernel_groups = tf.split(axis=<span class="number">3</span>, num_or_size_splits=group, value=kernel)</span><br><span class="line"> output_groups = [convolve(i, k) <span class="keyword">for</span> i, k <span class="keyword">in</span> <span class="built_in">zip</span>(input_groups, kernel_groups)]</span><br><span class="line"> conv = tf.concat(axis=<span class="number">3</span>, values=output_groups)</span><br><span class="line"> <span class="keyword">return</span> tf.reshape(tf.nn.bias_add(conv, biases), [-<span class="number">1</span>] + conv.get_shape().as_list()[<span class="number">1</span>:])</span><br></pre></td></tr></tbody></table></figure><h3 id="单通道卷积,当窗口不够时,填充"><a href="#单通道卷积,当窗口不够时,填充" class="headerlink" title="单通道卷积,当窗口不够时,填充"></a>单通道卷积,当窗口不够时,填充</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">conv2d</span>(<span class="params">x, W, stride_h, stride_w, padding=<span class="string">'SAME'</span></span>):</span><br><span class="line"> <span class="keyword">return</span> tf.nn.conv2d(x, W, strides=[<span class="number">1</span>, stride_h, stride_w, <span class="number">1</span>], padding=padding)</span><br></pre></td></tr></tbody></table></figure><h3 id="2-2池化操作,当窗口不够时,填充"><a href="#2-2池化操作,当窗口不够时,填充" class="headerlink" title="2*2池化操作,当窗口不够时,填充"></a>2*2池化操作,当窗口不够时,填充</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">max_pool_2x2</span>(<span class="params">x</span>):</span><br><span class="line"> <span class="keyword">return</span> tf.nn.max_pool(x, ksize=[<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], strides=[<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], padding=<span class="string">'SAME'</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="3-3池化操作,当窗口不够时,填充"><a href="#3-3池化操作,当窗口不够时,填充" class="headerlink" title="3*3池化操作,当窗口不够时,填充"></a>3*3池化操作,当窗口不够时,填充</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">max_pool_3x3</span>(<span class="params">x</span>):</span><br><span class="line"> <span class="keyword">return</span> tf.nn.max_pool(x, ksize=[<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">1</span>], strides=[<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], padding=<span class="string">'SAME'</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="4-4池化操作,当窗口不够时,填充"><a href="#4-4池化操作,当窗口不够时,填充" class="headerlink" title="4*4池化操作,当窗口不够时,填充"></a>4*4池化操作,当窗口不够时,填充</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">max_pool_4x4</span>(<span class="params">x</span>):</span><br><span class="line"> <span class="keyword">return</span> tf.nn.max_pool(x, ksize=[<span class="number">1</span>, <span class="number">4</span>, <span class="number">4</span>, <span class="number">1</span>], strides=[<span class="number">1</span>, <span class="number">4</span>, <span class="number">4</span>, <span class="number">1</span>], padding=<span class="string">'SAME'</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="金字塔池化操作"><a href="#金字塔池化操作" class="headerlink" title="金字塔池化操作"></a>金字塔池化操作</h3><h2 id="将输入的batch-size-height-width-channels的图像池化成batch-size-1-out-pool-size大小"><a href="#将输入的batch-size-height-width-channels的图像池化成batch-size-1-out-pool-size大小" class="headerlink" title="将输入的batch_size*height*width*channels的图像池化成batch_size*1*out_pool_size大小"></a>将输入的batch_size*height*width*channels的图像池化成batch_size*1*out_pool_size大小<br><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Spatial Pyramid Pooling block</span></span><br><span class="line"><span class="comment"># https://arxiv.org/abs/1406.4729</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">spatial_pyramid_pool</span>(<span class="params">previous_conv, num_sample, previous_conv_size, out_pool_size</span>):</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> previous_conv: a tensor vector of previous convolution layer</span></span><br><span class="line"><span class="string"> num_sample: an int number of image in the batch</span></span><br><span class="line"><span class="string"> previous_conv_size: an int vector [height, width] of the matrix features size of previous convolution layer</span></span><br><span class="line"><span class="string"> out_pool_size: a int vector of expected output size of max pooling layer</span></span><br><span class="line"><span class="string"> </span></span><br><span class="line"><span class="string"> returns: a tensor vector with shape [1 x n] is the concentration of multi-level pooling</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(out_pool_size)):</span><br><span class="line"> h_strd = h_size = math.ceil(<span class="built_in">float</span>(previous_conv_size[<span class="number">0</span>]) / out_pool_size[i])</span><br><span class="line"> w_strd = w_size = math.ceil(<span class="built_in">float</span>(previous_conv_size[<span class="number">1</span>]) / out_pool_size[i])</span><br><span class="line"> pad_h = <span class="built_in">int</span>(out_pool_size[i] * h_size - previous_conv_size[<span class="number">0</span>])</span><br><span class="line"> pad_w = <span class="built_in">int</span>(out_pool_size[i] * w_size - previous_conv_size[<span class="number">1</span>])</span><br><span class="line"> new_previous_conv = tf.pad(previous_conv, tf.constant([[<span class="number">0</span>, <span class="number">0</span>], [<span class="number">0</span>, pad_h], [<span class="number">0</span>, pad_w], [<span class="number">0</span>, <span class="number">0</span>]]))</span><br><span class="line"> max_pool = tf.nn.max_pool(new_previous_conv,</span><br><span class="line"> ksize=[<span class="number">1</span>,h_size, h_size, <span class="number">1</span>],</span><br><span class="line"> strides=[<span class="number">1</span>,h_strd, w_strd,<span class="number">1</span>],</span><br><span class="line"> padding=<span class="string">'SAME'</span>)</span><br><span class="line"> <span class="keyword">if</span> (i == <span class="number">0</span>):</span><br><span class="line"> spp = tf.reshape(max_pool, [num_sample, -<span class="number">1</span>])</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> spp = tf.concat(axis=<span class="number">1</span>, values=[spp, tf.reshape(max_pool, [num_sample, -<span class="number">1</span>])])</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> spp</span><br></pre></td></tr></tbody></table></figure></h2><h2 id="训练模型"><a href="#训练模型" class="headerlink" title="训练模型"></a>训练模型</h2><h3 id="设置batch"><a href="#设置batch" class="headerlink" title="设置batch"></a>设置batch</h3><p>将某一维尺寸相差不超过10个像素点值的部分聚集在一起,成为一个batch</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">size_cluster = defaultdict(<span class="built_in">list</span>)</span><br><span class="line"><span class="keyword">for</span> tid <span class="keyword">in</span> trnid:</span><br><span class="line"> img = Image.<span class="built_in">open</span>(data_dir[tid])</span><br><span class="line"> key = (img.size[<span class="number">0</span>] - img.size[<span class="number">0</span>] % <span class="number">10</span>, img.size[<span class="number">1</span>] - img.size[<span class="number">1</span>] % <span class="number">10</span>)</span><br><span class="line"> size_cluster[key].append(tid)</span><br><span class="line"> </span><br><span class="line">size_cluster_keys = size_cluster.keys()</span><br></pre></td></tr></tbody></table></figure><h3 id="初始化变量"><a href="#初始化变量" class="headerlink" title="初始化变量"></a>初始化变量</h3><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">train_accuracies = []</span><br><span class="line">train_cost = []</span><br><span class="line">validation_accuracies = []</span><br><span class="line">x_range = []</span><br><span class="line">batch_size = <span class="number">20</span></span><br><span class="line"><span class="built_in">print</span>(<span class="string">'Training ...'</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="训练部分"><a href="#训练部分" class="headerlink" title="训练部分"></a>训练部分</h3><h4 id="迭代计数部分"><a href="#迭代计数部分" class="headerlink" title="迭代计数部分"></a>迭代计数部分</h4><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Training block</span></span><br><span class="line"><span class="comment"># 1. Combime all iamges have the same size to a batch.</span></span><br><span class="line"><span class="comment"># 2. Then, train parameters in a batch</span></span><br><span class="line"><span class="comment"># 3. Transfer trained parameters to another batch</span></span><br><span class="line">it = <span class="number">0</span> <span class="comment">#迭代计数器</span></span><br><span class="line"><span class="keyword">while</span> it < TRAINING_ITERATIONS:</span><br><span class="line"> graph = tf.Graph()</span><br><span class="line"> <span class="keyword">with</span> graph.as_default():</span><br></pre></td></tr></tbody></table></figure><h4 id="循环主体"><a href="#循环主体" class="headerlink" title="循环主体"></a>循环主体</h4><h5 id="设置batch-1"><a href="#设置batch-1" class="headerlink" title="设置batch"></a>设置batch</h5><p>每一个batch的数量不相同,图片尺寸从小到大进行训练,在进行训练图像输入时,先将图像的长宽缩小2倍</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line">y_train = labels[size_cluster[size_cluster_keys[it%<span class="built_in">len</span>(size_cluster_keys)]]]</span><br><span class="line"><span class="keyword">if</span> <span class="built_in">len</span>(y_train) < <span class="number">50</span>:</span><br><span class="line"> batch_size = <span class="built_in">len</span>(y_train)</span><br><span class="line"></span><br><span class="line">y_train = dense_to_one_hot(y_train, num_classes)</span><br><span class="line">x_train = [data_dir[i] <span class="keyword">for</span> i <span class="keyword">in</span> size_cluster[size_cluster_keys[it%<span class="built_in">len</span>(size_cluster_keys)]]]</span><br><span class="line"></span><br><span class="line">input_queue_train = tf.train.slice_input_producer([x_train, y_train],</span><br><span class="line"> num_epochs=<span class="literal">None</span>,</span><br><span class="line"> shuffle=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">x_train, y_train = read_images_from_disk(input_queue_train)</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(size_cluster_keys[it%<span class="built_in">len</span>(size_cluster_keys)]) <span class="comment">#输出聚类图像尺寸字典键</span></span><br><span class="line"></span><br><span class="line">x_train = tf.image.resize_images(x_train,</span><br><span class="line"> [size_cluster_keys[it%<span class="built_in">len</span>(size_cluster_keys)][<span class="number">1</span>]/<span class="number">2</span>,</span><br><span class="line"> size_cluster_keys[it%<span class="built_in">len</span>(size_cluster_keys)][<span class="number">0</span>]/<span class="number">2</span>],</span><br><span class="line"> method=<span class="number">1</span>, align_corners=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line">x_train, y_train = tf.train.batch([x_train, y_train], batch_size = batch_size)</span><br></pre></td></tr></tbody></table></figure><h5 id="加载预训练的Alexnet模型"><a href="#加载预训练的Alexnet模型" class="headerlink" title="加载预训练的Alexnet模型"></a>加载预训练的Alexnet模型</h5><p>舍弃了alexnet第六层全连接的W,重新初始化变量,修改第八层全连接参数匹配训练集</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line">x = tf.placeholder(<span class="string">'float'</span>, shape = x_train.get_shape()) <span class="comment">#定义输入参数</span></span><br><span class="line">y_ = tf.placeholder(<span class="string">'float'</span>, shape = [<span class="literal">None</span>, num_classes]) <span class="comment">#定义输出参数</span></span><br><span class="line"></span><br><span class="line">conv1W = tf.Variable(net_data[<span class="string">"conv1"</span>][<span class="number">0</span>])</span><br><span class="line">conv1b = tf.Variable(net_data[<span class="string">"conv1"</span>][<span class="number">1</span>])</span><br><span class="line">conv2W = tf.Variable(net_data[<span class="string">"conv2"</span>][<span class="number">0</span>])</span><br><span class="line">conv2b = tf.Variable(net_data[<span class="string">"conv2"</span>][<span class="number">1</span>])</span><br><span class="line">conv3W = tf.Variable(net_data[<span class="string">"conv3"</span>][<span class="number">0</span>])</span><br><span class="line">conv3b = tf.Variable(net_data[<span class="string">"conv3"</span>][<span class="number">1</span>])</span><br><span class="line">conv4W = tf.Variable(net_data[<span class="string">"conv4"</span>][<span class="number">0</span>])</span><br><span class="line">conv4b = tf.Variable(net_data[<span class="string">"conv4"</span>][<span class="number">1</span>])</span><br><span class="line">conv5W = tf.Variable(net_data[<span class="string">"conv5"</span>][<span class="number">0</span>])</span><br><span class="line">conv5b = tf.Variable(net_data[<span class="string">"conv5"</span>][<span class="number">1</span>])</span><br><span class="line">fc6W = weight_variable([hidden_dim * <span class="number">256</span>, <span class="number">4096</span>], <span class="string">'fc6W'</span>)</span><br><span class="line">fc6b = tf.Variable(net_data[<span class="string">"fc6"</span>][<span class="number">1</span>])</span><br><span class="line">fc7W = tf.Variable(net_data[<span class="string">"fc7"</span>][<span class="number">0</span>])</span><br><span class="line">fc7b = tf.Variable(net_data[<span class="string">"fc7"</span>][<span class="number">1</span>])</span><br><span class="line">fc8W = weight_variable([<span class="number">4096</span>, num_classes], <span class="string">'W_fc8'</span>)</span><br><span class="line">fc8b = bias_variable([num_classes], <span class="string">'b_fc8'</span>)</span><br><span class="line">keep_prob = tf.placeholder(<span class="string">'float'</span>)</span><br></pre></td></tr></tbody></table></figure><h5 id="前向传播通路"><a href="#前向传播通路" class="headerlink" title="前向传播通路"></a>前向传播通路</h5><p>第一层</p><pre><code>conv1 卷积核11*11 卷积步长4 填充卷积 relu函数激活lrn1 正则化maxpool1 3*3池化 池化步长2 舍弃池化</code></pre><p>第二层</p><pre><code>conv2 卷积核5*5 卷积步长1 填充卷积 relu函数激活lrn2 正则化maxpool2 3*3池化 池化步长为2 舍弃池化</code></pre><p>第三层</p><pre><code>conv3 卷积核3*3 卷积步长1 填充卷积 relu函数激活</code></pre><p>第四层</p><pre><code>conv4 卷积核3*3 卷积步长1 填充卷积 relu函数激活</code></pre><p>第五层</p><pre><code>conv5 卷积核3*3 卷积步长1 填充卷积 relu函数激活spp5 金字塔池化</code></pre><p>第六层</p><pre><code>fc6 全连接 relu激活 随机失活训练</code></pre><p>第七层</p><pre><code>fc7 全连接 relu激活 随机失活训练</code></pre><p>第八层</p><pre><code>fc8 全连接 不使用激活函数</code></pre><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">model</span>(<span class="params">x</span>):</span><br><span class="line"> <span class="comment"># conv1 卷积核11*11 卷积步长4 填充卷积 relu函数激活</span></span><br><span class="line"> conv1 = tf.nn.relu(conv(x, conv1W, conv1b, <span class="number">11</span>, <span class="number">11</span>, <span class="number">96</span>, <span class="number">4</span>, <span class="number">4</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">1</span>))</span><br><span class="line"> <span class="comment"># lrn1 正则化</span></span><br><span class="line"> <span class="comment"># lrn(2, 2e-05, 0.75, name='norm1')</span></span><br><span class="line"> lrn1 = tf.nn.local_response_normalization(conv1,</span><br><span class="line"> depth_radius=<span class="number">5</span>,</span><br><span class="line"> alpha=<span class="number">0.0001</span>,</span><br><span class="line"> beta=<span class="number">0.75</span>,</span><br><span class="line"> bias=<span class="number">1.0</span>)</span><br><span class="line"> <span class="comment"># maxpool1 3*3池化 池化步长2 舍弃池化</span></span><br><span class="line"> maxpool1 = tf.nn.max_pool(lrn1, ksize=[<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">1</span>], strides=[<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], padding=<span class="string">'VALID'</span>)</span><br><span class="line"> <span class="comment"># conv2 卷积核5*5 卷积步长1 填充卷积 relu函数激活</span></span><br><span class="line"> conv2 = tf.nn.relu(conv(maxpool1, conv2W, conv2b, <span class="number">5</span>, <span class="number">5</span>, <span class="number">256</span>, <span class="number">1</span>, <span class="number">1</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">2</span>))</span><br><span class="line"> <span class="comment"># lrn2 正则化</span></span><br><span class="line"> <span class="comment"># lrn(2, 2e-05, 0.75, name='norm2')</span></span><br><span class="line"> lrn2 = tf.nn.local_response_normalization(conv2,</span><br><span class="line"> depth_radius=<span class="number">5</span>,</span><br><span class="line"> alpha=<span class="number">0.0001</span>,</span><br><span class="line"> beta=<span class="number">0.75</span>,</span><br><span class="line"> bias=<span class="number">1.0</span>)</span><br><span class="line"> <span class="comment"># maxpool2 3*3池化 池化步长为2 舍弃池化</span></span><br><span class="line"> maxpool2 = tf.nn.max_pool(lrn2, ksize=[<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">1</span>], strides=[<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], padding=<span class="string">'VALID'</span>)</span><br><span class="line"> <span class="comment"># conv3 卷积核3*3 卷积步长1 填充卷积 relu函数激活</span></span><br><span class="line"> conv3 = tf.nn.relu(conv(maxpool2, conv3W, conv3b, <span class="number">3</span>, <span class="number">3</span>, <span class="number">384</span>, <span class="number">1</span>, <span class="number">1</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">1</span>))</span><br><span class="line"> <span class="comment"># conv4 卷积核3*3 卷积步长1 填充卷积 relu函数激活</span></span><br><span class="line"> conv4 = tf.nn.relu(conv(conv3, conv4W, conv4b, <span class="number">3</span>, <span class="number">3</span>, <span class="number">384</span>, <span class="number">1</span>, <span class="number">1</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">2</span>))</span><br><span class="line"> <span class="comment"># conv5 卷积核3*3 卷积步长1 填充卷积 relu函数激活</span></span><br><span class="line"> conv5 = tf.nn.relu(conv(conv4, conv5W, conv5b, <span class="number">3</span>, <span class="number">3</span>, <span class="number">256</span>, <span class="number">1</span>, <span class="number">1</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">2</span>))</span><br><span class="line"> <span class="built_in">print</span>(<span class="built_in">int</span>(conv5.get_shape()[<span class="number">0</span>]), <span class="built_in">int</span>(conv5.get_shape()[<span class="number">1</span>]), <span class="built_in">int</span>(conv5.get_shape()[<span class="number">2</span>]))</span><br><span class="line"> <span class="comment"># spp5 金字塔池化</span></span><br><span class="line"> maxpool5 = spatial_pyramid_pool(conv5,</span><br><span class="line"> <span class="built_in">int</span>(conv5.get_shape()[<span class="number">0</span>]),</span><br><span class="line"> [<span class="built_in">int</span>(conv5.get_shape()[<span class="number">1</span>]), <span class="built_in">int</span>(conv5.get_shape()[<span class="number">2</span>])],</span><br><span class="line"> out_pool_size)</span><br><span class="line"> <span class="comment"># fc6 全连接 relu激活 随机失活训练</span></span><br><span class="line"> fc6 = tf.nn.relu_layer(tf.reshape(maxpool5, [-<span class="number">1</span>, <span class="built_in">int</span>(prod(maxpool5.get_shape()[<span class="number">1</span>:]))]), fc6W, fc6b)</span><br><span class="line"> fc6_drop = tf.nn.dropout(fc6, keep_prob)</span><br><span class="line"> <span class="comment"># fc7 全连接 relu激活 随机失活训练</span></span><br><span class="line"> fc7 = tf.nn.relu_layer(fc6_drop, fc7W, fc7b)</span><br><span class="line"> fc7_drop = tf.nn.dropout(fc7, keep_prob)</span><br><span class="line"> <span class="comment"># fc8 全连接 不使用激活函数</span></span><br><span class="line"> fc8 = tf.nn.xw_plus_b(fc7_drop, fc8W, fc8b)</span><br><span class="line"> <span class="keyword">return</span> fc8</span><br></pre></td></tr></tbody></table></figure><h5 id="损失函数计算反向传播"><a href="#损失函数计算反向传播" class="headerlink" title="损失函数计算反向传播"></a>损失函数计算反向传播</h5><h2 id="使用交叉熵softmax函数计算score,损失函数为:交叉熵损失-正则化系数-所有权重的L2,动态调整学习率-评价模型计算模型正确率-训练设置-输出评价图"><a href="#使用交叉熵softmax函数计算score,损失函数为:交叉熵损失-正则化系数-所有权重的L2,动态调整学习率-评价模型计算模型正确率-训练设置-输出评价图" class="headerlink" title="使用交叉熵softmax函数计算score,损失函数为:交叉熵损失+正则化系数*所有权重的L2,动态调整学习率##### 评价模型计算模型正确率##### 训练设置##### 输出评价图"></a>使用交叉熵softmax函数计算score,损失函数为:交叉熵损失+正则化系数*所有权重的L2,动态调整学习率<br><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">logits = model(x)</span><br><span class="line">cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))</span><br><span class="line">regularizers = tf.nn.l2_loss(conv1W) + tf.nn.l2_loss(conv1b) + \</span><br><span class="line"> tf.nn.l2_loss(conv2W) + tf.nn.l2_loss(conv2b) + \</span><br><span class="line"> tf.nn.l2_loss(conv3W) + tf.nn.l2_loss(conv3b) + \</span><br><span class="line"> tf.nn.l2_loss(conv4W) + tf.nn.l2_loss(conv4b) + \</span><br><span class="line"> tf.nn.l2_loss(conv5W) + tf.nn.l2_loss(conv5b) + \</span><br><span class="line"> tf.nn.l2_loss(fc6W) + tf.nn.l2_loss(fc6b) + \</span><br><span class="line"> tf.nn.l2_loss(fc7W) + tf.nn.l2_loss(fc7b) + \</span><br><span class="line"> tf.nn.l2_loss(fc8W) + tf.nn.l2_loss(fc8b)</span><br><span class="line"></span><br><span class="line">loss = tf.reduce_mean(cross_entropy + WEIGHT_DECAY * regularizers)</span><br><span class="line"></span><br><span class="line">cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))</span><br><span class="line"><span class="comment"># optimisation loss function</span></span><br><span class="line">global_step = tf.Variable(<span class="number">0</span>)</span><br><span class="line">learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, <span class="number">1000</span>, <span class="number">0.9</span>, staircase=<span class="literal">True</span>) <span class="comment">#动态学习率</span></span><br><span class="line">train_step = tf.train.AdagradOptimizer(learning_rate).minimize(loss)</span><br></pre></td></tr></tbody></table></figure><br>##### 评价模型<br>计算模型正确率<br><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># evaluation</span></span><br><span class="line">correct_prediction = tf.equal(tf.argmax(logits, <span class="number">1</span>), tf.argmax(y_, <span class="number">1</span>))</span><br><span class="line">accuracy = tf.reduce_mean(tf.cast(correct_prediction, <span class="string">'float'</span>))</span><br><span class="line">predict = tf.argmax(logits, <span class="number">1</span>)</span><br><span class="line">saver = tf.train.Saver({v.op.name: v <span class="keyword">for</span> v <span class="keyword">in</span> [conv1W, conv1b,</span><br><span class="line"> conv2W, conv2b,</span><br><span class="line"> conv3W, conv3b,</span><br><span class="line"> conv4W, conv4b,</span><br><span class="line"> conv5W, conv5b,</span><br><span class="line"> fc6W, fc6b,</span><br><span class="line"> fc7W, fc7b,</span><br><span class="line"> fc8W, fc8b]})</span><br></pre></td></tr></tbody></table></figure><br>##### 训练设置<br><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">with</span> tf.Session(graph=graph) <span class="keyword">as</span> sess:</span><br><span class="line">init = tf.global_variables_initializer()</span><br><span class="line">sess.run(init)</span><br><span class="line">coord = tf.train.Coordinator()</span><br><span class="line">threads = tf.train.start_queue_runners(coord=coord)</span><br><span class="line"><span class="keyword">if</span> os.path.exists(<span class="string">'./alex_model_spp.ckpt'</span>):</span><br><span class="line"> saver.restore(sess, <span class="string">'./alex_model_spp.ckpt'</span>)</span><br><span class="line"></span><br><span class="line">cnt_tmp = <span class="number">0</span></span><br><span class="line">xtrain, ytrain = sess.run([x_train, y_train])</span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>):</span><br><span class="line"> it = it + <span class="number">1</span></span><br><span class="line"> _, train_accuracy, cost = sess.run([train_step, accuracy, cross_entropy], </span><br><span class="line"> feed_dict = {x: xtrain,</span><br><span class="line"> y_: ytrain, </span><br><span class="line"> keep_prob: <span class="number">1.0</span>})</span><br><span class="line"> </span><br><span class="line"> <span class="built_in">print</span>(<span class="string">'training_accuracy => %.4f, cost value => %.4f for step %d'</span></span><br><span class="line"> %(train_accuracy, cost, it))</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> (train_accuracy > <span class="number">0.95</span>):</span><br><span class="line"> cnt_tmp = cnt_tmp + <span class="number">1</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> (cnt_tmp > <span class="number">10</span>):</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"></span><br><span class="line"> train_accuracies.append(train_accuracy)</span><br><span class="line"> x_range.append(it)</span><br><span class="line"> train_cost.append(cost)</span><br><span class="line"></span><br><span class="line">saver.save(sess, <span class="string">'./alex_model_spp.ckpt'</span>)</span><br><span class="line">coord.request_stop()</span><br><span class="line">coord.join(threads)</span><br><span class="line">sess.close()</span><br><span class="line"><span class="keyword">del</span> sess</span><br></pre></td></tr></tbody></table></figure><br>##### 输出评价图<br><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Plot accuracy and loss curve</span></span><br><span class="line">plt.plot(x_range, train_cost,<span class="string">'-b'</span>)</span><br><span class="line">plt.ylabel(<span class="string">'spp_cost'</span>)</span><br><span class="line">plt.xlabel(<span class="string">'step'</span>)</span><br><span class="line">plt.savefig(<span class="string">'spp_cost.png'</span>)</span><br><span class="line">plt.close()</span><br><span class="line">plt.plot(x_range, train_accuracies,<span class="string">'-b'</span>)</span><br><span class="line">plt.ylabel(<span class="string">'spp_accuracies'</span>)</span><br><span class="line">plt.ylim(ymax = <span class="number">1.1</span>)</span><br><span class="line">plt.xlabel(<span class="string">'step'</span>)</span><br><span class="line">plt.savefig(<span class="string">'spp_accuracy.png'</span>)</span><br></pre></td></tr></tbody></table></figure></h2><h2 id="测试模型"><a href="#测试模型" class="headerlink" title="测试模型"></a>测试模型</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Testing block</span></span><br><span class="line"><span class="comment"># 1. Gather all images have the same size into a batch</span></span><br><span class="line"><span class="comment"># 2. Feed to Alexnet_SPP to predict the expected labels</span></span><br><span class="line">it = <span class="number">0</span></span><br><span class="line">result = <span class="built_in">list</span>()</span><br><span class="line">f = <span class="built_in">open</span>(<span class="string">'result_spp.txt'</span>, <span class="string">'w'</span>)</span><br><span class="line"><span class="keyword">while</span> it < <span class="built_in">len</span>(tstid):</span><br><span class="line"> <span class="keyword">if</span> (it % <span class="number">10</span> == <span class="number">0</span>):</span><br><span class="line"> <span class="built_in">print</span>(it)</span><br><span class="line"> graph = tf.Graph()</span><br><span class="line"> <span class="keyword">with</span> graph.as_default():</span><br><span class="line"> <span class="comment"># with tf.device('/cpu:0'):</span></span><br><span class="line"> img = Image.<span class="built_in">open</span>(data_dir[tstid[it]])</span><br><span class="line"> filename_queue = tf.train.string_input_producer([data_dir[tstid[it]]])</span><br><span class="line"> reader = tf.WholeFileReader()</span><br><span class="line"> key, value = reader.read(filename_queue)</span><br><span class="line"> my_img = tf.image.decode_jpeg(value, channels = <span class="number">3</span>)</span><br><span class="line"> <span class="comment"># my_img = tf.cast(my_img, tf.float32)</span></span><br><span class="line"> my_img = tf.image.resize_images(my_img,</span><br><span class="line"> [img.size[<span class="number">1</span>] / <span class="number">2</span>,</span><br><span class="line"> img.size[<span class="number">0</span>] / <span class="number">2</span>],</span><br><span class="line"> method = <span class="number">1</span>,</span><br><span class="line"> align_corners = <span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"> my_img = tf.expand_dims(my_img, <span class="number">0</span>)</span><br><span class="line"></span><br><span class="line"> x = tf.placeholder(<span class="string">'float'</span>, shape=my_img.get_shape())</span><br><span class="line"> <span class="built_in">print</span>(my_img.get_shape())</span><br><span class="line"> conv1W = tf.Variable(net_data[<span class="string">"conv1"</span>][<span class="number">0</span>])</span><br><span class="line"> conv1b = tf.Variable(net_data[<span class="string">"conv1"</span>][<span class="number">1</span>])</span><br><span class="line"> conv2W = tf.Variable(net_data[<span class="string">"conv2"</span>][<span class="number">0</span>])</span><br><span class="line"> conv2b = tf.Variable(net_data[<span class="string">"conv2"</span>][<span class="number">1</span>])</span><br><span class="line"> conv3W = tf.Variable(net_data[<span class="string">"conv3"</span>][<span class="number">0</span>])</span><br><span class="line"> conv3b = tf.Variable(net_data[<span class="string">"conv3"</span>][<span class="number">1</span>])</span><br><span class="line"> conv4W = tf.Variable(net_data[<span class="string">"conv4"</span>][<span class="number">0</span>])</span><br><span class="line"> conv4b = tf.Variable(net_data[<span class="string">"conv4"</span>][<span class="number">1</span>])</span><br><span class="line"> conv5W = tf.Variable(net_data[<span class="string">"conv5"</span>][<span class="number">0</span>])</span><br><span class="line"> conv5b = tf.Variable(net_data[<span class="string">"conv5"</span>][<span class="number">1</span>])</span><br><span class="line"> fc6W = weight_variable([hidden_dim * <span class="number">256</span>, <span class="number">4096</span>], <span class="string">'fc6W'</span>)</span><br><span class="line"> fc6b = tf.Variable(net_data[<span class="string">"fc6"</span>][<span class="number">1</span>])</span><br><span class="line"> fc7W = tf.Variable(net_data[<span class="string">"fc7"</span>][<span class="number">0</span>])</span><br><span class="line"> fc7b = tf.Variable(net_data[<span class="string">"fc7"</span>][<span class="number">1</span>])</span><br><span class="line"> fc8W = weight_variable([<span class="number">4096</span>, num_classes], <span class="string">'W_fc8'</span>)</span><br><span class="line"> fc8b = bias_variable([num_classes], <span class="string">'b_fc8'</span>)</span><br><span class="line"> keep_prob = tf.placeholder(<span class="string">'float'</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">model</span>(<span class="params">x</span>):</span><br><span class="line"> <span class="comment"># conv1</span></span><br><span class="line"> conv1 = tf.nn.relu(conv(x, conv1W, conv1b, <span class="number">11</span>, <span class="number">11</span>, <span class="number">96</span>, <span class="number">4</span>, <span class="number">4</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">1</span>))</span><br><span class="line"> <span class="comment"># lrn1</span></span><br><span class="line"> <span class="comment"># lrn(2, 2e-05, 0.75, name='norm1')</span></span><br><span class="line"> lrn1 = tf.nn.local_response_normalization(conv1,</span><br><span class="line"> depth_radius=<span class="number">5</span>,</span><br><span class="line"> alpha=<span class="number">0.0001</span>,</span><br><span class="line"> beta=<span class="number">0.75</span>,</span><br><span class="line"> bias=<span class="number">1.0</span>)</span><br><span class="line"> <span class="comment"># maxpool1</span></span><br><span class="line"> maxpool1 = tf.nn.max_pool(lrn1, ksize=[<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">1</span>], strides=[<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], padding=<span class="string">'VALID'</span>)</span><br><span class="line"> <span class="comment"># conv2</span></span><br><span class="line"> conv2 = tf.nn.relu(conv(maxpool1, conv2W, conv2b, <span class="number">5</span>, <span class="number">5</span>, <span class="number">256</span>, <span class="number">1</span>, <span class="number">1</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">2</span>))</span><br><span class="line"> <span class="comment"># lrn2</span></span><br><span class="line"> <span class="comment"># lrn(2, 2e-05, 0.75, name='norm2')</span></span><br><span class="line"> lrn2 = tf.nn.local_response_normalization(conv2,</span><br><span class="line"> depth_radius=<span class="number">5</span>,</span><br><span class="line"> alpha=<span class="number">0.0001</span>,</span><br><span class="line"> beta=<span class="number">0.75</span>,</span><br><span class="line"> bias=<span class="number">1.0</span>)</span><br><span class="line"> <span class="comment"># maxpool2</span></span><br><span class="line"> maxpool2 = tf.nn.max_pool(lrn2, ksize=[<span class="number">1</span>, <span class="number">3</span>, <span class="number">3</span>, <span class="number">1</span>], strides=[<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">1</span>], padding=<span class="string">'VALID'</span>)</span><br><span class="line"> <span class="comment"># conv3</span></span><br><span class="line"> conv3 = tf.nn.relu(conv(maxpool2, conv3W, conv3b, <span class="number">3</span>, <span class="number">3</span>, <span class="number">384</span>, <span class="number">1</span>, <span class="number">1</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">1</span>))</span><br><span class="line"> <span class="comment"># conv4</span></span><br><span class="line"> conv4 = tf.nn.relu(conv(conv3, conv4W, conv4b, <span class="number">3</span>, <span class="number">3</span>, <span class="number">384</span>, <span class="number">1</span>, <span class="number">1</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">2</span>))</span><br><span class="line"> <span class="comment"># conv5</span></span><br><span class="line"> conv5 = tf.nn.relu(conv(conv4, conv5W, conv5b, <span class="number">3</span>, <span class="number">3</span>, <span class="number">256</span>, <span class="number">1</span>, <span class="number">1</span>, padding=<span class="string">"SAME"</span>, group=<span class="number">2</span>))</span><br><span class="line"> maxpool5 = spatial_pyramid_pool(conv5,</span><br><span class="line"> <span class="built_in">int</span>(conv5.get_shape()[<span class="number">0</span>]),</span><br><span class="line"> [<span class="built_in">int</span>(conv5.get_shape()[<span class="number">1</span>]), <span class="built_in">int</span>(conv5.get_shape()[<span class="number">2</span>])],</span><br><span class="line"> out_pool_size)</span><br><span class="line"> <span class="comment"># fc6</span></span><br><span class="line"> fc6 = tf.nn.relu_layer(tf.reshape(maxpool5, [-<span class="number">1</span>, <span class="built_in">int</span>(prod(maxpool5.get_shape()[<span class="number">1</span>:]))]), fc6W, fc6b)</span><br><span class="line"> fc6_drop = tf.nn.dropout(fc6, keep_prob)</span><br><span class="line"> <span class="comment"># fc7</span></span><br><span class="line"> fc7 = tf.nn.relu_layer(fc6_drop, fc7W, fc7b)</span><br><span class="line"> fc7_drop = tf.nn.dropout(fc7, keep_prob)</span><br><span class="line"> <span class="comment"># fc8</span></span><br><span class="line"> fc8 = tf.nn.xw_plus_b(fc7_drop, fc8W, fc8b)</span><br><span class="line"> prob = tf.nn.softmax(fc8)</span><br><span class="line"> <span class="keyword">return</span> prob</span><br><span class="line"></span><br><span class="line"> logits = model(x)</span><br><span class="line"> predict = tf.argmax(logits, <span class="number">1</span>)</span><br><span class="line"> saver = tf.train.Saver({v.op.name: v <span class="keyword">for</span> v <span class="keyword">in</span> [conv1W, conv1b,</span><br><span class="line"> conv2W, conv2b,</span><br><span class="line"> conv3W, conv3b,</span><br><span class="line"> conv4W, conv4b,</span><br><span class="line"> conv5W, conv5b,</span><br><span class="line"> fc6W, fc6b,</span><br><span class="line"> fc7W, fc7b,</span><br><span class="line"> fc8W, fc8b]})</span><br><span class="line"></span><br><span class="line"> <span class="keyword">with</span> tf.Session(graph=graph) <span class="keyword">as</span> sess:</span><br><span class="line"> init = tf.global_variables_initializer()</span><br><span class="line"> sess.run(init)</span><br><span class="line"> coord = tf.train.Coordinator()</span><br><span class="line"> threads = tf.train.start_queue_runners(coord=coord)</span><br><span class="line"> saver.restore(sess, <span class="string">'./alex_model_spp.ckpt'</span>)</span><br><span class="line"> image = sess.run(my_img)</span><br><span class="line"> predict = predict.<span class="built_in">eval</span>(feed_dict={x: image, keep_prob: <span class="number">1.0</span>})</span><br><span class="line"> result.append(predict[<span class="number">0</span>])</span><br><span class="line"> f.write(data_dir[tstid[it]] + <span class="string">'\t'</span> + <span class="built_in">str</span>(predict[<span class="number">0</span>]) + <span class="string">'\t'</span> + <span class="built_in">str</span>(labels[tstid[it]]))</span><br><span class="line"> f.write(<span class="string">'\n'</span>)</span><br><span class="line"> coord.request_stop()</span><br><span class="line"> coord.join(threads)</span><br><span class="line"> sess.close()</span><br><span class="line"> <span class="keyword">del</span> sess</span><br><span class="line"> it = it + <span class="number">1</span></span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(<span class="string">'Test accuracy: %f'</span> %(<span class="built_in">sum</span>(np.array(result) == np.array(labels[tstid])).astype(<span class="string">'float'</span>)/<span class="built_in">len</span>(tstid)))</span><br><span class="line">f.close()</span><br></pre></td></tr></tbody></table></figure>]]></content>
<summary type="html"><p>代码来源:<a href="https://github.com/peace195/sppnet">https://github.com/peace195/sppnet</a></p>
<h1 id="代码解读"><a href="#代码解读" class="headerl</summary>
<category term="人工智能学习" scheme="https://thinksky5124.github.io/categories/%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E5%AD%A6%E4%B9%A0/"/>
<category term="目标检测" scheme="https://thinksky5124.github.io/categories/%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E5%AD%A6%E4%B9%A0/%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B/"/>
<category term="目标检测" scheme="https://thinksky5124.github.io/tags/%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B/"/>
<category term="SPP-Net" scheme="https://thinksky5124.github.io/tags/SPP-Net/"/>
<category term="tensorflow" scheme="https://thinksky5124.github.io/tags/tensorflow/"/>
</entry>
</feed>