【心得】雷泰光電實習 - 技術分享篇
在 面試 後,開始實習,熟悉了一下環境後當然就要開始有生產力囉,所以這篇 focus 在技術的部份,若對實習生活有興趣可以到 上一篇 看看。
技術分享
同面試那篇所說,我要透過平行計算加速 morphological operation,所以一開始有個 CPU version 讓我 trace,我的目標就是瘋狂加速他。
CPU version 超難懂,基本上就是各種 CV 的操作,在沒有 CV 背景的情況下去看超級痛苦,雖然可能才 3, 400 行 code,但我硬是看了好幾天還是沒看懂,最後覺得其實我只要可以瘋狂加速,這些 operation 背後的原理跟原因我不用熟,只要知道在幹嘛就好,抱持著這樣的心態後就大致開始掌握節奏了。
原本的 CPU version 2048 張照片約處理 700 seconds,遠遠不及產線上的需求,產線產生資料的速度太快,所以需要很高的 throughput 才能撐起來,因此需要 GPU 加速。
CPU version 是用 python 實作,要優化他的話我們首先分為以下幾步驟:
- Baseline: Pytorch
- Batch: Pytorch
- Optimized: Triton
以下將針對各個 part 做敘述。
Baseline - Pytorch
所有事情都要先有個 baseline,所以我打算先用 Pytorch 做出一個基本的版本可以跑在 GPU 上,再去 profile bottleneck 在哪再看怎麼進一步優化。
修改的過程中比較多的時間是在學 Pytorch,Pytorch 其實是一個 Deep learning framework,但我們要用他來做科學計算等等的操作,其實在很多地方上他沒有很直接的支援(至少不像 Deep learning 時直接就有一個 function 可以使用),所以很多時候其實是在看怎麼套 Pytorch 的 API。
我們有 3 種算法,其中一種可以很直接的用 Pytorch 做完,但另外 2 種有較複雜的操作,Pytorch 超卡,要繞超大一圈才能做到,效能也變很差,所以就直接開始實作 Triton 的版本ㄌ。
Baseline - Pytorch + Triton
Before Triton
跟原先計畫不同,原本是預計 Pytorch 做完再看 Triton,結果 Pytorch 太難搞了,最後就直接來用 Triton。
而什麼是 Triton 呢?他是一個 OpenAI 開發,基於 Compiler 的平行計算 framework。可以看看他的 官網 ,相比 Pytorch,他有非常好的效能,原因在於他可以將多個 operation fuse 起來,做到 reuse SRAM 的效果,基本上概念跟 flash attention 一樣,就是想辦法 reuse share memory。
但 Triton 有一點非常特別,CUDA 是先分 thread block,再區分 thread,一層一層去分配工作;然而 Triton 是以 block 為單位在工作,programmer 不用管到 thread level,只要放心交給 compiler,什麼 coalesced memory access 或者 bank conflict 他都會想辦法幫你解決,他最重要的 feature 就是 programmer 只要 focus 在 high level 的 data distribution 就好,細部交給 compiler。
所以他其實很特別,有點像是 Pytorch + CUDA 的感覺,沒有 Pytorch 那麼 high level,你可以管到 GPU 上的 operation,但又沒有 CUDA 那麼自由那麼 low level,你所有的操作都得基於一個 tensor,並且也有很多限制,像是 data 數量必須是 2 的次方等等(當然,更難 debug,畢竟是 compiler 在運作)。
總而言之,使用 Triton 為的就是 3 個目的:
- 更短的開發週期
- 更好的效能
- 程式碼不僅限於 CUDA device,ROCM 也想要可以跑
所以我們使用 Triton。
Tritoning
剛開始寫得時候因為有 CUDA 的基礎,所以覺得靠北這個比 Pytorch 簡單多了,Pytorch 要想辦法各種繞圈圈,Triton 只要將 data 分配下去就好。
但事情當然沒有這麼簡單,馬上第一個問題就來了,data 不是 2 的次方。
我在想因為 Triton 要讓 compiler 去做事,所以才有這個限制,但這樣其實滿搞的。總之後來的解法就是想辦法把資料 pad 到 2 的次方。
再來的問題就是超級難 debug,CUDA 可以直接 print,但 Triton 的 print(tl.device_print) 會 print 出一堆東西來超難看懂,基本上可以當作沒有。
雖然 Triton 有個 interpret mode,可以讓 Triton 跑在 CPU 上,print 出 intermediate data 做確認,但實際上這個方法有很多 bug,跑在 CPU 上會有一堆問題,很多運算結果都會錯誤,畢竟人家是設計跑在 GPU 上的。
其中有個 bug 很靠北,我特別有印象,我有個 tensor 是用 as_strided 這個方法從另一個 tensor 取出來的,結果 copy 到 GPU 上後,他在 CPU 上的資料是用 stride 所以可能不連續,但 copy 到 GPU 上後是連續,但 stride 又沒有轉換,所以直接用這個 tensor 的 stride 去取資料會拿到錯的,超級靠北,我覺得這是 Triton 跟 Pytorch 沒有結合好的問題,也有可能他們沒有預設到有人是用 as_strided 去取 tensor 吧。
總而言之,就是花了不少時間 debug,還有剛開始上手的時候花了好一陣子才搞懂 Triton 的運作原理,但上手後還是覺得挺不錯的,確實相比於 CUDA 而言有更短的開發週期,並且 portable。
把 3 種算法完成後 profile 一下,大致快了 10 倍左右,到目前為止是一次跑一張照片的速度,但要知道 GPU 強得並不是一次一張照片,是一次幾十甚至幾百張照片。
Batching
沒錯,所以接著就來實作 Batching 的版本啦~
一樣是用 Pytorch,Triton 的部份很好支援 Batching,Pytorch 的話老樣子,要去學怎麼用人家的 API,像是 convolution 中 tensor 的 dimension 的意思,怎麼一次對一張照片套多個 convolution kernel,或者對一個 batch 的照片套各自的 kernel 等等。
實作完成後果然進步神速,這才是 GPU 的威力,到目前為止 2048 張照片只需要 10 秒左右,相比 CPU version 快了 60 倍。
接著就是各種想辦法 optimize 了。
Optimize
首先當然就是 profile bottleneck,發現了 2 個 bottleneck:
- torch.unravel_index
- image shifting, cropping
Bottleneck - torch.unravel_index
第一個很神奇,只是單純把 argmax 的結果轉換成 index,我不懂為什麼 Pytorch 可以花一輩子的時間在做,甚至比 convolution 還要久。
解法也很簡單直接,就是直接用 Triton 寫一個 custom kernel 解掉,有點忘記確切數字,但應該比本來的有快上幾十倍甚至百倍。
Bottleneck - image shifting, cropping
跟上一個類似,也是直接用 Triton kernel 解掉,但是 Triton kernel 不支援 cropping,如同前面說的,他必須要 2 的次方的 data,所以 256 * 256 的照片變成 224 * 224 的照片他做不到笑死。
所以就用 Triton kernel 做 shifting 後,tensor 丟出來再用 Pytorch 做 cropping,像是這樣
1 | tensor[..., 16:-16, 16:-16] |
Other optimizations
解決掉主要的 bottleneck 後,就開始各種小優化,像是想辦法把各種東西 fuse 起來 reuse share memory;或者其實有些東西的計算沒有必要,可以用一些操作跳過;最後當然還有 tune batch size。
Result
透過上面全部的優化後,原先處理 2048 張照片 CPU 要 700 秒,現在 GPU 只要 2.34 秒,接近 300 倍的加速,超出我的預期非常多,所以也很有成就感。
以上~大概是全部寒假實習的內容分享,在實習的過程中學到了很多,Pytorch, Triton 還有 presentation 的方式,怎麼製作簡報等等。
我也成功下學期繼續在這邊做 part time,甚至他們讓我可以 work from home,超爽(薪水還上漲嘿嘿),work from home 的意思就是公司請飲料的時候我會 on site,其他時候在家裡,哇哈哈。
所以呢,說不定還會有下一篇,到時候再看看可以分享什麼~





