从理解 NVCC 到理解 How CUDA Runtime Works?

TL;DR

    现有的参考资料 blog_1blog_2blog_3 非常详细地记录了博主们分析 nvcc 的 CUDA 程序构建过程,但是我发现在我的平台上的编译过程和他们有些许出入,遂进行记录,我的平台参数如 platform 所示。在记录的过程中,我发现如果能对 CUDA 程序的构建过程有一个比较深入的了解,则也能辅助理解构建后的 CUDA 程序是如何利用 CUDA Runtime APIs 工作的,故形成了如标题所示的本文内容。

部件 参数/版本
GPU NVIDIA GeForce RTX 4060 Laptop GPU
CUDA Toolkit 12.2
NVIDIA Driver 536.45
nvcc V12.1.105

构建过程简介和程序准备

CUDA 程序构建与 nvcc

CUDA 程序的编译与链接

推荐读者朋友先阅读我的另一篇文章:Section 与 Segment:从 链接器 到 Runtime 的角度出发

    CUDA 程序的构建过程可以分为两个阶段: A 编译和 B 链接,如 nvcc_process_overview 所示。

    在 A 编译阶段,对于每一个 .cu 文件,一共可以分为 4 步处理: 1 首先将 .cu 文件进行 Host 侧的预处理,包括展开头文件内容,替换 <<<>>> 语法糖等;2 随后还是针对原始 .cu 文件,进行 Device 侧的预处理,然后首先编译出 .ptx 文件,如果 .cu 文件中带有 Kernel 的相关定义,则在 .ptx 文件中就会出现对应的 PTX 程序;3 随后基于 PTX 进行 SASS 的编译,并最终将 PTX 和 SASS 程序一起打包放入 .fatbin 文件中,后者实际上是一段 C 程序代码,其中用 C 结构体封装了已经被编译为二进制的程序;4 最终进行一次编译 (i.e., gcc -c 指令),将 Host 侧程序和已经被编译为 Device 二进制的程序一起,打包为一个待重定位文件 (i.e., .o 文件)。在 nvcc 的官方文档 nvcc_doc 中将上述编译流程定义为 CUDA 编译轨道 (CUDA Compilation Trajectory)

    在 B 链接阶段,TODO

NVCC

    nvcc 在 NVIDIA 官方文档 nvcc_doc 中被定义为 CUDA Compiler Driver,其目的在于掩盖包含了多次程序分离、预处理、编译和合并的 A 编译阶段以及 B 链接阶段。nvcc 并不是一个编译器,它的功能可以视为根据用户传入的命令行参数,按照一定的顺序调用对应的编译器/链接器对程序进行构建的 wrapper,因此才被称为编译器驱动。

    就后文的观察来看,nvcc 调用的编译器包括:

  • gcc: Host 侧程序的编译器,使用 -E 指令进行预处理;使用 -c 指令编译出待重定位文件;
  • cicc: Device 侧 PTX 程序的编译器,其是一个 LLVM-based Compiler cicc_zhihu,负责将 C/C++ 定义的 device function 转化为虚拟架构 PTX 程序;
  • cudafe++: 用于处理 .cu 文件中的 CUDA 语法糖 (e.g., <<<>>> 调用符号,__global__ kernel 定义等)

理解 nvcc 的编译参数

    接着我们理解一下 nvcc 的编译参数 nvcc_param_tips

1
2
3
4
5
nvcc  main.cu gemm_kernel_1.cu gemm_kernel_2.cu help_function.cu  \
-arch=compute_89 \
-code=compute_89,sm_89,sm_90 \
-o gemm_exe \
-rdc=true

    就最常用的参数来说,arch 用于指定虚拟架构,code 用于指定 nvcc 最后将向编译产物 (i.e., cubin 文件) 中放入哪些程序。例如 nvcc_exp_1 所示的命令,则 nvcc 将以 compute_89 的 PTX 代码,编译出 sm_89sm_90 的 SASS 代码,然后最终向 cubin 中放入 1 compute_89 的 PTX,以及 2 sm_893 sm_90 的 SASS 代码。

1
2
3
4
5
6
7
8
nvcc  main.cu gemm_kernel_1.cu gemm_kernel_2.cu help_function.cu  \ 
-gencode=arch=compute_89,code=sm_89 \
-gencode=arch=compute_89,code=sm_90 \
-gencode=arch=compute_90,code=sm_90 \
-gencode=arch=compute_89,code=compute_89 \
-gencode=arch=compute_90,code=compute_90 \
-o gemm_exe \
-rdc=true

    有时候我们又希望编译产物中可以包含基于多种虚拟架构的 PTX 代码,以及他们编译出来的对应 SASS,此时我们可以使用 gencode 选项。如 nvcc_exp_2 所示,编译完成后,编译产物中将包括 1 基于 compute_89 编译形成的 sm_89 SASS 程序;2 基于 compute_89 编译形成的 sm_90 SASS 程序;3 基于 compute_90 编译形成的 sm_90 SASS 程序;4 compute_89 PTX 程序和 5 compute_90 PTX 程序。

nvcc_exp_1nvcc_exp_2 中的 -rdc=true 参数用于使能 Relocatable Device Code,使得在编译的时候不需要将 __device__ __global__ nvidia_rdc

程序准备

    我们进行编译实验的 CUDA 程序结构如 structure 所示,它们之间的关系如 program_structure 所示。简而言之,我们在 gemm_kernel_1.cugemm_kernel_2.cu 中分别定义了 1 个 kernel,其中 gemm_kernel_2.cu 中定义的 kernel 调用了在 help_function.cu 中定义的 device function。另外我们还在 gemm_kernel_cublas.cu 中定义了一个函数,其使用了来自 cuBLAS 库 cublas_doc 的 APIs。程序的主线逻辑则位于 main.cu 中。

1
2
3
4
5
6
7
.
├── gemm_kernel.cuh
├── gemm_kernel_1.cu
├── gemm_kernel_2.cu
├── gemm_kernel_cublas.cu
├── help_function.cu
└── main.cu

    分析这 5 个文件的编译过程,对应 5 个验证目的:

  • main.cu: launch kernel 的 <<<>>> 语法糖在编译过程中被转化为了什么?
  • gemm_kernel_1.cu: kernel 是怎么被编译的?kernel 是怎么被注册和调用的?
  • gemm_kernel_2.cu: 调用了位于其它 .cu 文件定义的 device funtion 的 kernel 是怎么被编译的?以及注册和调用的?
  • gemm_kernel_cublas.cu: 调用了第三方库 (e.g., cuBLAS cublas_doc) 的 .cu 文件是怎么被编译和链接的?
  • help_function.cu: device function 是怎么被编译的?

    就具体文件内容来说,我们首先在 help_function.cu 中定义了一个 device function get_flat_index,如 help_function_cu 所示。

1
2
3
4
5
#include "gemm_kernel.cuh"

__device__ int get_flat_index(const int row_id, const int col_id, const int row_dim){
return row_id * row_dim + col_id;
}

    在 gemm_kernel_1.cugemm_kernel_2.cu 中,我们分别定义了一个用于 GEMM 计算的 kernel,程序分别如 gemm_kernel_1_cugemm_kernel_2_cu 所示。其中,gemm_kernel_2.cu 中定义的 kernel gemm_kernel_2 使用了 device function get_flat_function,而 gemm_kernel_1.cu 中定义的 kernel gemm_kernel_1 没有使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include "gemm_kernel.cuh"

// aligned implementation
__global__ void gemm_kernel_1(
/* m*n */ const int *a,
/* k*n */ const int *b_t,
/* m*k */ int *c,
const int m,
const int n,
const int k
){
#pragma unroll
for(int i=blockIdx.x*blockDim.x+threadIdx.x; i<=m*k; i+=blockDim.x){
int row_id = i/k;
int col_id = i-k*row_id;
int c_flat_index = row_id*m+col_id;

c[c_flat_index] = 0;

#pragma unroll
for(int j=0; j<n; j++){
c[c_flat_index] += a[row_id*n+j]*b_t[col_id*n+j];
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include "gemm_kernel.cuh"

// misaligned implementation
__global__ void gemm_kernel_2(
/* m*n */ const int *a,
/* n*k */ const int *b,
/* m*k */ int *c,
const int m,
const int n,
const int k
){
#pragma unroll
for(int i=blockIdx.x*blockDim.x+threadIdx.x; i<=m*k; i+=blockDim.x){
int row_id = i/k;
int col_id = i-k*row_id;
int c_flat_index = get_flat_index(/* row_id */row_id, /* col_id */col_id, /* row_dim */m);

c[c_flat_index] = 0;

#pragma unroll
for(int j=0; j<n; j++){
c[c_flat_index] += a[row_id*n+j]*b[j*n+col_id];
}
}
}

    在 gemm_kernel_cublas.cu 中,我们则定义了一个函数,其调用了 cuBLAS 库的部分 APIs,具体文件内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <stdlib.h>
#include <math.h>
#include <cuda_runtime.h>
#include "cublas_v2.h"
#define M 6
#define N 5
#define IDX2F(i,j,ld) ((((j)-1)*(ld))+((i)-1))
#include "gemm_kernel.cuh"

static __inline__ void modify (cublasHandle_t handle, float *m, int ldm, int n, int p, int q, float alpha, float beta){
cublasSscal (handle, n-q+1, &alpha, &m[IDX2F(p,q,ldm)], ldm);
cublasSscal (handle, ldm-p+1, &beta, &m[IDX2F(p,q,ldm)], 1);
}

void gemm_kernel_cublas(){
cudaError_t cudaStat;
cublasStatus_t stat;
cublasHandle_t handle;
int i, j;
float* devPtrA;
float* a = 0;
a = (float *)malloc (M * N * sizeof (*a));
if (!a) {
printf ("host memory allocation failed");
return;
}
for (j = 1; j <= N; j++) {
for (i = 1; i <= M; i++) {
a[IDX2F(i,j,M)] = (float)((i-1) * N + j);
}
}
cudaStat = cudaMalloc ((void**)&devPtrA, M*N*sizeof(*a));
if (cudaStat != cudaSuccess) {
printf ("device memory allocation failed");
return;
}
stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
return;
}
stat = cublasSetMatrix (M, N, sizeof(*a), a, M, devPtrA, M);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("data download failed");
cudaFree (devPtrA);
cublasDestroy(handle);
return;
}
modify (handle, devPtrA, M, N, 2, 3, 16.0f, 12.0f);
stat = cublasGetMatrix (M, N, sizeof(*a), devPtrA, M, a, M);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("data upload failed");
cudaFree (devPtrA);
cublasDestroy(handle);
return;
}
cudaFree (devPtrA);
cublasDestroy(handle);
for (j = 1; j <= N; j++) {
for (i = 1; i <= M; i++) {
printf ("%7.0f", a[IDX2F(i,j,M)]);
}
printf ("\n");
}
free(a);
}

    我们把以上所有函数的函数声明放在了 gemm_kernel.cuh 中:

1
2
3
4
5
6
7
8
9
10
#ifndef __GEMM_KERNEL_CUH__
#define __GEMM_KERNEL_CUH__

__device__ int get_flat_index(const int row_id, const int col_id, const int row_dim);
__global__ void gemm_kernel_1(const int *a, const int *b_t, int *c, const int m, const int n, const int k);
__global__ void gemm_kernel_2(const int *a, const int *b, int *c, const int m, const int n, const int k);

void gemm_kernel_cublas();

#endif

    在 main.cu 中,我们则完成了一个简单的 CUDA 程序的定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <cuda.h>
#include <stdio.h>

#include "gemm_kernel.cuh"

int main() {
int m=64, n=128, k=32;
int *da = NULL, *db = NULL, *dc = NULL;
int *ha = NULL, *hb = NULL, *hc = NULL;

cudaMalloc((void **)&da, m*n*sizeof(int));
if (da == NULL) {
printf("GPU alloc fail");
return -1;
}

cudaMalloc((void **)&db, n*k*sizeof(int));
if (db == NULL) {
printf("GPU alloc fail");
return -1;
}

cudaMalloc((void **)&dc, m*k*sizeof(int));
if (dc == NULL) {
printf("GPU alloc fail");
return -1;
}

ha = (int *)malloc(m*n*sizeof(int));
if (ha == NULL) {
printf("CPU alloc fail");
return -1;
}
for(int i=0; i<m*n; i++){
ha[i] = rand();
}

hb = (int *)malloc(n*k*sizeof(int));
if (hb == NULL) {
printf("CPU alloc fail");
return -1;
}
for(int i=0; i<n*k; i++){
hb[i] = rand();
}

hc = (int *)malloc(m*k*sizeof(int));
if (hc == NULL) {
printf("CPU alloc fail");
return -1;
}

cudaMemcpy(da, ha, m*n*sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(db, hb, n*k*sizeof(int), cudaMemcpyHostToDevice);

gemm_kernel_1<<<1,256>>>(da, db, dc, m, n, k);
cudaMemcpy(hc, dc, m*k*sizeof(int), cudaMemcpyDeviceToHost);

gemm_kernel_2<<<1,256>>>(da, db, dc, m, n, k);
cudaMemcpy(hc, dc, m*k*sizeof(int), cudaMemcpyDeviceToHost);

gemm_kernel_cublas();

cudaFree(da); cudaFree(db); cudaFree(dc);
free(ha); free(hb); free(hc);
return 0;
}

编译过程 Overview

    为了探究过程完整性,我们运行如 expose_nvcc 所示的指令,使 nvcc 暴露其编译过程:

1
nvcc main.cu gemm_kernel_1.cu gemm_kernel_2.cu -arch=compute_89 -code=compute_89,sm_89,sm_90 --verbose -o gemm_exe

    运行过后,我们可以获得如 compile_output 所示的输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
$ _NVVM_BRANCH_=nvvm
$ _SPACE_=
$ _CUDART_=cudart
$ _HERE_=/usr/local/cuda-12.1/bin
$ _THERE_=/usr/local/cuda-12.1/bin
$ _TARGET_SIZE_=
$ _TARGET_DIR_=
$ _TARGET_DIR_=targets/x86_64-linux
$ TOP=/usr/local/cuda-12.1/bin/..
$ NVVMIR_LIBRARY_DIR=/usr/local/cuda-12.1/bin/../nvvm/libdevice
$ LD_LIBRARY_PATH=/usr/local/cuda-12.1/bin/../lib:
$ PATH=/usr/local/cuda-12.1/bin/../nvvm/bin:/usr/local/cuda-12.1/bin:/usr/local/cuda-12.1/nvvm/bin:/usr/local/cuda-12.1/bin:/usr/local/cuda-12.1/nvvm/bin:/usr/local/cuda-12.1/bin:/home/zobin/anaconda3/bin:/home/zobin/anaconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/usr/lib/wsl/lib:/mnt/c/Windows/system32:/mnt/c/Windows:/mnt/c/Windows/System32/Wbem:/mnt/c/Windows/System32/WindowsPowerShell/v1.0:/mnt/c/Windows/System32/OpenSSH:/mnt/c/Program Files (x86)/NVIDIA Corporation/PhysX/Common:/mnt/c/Program Files/NVIDIA Corporation/NVIDIA NvDLISR:/mnt/c/Program Files/Docker/Docker/resources/bin:/mnt/e/applications/git/Git/cmd:/mnt/e/applications/nodejs:/mnt/c/ProgramData/chocolatey/bin:/mnt/e/applications/win32yank-x64:/mnt/c/Users/85280/AppData/Local/Microsoft/WindowsApps:/mnt/e/applications/vscode/Microsoft VS Code/bin:/mnt/c/Users/85280/AppData/Roaming/npm:/home/zobin/Applications/vivado_2022/Vivado/2022.2/bin:/home/zobin/Applications/vivado_2022/Vivado/2022.2/bin
$ INCLUDES="-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include"
$ LIBRARIES= "-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib/stubs" "-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib"
$ CUDAFE_FLAGS=
$ PTXAS_FLAGS=

# ================== 处理 main.cu ==================
# 预处理 host 侧代码
$ gcc -D__CUDA_ARCH_LIST__=890 -E -x c++ -D__CUDACC__ -D__NVCC__ "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -D__CUDACC_VER_MAJOR__=12 -D__CUDACC_VER_MINOR__=1 -D__CUDACC_VER_BUILD__=105 -D__CUDA_API_VER_MAJOR__=12 -D__CUDA_API_VER_MINOR__=1 -D__NVCC_DIAG_PRAGMA_SUPPORT__=1 -include "cuda_runtime.h" -m64 "main.cu" -o "/tmp/tmpxft_0000047f_00000000-5_main.cpp4.ii"

# 分离出 host 侧代码
$ cudafe++ --c++17 --gnu_version=110400 --display_error_number --orig_src_file_name "main.cu" --orig_src_path_name "/home/zobin/projects/pos/test_cuda/main.cu" --allow_managed --m64 --parse_templates --gen_c_file_name "/tmp/tmpxft_0000047f_00000000-6_main.cudafe1.cpp" --stub_file_name "tmpxft_0000047f_00000000-6_main.cudafe1.stub.c" --gen_module_id_file --module_id_file_name "/tmp/tmpxft_0000047f_00000000-4_main.module_id" "/tmp/tmpxft_0000047f_00000000-5_main.cpp4.ii"

# 预处理 device 侧代码
$ gcc -D__CUDA_ARCH__=890 -D__CUDA_ARCH_LIST__=890 -E -x c++ -DCUDA_DOUBLE_MATH_FUNCTIONS -D__CUDACC__ -D__NVCC__ "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -D__CUDACC_VER_MAJOR__=12 -D__CUDACC_VER_MINOR__=1 -D__CUDACC_VER_BUILD__=105 -D__CUDA_API_VER_MAJOR__=12 -D__CUDA_API_VER_MINOR__=1 -D__NVCC_DIAG_PRAGMA_SUPPORT__=1 -include "cuda_runtime.h" -m64 "main.cu" -o "/tmp/tmpxft_0000047f_00000000-17_main.cpp1.ii"

# 编译出 PTX 和 插桩文件
$ cicc --c++17 --gnu_version=110400 --display_error_number --orig_src_file_name "main.cu" --orig_src_path_name "/home/zobin/projects/pos/test_cuda/main.cu" --allow_managed -arch compute_89 -m64 --no-version-ident -ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 --include_file_name "tmpxft_0000047f_00000000-3_main.fatbin.c" -tused --module_id_file_name "/tmp/tmpxft_0000047f_00000000-4_main.module_id" --gen_c_file_name "/tmp/tmpxft_0000047f_00000000-6_main.cudafe1.c" --stub_file_name "/tmp/tmpxft_0000047f_00000000-6_main.cudafe1.stub.c" --gen_device_file_name "/tmp/tmpxft_0000047f_00000000-6_main.cudafe1.gpu" "/tmp/tmpxft_0000047f_00000000-17_main.cpp1.ii" -o "/tmp/tmpxft_0000047f_00000000-6_main.ptx"

# 编译出 cubin (SASS)
$ ptxas -arch=sm_90 -m64 "/tmp/tmpxft_0000047f_00000000-6_main.ptx" -o "/tmp/tmpxft_0000047f_00000000-18_main.sm_90.cubin"
$ ptxas -arch=sm_89 -m64 "/tmp/tmpxft_0000047f_00000000-6_main.ptx" -o "/tmp/tmpxft_0000047f_00000000-19_main.sm_89.cubin"

# 将 PTX 和 SASS 合并成为 fatbin.c
$ fatbinary -64 --cicc-cmdline="-ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 " "--image3=kind=elf,sm=90,file=/tmp/tmpxft_0000047f_00000000-18_main.sm_90.cubin" "--image3=kind=elf,sm=89,file=/tmp/tmpxft_0000047f_00000000-19_main.sm_89.cubin" "--image3=kind=ptx,sm=89,file=/tmp/tmpxft_0000047f_00000000-6_main.ptx" --embedded-fatbin="/tmp/tmpxft_0000047f_00000000-3_main.fatbin.c"
$ rm /tmp/tmpxft_0000047f_00000000-3_main.fatbin

# 编译出 host 目标文件
$ gcc -D__CUDA_ARCH__=890 -D__CUDA_ARCH_LIST__=890 -c -x c++ -DCUDA_DOUBLE_MATH_FUNCTIONS "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -m64 "/tmp/tmpxft_0000047f_00000000-6_main.cudafe1.cpp" -o "/tmp/tmpxft_0000047f_00000000-20_main.o"

# ================== 处理 gemm_kernel_1.cu ==================
# 预处理 host 侧代码
$ gcc -D__CUDA_ARCH_LIST__=890 -E -x c++ -D__CUDACC__ -D__NVCC__ "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -D__CUDACC_VER_MAJOR__=12 -D__CUDACC_VER_MINOR__=1 -D__CUDACC_VER_BUILD__=105 -D__CUDA_API_VER_MAJOR__=12 -D__CUDA_API_VER_MINOR__=1 -D__NVCC_DIAG_PRAGMA_SUPPORT__=1 -include "cuda_runtime.h" -m64 "gemm_kernel_1.cu" -o "/tmp/tmpxft_0000047f_00000000-9_gemm_kernel_1.cpp4.ii"

# 分离出 host 侧代码
$ cudafe++ --c++17 --gnu_version=110400 --display_error_number --orig_src_file_name "gemm_kernel_1.cu" --orig_src_path_name "/home/zobin/projects/pos/test_cuda/gemm_kernel_1.cu" --allow_managed --m64 --parse_templates --gen_c_file_name "/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.cudafe1.cpp" --stub_file_name "tmpxft_0000047f_00000000-10_gemm_kernel_1.cudafe1.stub.c" --gen_module_id_file --module_id_file_name "/tmp/tmpxft_0000047f_00000000-8_gemm_kernel_1.module_id" "/tmp/tmpxft_0000047f_00000000-9_gemm_kernel_1.cpp4.ii"

# 预处理 device 侧代码
$ gcc -D__CUDA_ARCH__=890 -D__CUDA_ARCH_LIST__=890 -E -x c++ -DCUDA_DOUBLE_MATH_FUNCTIONS -D__CUDACC__ -D__NVCC__ "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -D__CUDACC_VER_MAJOR__=12 -D__CUDACC_VER_MINOR__=1 -D__CUDACC_VER_BUILD__=105 -D__CUDA_API_VER_MAJOR__=12 -D__CUDA_API_VER_MINOR__=1 -D__NVCC_DIAG_PRAGMA_SUPPORT__=1 -include "cuda_runtime.h" -m64 "gemm_kernel_1.cu" -o "/tmp/tmpxft_0000047f_00000000-21_gemm_kernel_1.cpp1.ii"

# 编译出 PTX 和插桩文件
$ cicc --c++17 --gnu_version=110400 --display_error_number --orig_src_file_name "gemm_kernel_1.cu" --orig_src_path_name "/home/zobin/projects/pos/test_cuda/gemm_kernel_1.cu" --allow_managed -arch compute_89 -m64 --no-version-ident -ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 --include_file_name "tmpxft_0000047f_00000000-7_gemm_kernel_1.fatbin.c" -tused --module_id_file_name "/tmp/tmpxft_0000047f_00000000-8_gemm_kernel_1.module_id" --gen_c_file_name "/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.cudafe1.c" --stub_file_name "/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.cudafe1.stub.c" --gen_device_file_name "/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.cudafe1.gpu" "/tmp/tmpxft_0000047f_00000000-21_gemm_kernel_1.cpp1.ii" -o "/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.ptx"

# 编译出 cubin (SASS)
$ ptxas -arch=sm_90 -m64 "/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.ptx" -o "/tmp/tmpxft_0000047f_00000000-22_gemm_kernel_1.sm_90.cubin"
$ ptxas -arch=sm_89 -m64 "/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.ptx" -o "/tmp/tmpxft_0000047f_00000000-23_gemm_kernel_1.sm_89.cubin"

# 将 PTX 和 SASS 合并成为 fatbin.c
$ fatbinary -64 --cicc-cmdline="-ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 " "--image3=kind=elf,sm=90,file=/tmp/tmpxft_0000047f_00000000-22_gemm_kernel_1.sm_90.cubin" "--image3=kind=elf,sm=89,file=/tmp/tmpxft_0000047f_00000000-23_gemm_kernel_1.sm_89.cubin" "--image3=kind=ptx,sm=89,file=/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.ptx" --embedded-fatbin="/tmp/tmpxft_0000047f_00000000-7_gemm_kernel_1.fatbin.c"
$ rm /tmp/tmpxft_0000047f_00000000-7_gemm_kernel_1.fatbin

# 编译出 host 目标文件
$ gcc -D__CUDA_ARCH__=890 -D__CUDA_ARCH_LIST__=890 -c -x c++ -DCUDA_DOUBLE_MATH_FUNCTIONS "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -m64 "/tmp/tmpxft_0000047f_00000000-10_gemm_kernel_1.cudafe1.cpp" -o "/tmp/tmpxft_0000047f_00000000-24_gemm_kernel_1.o"

# ================== 处理 gemm_kernel_2.cu ==================
# 预处理 host 侧代码
$ gcc -D__CUDA_ARCH_LIST__=890 -E -x c++ -D__CUDACC__ -D__NVCC__ "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -D__CUDACC_VER_MAJOR__=12 -D__CUDACC_VER_MINOR__=1 -D__CUDACC_VER_BUILD__=105 -D__CUDA_API_VER_MAJOR__=12 -D__CUDA_API_VER_MINOR__=1 -D__NVCC_DIAG_PRAGMA_SUPPORT__=1 -include "cuda_runtime.h" -m64 "gemm_kernel_2.cu" -o "/tmp/tmpxft_0000047f_00000000-13_gemm_kernel_2.cpp4.ii"

# 分离出 host 侧代码
$ cudafe++ --c++17 --gnu_version=110400 --display_error_number --orig_src_file_name "gemm_kernel_2.cu" --orig_src_path_name "/home/zobin/projects/pos/test_cuda/gemm_kernel_2.cu" --allow_managed --m64 --parse_templates --gen_c_file_name "/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.cudafe1.cpp" --stub_file_name "tmpxft_0000047f_00000000-14_gemm_kernel_2.cudafe1.stub.c" --gen_module_id_file --module_id_file_name "/tmp/tmpxft_0000047f_00000000-12_gemm_kernel_2.module_id" "/tmp/tmpxft_0000047f_00000000-13_gemm_kernel_2.cpp4.ii"

# 预处理 device 侧代码
$ gcc -D__CUDA_ARCH__=890 -D__CUDA_ARCH_LIST__=890 -E -x c++ -DCUDA_DOUBLE_MATH_FUNCTIONS -D__CUDACC__ -D__NVCC__ "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -D__CUDACC_VER_MAJOR__=12 -D__CUDACC_VER_MINOR__=1 -D__CUDACC_VER_BUILD__=105 -D__CUDA_API_VER_MAJOR__=12 -D__CUDA_API_VER_MINOR__=1 -D__NVCC_DIAG_PRAGMA_SUPPORT__=1 -include "cuda_runtime.h" -m64 "gemm_kernel_2.cu" -o "/tmp/tmpxft_0000047f_00000000-25_gemm_kernel_2.cpp1.ii"

# 编译出 PTX 和插桩文件
$ cicc --c++17 --gnu_version=110400 --display_error_number --orig_src_file_name "gemm_kernel_2.cu" --orig_src_path_name "/home/zobin/projects/pos/test_cuda/gemm_kernel_2.cu" --allow_managed -arch compute_89 -m64 --no-version-ident -ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 --include_file_name "tmpxft_0000047f_00000000-11_gemm_kernel_2.fatbin.c" -tused --module_id_file_name "/tmp/tmpxft_0000047f_00000000-12_gemm_kernel_2.module_id" --gen_c_file_name "/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.cudafe1.c" --stub_file_name "/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.cudafe1.stub.c" --gen_device_file_name "/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.cudafe1.gpu" "/tmp/tmpxft_0000047f_00000000-25_gemm_kernel_2.cpp1.ii" -o "/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.ptx"

# 编译出 cubin (SASS)
$ ptxas -arch=sm_90 -m64 "/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.ptx" -o "/tmp/tmpxft_0000047f_00000000-26_gemm_kernel_2.sm_90.cubin"
$ ptxas -arch=sm_89 -m64 "/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.ptx" -o "/tmp/tmpxft_0000047f_00000000-27_gemm_kernel_2.sm_89.cubin"

# 将 PTX 和 SASS 合并成为 fatbin.c
$ fatbinary -64 --cicc-cmdline="-ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 " "--image3=kind=elf,sm=90,file=/tmp/tmpxft_0000047f_00000000-26_gemm_kernel_2.sm_90.cubin" "--image3=kind=elf,sm=89,file=/tmp/tmpxft_0000047f_00000000-27_gemm_kernel_2.sm_89.cubin" "--image3=kind=ptx,sm=89,file=/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.ptx" --embedded-fatbin="/tmp/tmpxft_0000047f_00000000-11_gemm_kernel_2.fatbin.c"
$ rm /tmp/tmpxft_0000047f_00000000-11_gemm_kernel_2.fatbin

# 编译出 host 目标文件
$ gcc -D__CUDA_ARCH__=890 -D__CUDA_ARCH_LIST__=890 -c -x c++ -DCUDA_DOUBLE_MATH_FUNCTIONS "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -m64 "/tmp/tmpxft_0000047f_00000000-14_gemm_kernel_2.cudafe1.cpp" -o "/tmp/tmpxft_0000047f_00000000-28_gemm_kernel_2.o"

# 提取出 sm_90 架构的 SASS 程序
$ nvlink -m64 --arch=sm_90 --register-link-binaries="/tmp/tmpxft_00000564_00000000-15_gemm_exe_dlink.reg.c" "-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib/stubs" "-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib" -cpu-arch=X86_64 -report-arch "/tmp/tmpxft_00000564_00000000-20_main.o" "/tmp/tmpxft_00000564_00000000-24_gemm_kernel_1.o" "/tmp/tmpxft_00000564_00000000-28_gemm_kernel_2.o" -lcudadevrt -o "/tmp/tmpxft_00000564_00000000-29_gemm_exe_dlink.sm_90.cubin" --host-ccbin "gcc"

# 提取出 sm_89 架构的 SASS 程序
$ nvlink -m64 --arch=sm_89 --register-link-binaries="/tmp/tmpxft_00000564_00000000-15_gemm_exe_dlink.reg.c" "-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib/stubs" "-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib" -cpu-arch=X86_64 -report-arch "/tmp/tmpxft_00000564_00000000-20_main.o" "/tmp/tmpxft_00000564_00000000-24_gemm_kernel_1.o" "/tmp/tmpxft_00000564_00000000-28_gemm_kernel_2.o" -lcudadevrt -o "/tmp/tmpxft_00000564_00000000-30_gemm_exe_dlink.sm_89.cubin" --host-ccbin "gcc"

# 将两份 cubin (SASS 程序) 合并为 fatbin.c
$ fatbinary -64 --cicc-cmdline="-ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 " -link "--image3=kind=elf,sm=90,file=/tmp/tmpxft_00000564_00000000-29_gemm_exe_dlink.sm_90.cubin" "--image3=kind=elf,sm=89,file=/tmp/tmpxft_00000564_00000000-30_gemm_exe_dlink.sm_89.cubin" --embedded-fatbin="/tmp/tmpxft_00000564_00000000-16_gemm_exe_dlink.fatbin.c"
$ rm /tmp/tmpxft_00000564_00000000-16_gemm_exe_dlink.fatbin

$ gcc -D__CUDA_ARCH_LIST__=890 -c -x c++ -DFATBINFILE="\"/tmp/tmpxft_00000564_00000000-16_gemm_exe_dlink.fatbin.c\"" -DREGISTERLINKBINARYFILE="\"/tmp/tmpxft_00000564_00000000-15_gemm_exe_dlink.reg.c\"" -I. -D__NV_EXTRA_INITIALIZATION= -D__NV_EXTRA_FINALIZATION= -D__CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ "-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" -D__CUDACC_VER_MAJOR__=12 -D__CUDACC_VER_MINOR__=1 -D__CUDACC_VER_BUILD__=105 -D__CUDA_API_VER_MAJOR__=12 -D__CUDA_API_VER_MINOR__=1 -D__NVCC_DIAG_PRAGMA_SUPPORT__=1 -m64 "/usr/local/cuda-12.1/bin/crt/link.stub" -o "/tmp/tmpxft_00000564_00000000-31_gemm_exe_dlink.o"

$ g++ -D__CUDA_ARCH_LIST__=890 -m64 -Wl,--start-group "/tmp/tmpxft_00000564_00000000-31_gemm_exe_dlink.o" "/tmp/tmpxft_00000564_00000000-20_main.o" "/tmp/tmpxft_00000564_00000000-24_gemm_kernel_1.o" "/tmp/tmpxft_00000564_00000000-28_gemm_kernel_2.o" "-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib/stubs" "-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib" -lcudadevrt -lcudart_static -lrt -lpthread -ldl -Wl,--end-group -o "gemm_exe"

    下面我们按照上面的日志输出内容,逐命令地对编译过程进行分析。

在下面执行的命令中,我们把原始命令中的 tmp 改为了本地临时文件夹 ./tmp,方便我们进行文件增删的分析。

main.cu 的编译

    首先我们对 main.cu 代码的编译过程进行拆解。

预处理 Host 侧源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
gcc -D__CUDA_ARCH_LIST__=890                                      \
-E \
-x c++ \
-D__CUDACC__ \
-D__NVCC__ \
"-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" \
-D__CUDACC_VER_MAJOR__=12 \
-D__CUDACC_VER_MINOR__=1 \
-D__CUDACC_VER_BUILD__=105 \
-D__CUDA_API_VER_MAJOR__=12 \
-D__CUDA_API_VER_MINOR__=1 \
-D__NVCC_DIAG_PRAGMA_SUPPORT__=1 \
-include "cuda_runtime.h" \
-m64 "main.cu" \
-o "./tmp/tmpxft_0000047f_00000000-5_main.cpp4.ii"

    运行了上述命令后,文件夹结构如下所示:

1
2
3
4
5
6
7
  .
# ├── gemm_kernel.cuh
# ├── gemm_kernel_1.cu
# ├── gemm_kernel_2.cu
# ├── main.cu
└── tmp
   └── tmpxft_0000047f_00000000-5_main.cpp4.ii

    上述命令实际上是 gcc -E 指令,目的是对 main.cu 进行 预处理 (pre-process) / 预编译 (pre-compile) 操作,也即把存在于头文件中的宏、结构体定义等,在 main.cu 中进行定义和展开。同时上述命令还在生成的文件中添加了若干宏定义,生成的文件 main.cpp4.ii 是一个长达 3 万行的代码文件,截取其中内容,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// ...

struct __attribute__((device_builtin)) char1
{
signed char x;
};

struct __attribute__((device_builtin)) uchar1
{
unsigned char x;
};


struct __attribute__((device_builtin)) __attribute__((aligned(2))) char2
{
signed char x, y;
};

// ...

# 1 "gemm_kernel.cuh" 1

# 4 "gemm_kernel.cuh"
__attribute__((global)) void gemm_kernel_1(const int *a, const int *b_t, int *c, const int m, const int n, const int k);
__attribute__((global)) void gemm_kernel_2(const int *a, const int *b, int *c, const int m, const int n, const int k);
# 6 "main.cu" 2

int main() {
int m=64, n=128, k=32;
int *da =
# 9 "main.cu" 3 4
__null
# 9 "main.cu"
, *db =
# 9 "main.cu" 3 4
__null
# 9 "main.cu"
, *dc =
# 9 "main.cu" 3 4
__null
# 9 "main.cu"
;
int *ha =
# 10 "main.cu" 3 4
__null
# 10 "main.cu"
, *hb =
# 10 "main.cu" 3 4
__null
# 10 "main.cu"
, *hc =
# 10 "main.cu" 3 4
__null
# 10 "main.cu"
;

cudaMalloc((void **)&da, m*n*sizeof(int));
if (da ==
# 13 "main.cu" 3 4
__null
# 13 "main.cu"
) {
printf("GPU alloc fail");
return -1;
}

cudaMalloc((void **)&db, n*k*sizeof(int));
if (db ==
# 19 "main.cu" 3 4
__null
# 19 "main.cu"
) {
printf("GPU alloc fail");
return -1;
}

cudaMalloc((void **)&dc, m*k*sizeof(int));
if (dc ==
# 25 "main.cu" 3 4
__null
# 25 "main.cu"
) {
printf("GPU alloc fail");
return -1;
}

ha = (int *)malloc(m*n*sizeof(int));
if (ha ==
# 31 "main.cu" 3 4
__null
# 31 "main.cu"
) {
printf("CPU alloc fail");
return -1;
}
for(int i=0; i<m*n; i++){
ha[i] = rand();
}

hb = (int *)malloc(n*k*sizeof(int));
if (hb ==
# 40 "main.cu" 3 4
__null
# 40 "main.cu"
) {
printf("CPU alloc fail");
return -1;
}
for(int i=0; i<n*k; i++){
hb[i] = rand();
}

hc = (int *)malloc(m*k*sizeof(int));
if (hc ==
# 49 "main.cu" 3 4
__null
# 49 "main.cu"
) {
printf("CPU alloc fail");
return -1;
}

cudaMemcpy(da, ha, m*n*sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(db, hb, n*k*sizeof(int), cudaMemcpyHostToDevice);

gemm_kernel_1<<<1,256>>>(da, db, dc, m, n, k);
cudaMemcpy(hc, dc, m*k*sizeof(int), cudaMemcpyDeviceToHost);

gemm_kernel_2<<<1,256>>>(da, db, dc, m, n, k);
cudaMemcpy(hc, dc, m*k*sizeof(int), cudaMemcpyDeviceToHost);

cudaFree(da); cudaFree(db); cudaFree(dc);
free(ha); free(hb); free(hc);
return 0;
}

    在上述代码中,我们在 Line 24 和 25 处可以看见 gemm_kernel.cuh 的函数声明被展开到了 main.cpp4.ii 中,并且 __global__ 的前缀被更换为了 __attribute__((global));另外我们在 Line 126 和 Line 129 处仍然可以看见程序使用 <<<>>> 的方式来异步地调用 kernel,可见此时程序只是进行了预处理工作,CUDA device 侧程序和 CUDA Runtime 的程序用法,依然和 Host 侧代码没有分离。

分离 Host 侧源代码

1
2
3
4
5
6
7
8
9
10
11
12
cudafe++  --c++17                                                                                       \
--gnu_version=110400 \
--display_error_number \
--orig_src_file_name "main.cu" \
--orig_src_path_name "/home/zobin/projects/pos/test_cuda/main.cu" \
--allow_managed \
--m64 \
--parse_templates \
--gen_c_file_name "./tmp/tmpxft_0000047f_00000000-6_main.cudafe1.cpp" \
--stub_file_name "tmpxft_0000047f_00000000-6_main.cudafe1.stub.c" \
--gen_module_id_file --module_id_file_name "./tmp/tmpxft_0000047f_00000000-4_main.module_id" \
"./tmp/tmpxft_0000047f_00000000-5_main.cpp4.ii"

    命令运行后,文件夹结构如下所示:

1
2
3
4
5
6
7
8
9
  .
# ├── gemm_kernel.cuh
# ├── gemm_kernel_1.cu
# ├── gemm_kernel_2.cu
# ├── main.cu
└── tmp
   ├── tmpxft_0000047f_00000000-4_main.module_id
#    ├── tmpxft_0000047f_00000000-5_main.cpp4.ii
   └── tmpxft_0000047f_00000000-6_main.cudafe1.cpp

    可见运行后,将会生成两个新文件: main.cudafe1.cppmain.module_id 文件。其中 main.cudafe1.cpp 的文件内容摘抄如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# 4 "gemm_kernel.cuh"
void gemm_kernel_1(const int * a, const int * b_t, int * c, const int m, const int n, const int k);
# 5
void gemm_kernel_2(const int * a, const int * b, int * c, const int m, const int n, const int k);
# 7 "main.cu"
int main() {
# 8
int m = 64, n = 128, k = 32;
# 9
int *da = (__null),
# 9 "main.cu"
*db = (__null),
# 9 "main.cu"
*dc = (__null);
# 10 "main.cu"
int *ha = (__null),
# 10 "main.cu"
*hb = (__null),
# 10 "main.cu"
*hc = (__null);
# 12 "main.cu"
cudaMalloc((void **)(&da), (m * n) * sizeof(int));
# 13
if (da == (__null))
# 13 "main.cu"
{
# 14
printf("GPU alloc fail");
# 15
return -1;
# 16
}
# 18
cudaMalloc((void **)(&db), (n * k) * sizeof(int));
# 19
if (db == (__null))
# 19 "main.cu"
{
# 20
printf("GPU alloc fail");
# 21
return -1;
# 22
}
# 24
cudaMalloc((void **)(&dc), (m * k) * sizeof(int));
# 25
if (dc == (__null))
# 25 "main.cu"
{
# 26
printf("GPU alloc fail");
# 27
return -1;
# 28
}
# 30
ha = ((int *)malloc((m * n) * sizeof(int)));
# 31
if (ha == (__null))
# 31 "main.cu"
{
# 32
printf("CPU alloc fail");
# 33
return -1;
# 34
}
# 35
for (int i = 0; i < (m * n); i++) {
# 36
(ha[i]) = rand();
# 37
}
# 39
hb = ((int *)malloc((n * k) * sizeof(int)));
# 40
if (hb == (__null))
# 40 "main.cu"
{
# 41
printf("CPU alloc fail");
# 42
return -1;
# 43
}
# 44
for (int i = 0; i < (n * k); i++) {
# 45
(hb[i]) = rand();
# 46
}
# 48
hc = ((int *)malloc((m * k) * sizeof(int)));
# 49
if (hc == (__null))
# 49 "main.cu"
{
# 50
printf("CPU alloc fail");
# 51
return -1;
# 52
}
# 54
cudaMemcpy(da, ha, (m * n) * sizeof(int), cudaMemcpyHostToDevice);
# 55
cudaMemcpy(db, hb, (n * k) * sizeof(int), cudaMemcpyHostToDevice);
# 57
(__cudaPushCallConfiguration(1, 256)) ? (void)0 : gemm_kernel_1(da, db, dc, m, n, k);
# 58
cudaMemcpy(hc, dc, (m * k) * sizeof(int), cudaMemcpyDeviceToHost);
# 60
(__cudaPushCallConfiguration(1, 256)) ? (void)0 : gemm_kernel_2(da, db, dc, m, n, k);
# 61
cudaMemcpy(hc, dc, (m * k) * sizeof(int), cudaMemcpyDeviceToHost);
# 63
cudaFree(da); cudaFree(db); cudaFree(dc);
# 64
free(ha); free(hb); free(hc);
# 65
return 0;
# 66
}

# 1 "tmpxft_0000047f_00000000-6_main.cudafe1.stub.c"
#define _NV_ANON_NAMESPACE _GLOBAL__N__9df44bf1_7_main_cu_main
#ifdef _NV_ANON_NAMESPACE
#endif
# 1 "tmpxft_0000047f_00000000-6_main.cudafe1.stub.c"
#include "tmpxft_0000047f_00000000-6_main.cudafe1.stub.c"
# 1 "tmpxft_0000047f_00000000-6_main.cudafe1.stub.c"
#undef _NV_ANON_NAMESPACE

    在 Line 2 和 Line 4 中可以看见,kernel 的定函数声明中 __attribute__((global)) 的编译器注释已经被去掉了,实际上这里留下的是同名的函数接口声明 void gemm_kernel_x(const int*, const int*, int*, const int, const int, const int) (以下简称为 gemm_kernel_x 接口);另外,在 Line 110 和 Line 114 中可以看见,kernel launch 的方式已经被改为了先调用 __cudaPushCallConfiguration 将 Launch 参数 (e.g., gridDim, blockDim, etc.) push 到某处,然后直接运行 gemm_kernel_x 接口。综上,此时的 Host 侧代码已经和 Device 侧代码分离开来了,main.cudafe1.cpp 中包括的是 Host 侧的代码。

    值得注意的是,在 Line 126~133 还可以看见 main.cudafe1.cpp include 了 main.cudafe1.stub.c 文件,后者在此时还没有被生成,实际上后者将间接包含编译后的 Device 侧代码,我们在下文将会看到它的内容。

    另外,main.module_id 文件内容如下所示,该文件存储了当前正在编译的 CUDA module 的 id 信息,我们在后文把它称为 module_id。

1
_9df44bf1_7_main_cu_main

预处理 Device 侧源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
gcc -D__CUDA_ARCH__=890                                           \
-D__CUDA_ARCH_LIST__=890 \
-E \
-x c++ \
-DCUDA_DOUBLE_MATH_FUNCTIONS \
-D__CUDACC__ \
-D__NVCC__ \
"-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" \
-D__CUDACC_VER_MAJOR__=12 \
-D__CUDACC_VER_MINOR__=1 \
-D__CUDACC_VER_BUILD__=105 \
-D__CUDA_API_VER_MAJOR__=12 \
-D__CUDA_API_VER_MINOR__=1 \
-D__NVCC_DIAG_PRAGMA_SUPPORT__=1 \
-include "cuda_runtime.h" \
-m64 "main.cu" \
-o "./tmp/tmpxft_0000047f_00000000-17_main.cpp1.ii"

    命令运行后,文件夹结构如下所示:

1
2
3
4
5
6
7
8
9
10
  .
# ├── gemm_kernel.cuh
# ├── gemm_kernel_1.cu
# ├── gemm_kernel_2.cu
# ├── main.cu
└── tmp
   ├── tmpxft_0000047f_00000000-17_main.cpp1.ii
#    ├── tmpxft_0000047f_00000000-4_main.module_id
#    ├── tmpxft_0000047f_00000000-5_main.cpp4.ii
#    └── tmpxft_0000047f_00000000-6_main.cudafe1.cpp

    实际上 preprocess_devicepreprocess_host 处展示的命令是一样的,就是对 main.cu 进行预处理。这里的预处理唯一的不同可能就是多加了宏定义 __CUDA_ARCH__,该命令生成的文件是 main.cpp1.ii,由于内容和 main.cpp4.ii 并无明显出入,此处不再赘述。

编译出 PTX 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
cicc  --c++17                                                                     \
--gnu_version=110400 \
--display_error_number \
--orig_src_file_name "main.cu" \
--orig_src_path_name "/home/zobin/projects/pos/test_cuda/main.cu" \
--allow_managed \
-arch compute_89 \
-m64 \
--no-version-ident \
-ftz=0 \
-prec_div=1 \
-prec_sqrt=1 \
-fmad=1 \
--include_file_name "tmpxft_0000047f_00000000-3_main.fatbin.c" \
-tused \
--module_id_file_name "./tmp/tmpxft_0000047f_00000000-4_main.module_id" \
--gen_c_file_name "./tmp/tmpxft_0000047f_00000000-6_main.cudafe1.c" \
--stub_file_name "./tmp/tmpxft_0000047f_00000000-6_main.cudafe1.stub.c" \
--gen_device_file_name "./tmp/tmpxft_0000047f_00000000-6_main.cudafe1.gpu" \
"./tmp/tmpxft_0000047f_00000000-17_main.cpp1.ii" \
-o "./tmp/tmpxft_0000047f_00000000-6_main.ptx"
cicc 大概率不在系统 binary 搜索路径下,需要手动向 terminal 配置文件中添加:
export PATH=/usr/local/cuda-12.1/nvvm/bin:$PATH

    运行了上述命令后,文件夹结构如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
  .
# ├── gemm_kernel.cuh
# ├── gemm_kernel_1.cu
# ├── gemm_kernel_2.cu
# ├── main.cu
└── tmp
#    ├── tmpxft_0000047f_00000000-17_main.cpp1.ii
#   ├── tmpxft_0000047f_00000000-4_main.module_id
#   ├── tmpxft_0000047f_00000000-5_main.cpp4.ii
├── tmpxft_0000047f_00000000-6_main.cudafe1.c
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.cpp
├── tmpxft_0000047f_00000000-6_main.cudafe1.gpu
   ├── tmpxft_0000047f_00000000-6_main.cudafe1.stub.c
   └── tmpxft_0000047f_00000000-6_main.ptx

    可见一共生成了 4 个新文件,其中较为重要的 2 个文件是 main.ptxcudafe1.stub.c

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-function"
#pragma GCC diagnostic ignored "-Wcast-qual"
#define __NV_CUBIN_HANDLE_STORAGE__ static
#if !defined(__CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__)
#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
#endif
#include "crt/host_runtime.h"
#include "tmpxft_0000047f_00000000-3_main.fatbin.c"
static void __nv_cudaEntityRegisterCallback(void **);
static void __sti____cudaRegisterAll(void) __attribute__((__constructor__));
static void __nv_cudaEntityRegisterCallback(void **__T2){
__nv_dummy_param_ref(__T2);
__nv_save_fatbinhandle_for_managed_rt(__T2);
}
static void __sti____cudaRegisterAll(void){
__cudaRegisterBinary(__nv_cudaEntityRegisterCallback);
}

#pragma GCC diagnostic pop

    cudafe1.stub.c 文件的内容如 stub 所示,在 Line 11 中我们可以观察到一个被标记为 __attribute__((__constructor__)) 的函数 __sti____cudaRegisterAll,这意味着它会在程序启动时被执行,它的定义位于 Line 16-18,可以看见它调用了 __cudaRegisterBinary,顾名思义是在注册跑在 device 上的二进制程序。__cudaRegisterBinary 是一个 CUDA Runtime 的内部 API,我们可以看到它的逻辑是传入注册回调函数 __nv_cudaEntityRegisterCallback,由于在 main.cu 中我们并没有定义任何 kernels,因此我们在这个用于注册的回调函数中并不能看到过多细节,在 compile_gemm_kernel_1 中我们将看到 1 个 kernel 被注册进 runtime 的更多细节。

    而对于 main.ptx,由于我们在 main.cu 中没有定义任何 kernel,所以在 main_ptx 中并没有任何内容。

1
2
3
4
5
6
7
8
9
10
11
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-32688072
// Cuda compilation tools, release 12.1, V12.1.105
// Based on NVVM 7.0.1
//

.version 8.1
.target sm_89
.address_size 64

编译出 SASS 代码 (cubin)

1
2
ptxas -arch=sm_90 -m64  "./tmp/tmpxft_0000047f_00000000-6_main.ptx"  -o "./tmp/tmpxft_0000047f_00000000-18_main.sm_90.cubin" 
ptxas -arch=sm_89 -m64 "./tmp/tmpxft_0000047f_00000000-6_main.ptx" -o "./tmp/tmpxft_0000047f_00000000-19_main.sm_89.cubin"

    运行了上述命令后,文件夹结构如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  .
# ├── gemm_kernel.cuh
# ├── gemm_kernel_1.cu
# ├── gemm_kernel_2.cu
# ├── main.cu
└── tmp
#    ├── tmpxft_0000047f_00000000-17_main.cpp1.ii
   ├── tmpxft_0000047f_00000000-18_main.sm_90.cubin
   ├── tmpxft_0000047f_00000000-19_main.sm_89.cubin
#    ├── tmpxft_0000047f_00000000-4_main.module_id
#    ├── tmpxft_0000047f_00000000-5_main.cpp4.ii
#    ├── tmpxft_0000047f_00000000-6_main.cudafe1.c
#    ├── tmpxft_0000047f_00000000-6_main.cudafe1.cpp
#    ├── tmpxft_0000047f_00000000-6_main.cudafe1.gpu
#    ├── tmpxft_0000047f_00000000-6_main.cudafe1.stub.c
#    └── tmpxft_0000047f_00000000-6_main.ptx

    新生成的 main.sm_90.cubinmain.sm_89.cubin 理论上应包含在设备上被执行的 SASS 二进制。同理,由于我们在 main.cu 并没有定义任何 kernels,因此这两个 cubin 文件中不会包含任何内容,如 empty_cubin_1empty_cubin_2 所示。

1
2
$ objdump -s ./tmp/tmpxft_0000047f_00000000-18_main.sm_90.cubin
./tmp/tmpxft_0000047f_00000000-18_main.sm_90.cubin: file format elf64-little
1
2
$ objdump -s ./tmp/tmpxft_0000047f_00000000-19_main.sm_89.cubin
./tmp/tmpxft_0000047f_00000000-19_main.sm_89.cubin: file format elf64-little

合并 PTX 和 SASS 代码 (合并 ptx 和 cubin 生成 fatbin)

1
2
3
4
5
6
7
fatbinary -64                                                                               \
--cicc-cmdline= \
"-ftz=0 -prec_div=1 -prec_sqrt=1 -fmad=1 " \
"--image3=kind=elf,sm=90,file=./tmp/tmpxft_0000047f_00000000-18_main.sm_90.cubin" \
"--image3=kind=elf,sm=89,file=./tmp/tmpxft_0000047f_00000000-19_main.sm_89.cubin" \
"--image3=kind=ptx,sm=89,file=./tmp/tmpxft_0000047f_00000000-6_main.ptx" \
--embedded-fatbin="./tmp/tmpxft_0000047f_00000000-3_main.fatbin.c"

    运行了上述命令后,文件夹结构如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
  .
# ├── gemm_kernel.cuh
# ├── gemm_kernel_1.cu
# ├── gemm_kernel_2.cu
# ├── main.cu
└── tmp
# ├── tmpxft_0000047f_00000000-17_main.cpp1.ii
# ├── tmpxft_0000047f_00000000-18_main.sm_90.cubin
# ├── tmpxft_0000047f_00000000-19_main.sm_89.cubin
├── tmpxft_0000047f_00000000-3_main.fatbin.c
# ├── tmpxft_0000047f_00000000-4_main.module_id
# ├── tmpxft_0000047f_00000000-5_main.cpp4.ii
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.c
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.cpp
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.gpu
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.stub.c
# └── tmpxft_0000047f_00000000-6_main.ptx

    上述命令将 main.sm_89.cubinmain.sm_90.cubinmain.ptx 合并到了 main.fatbin.c 中,后者的文件内容如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#ifndef __SKIP_INTERNAL_FATBINARY_HEADERS
#include "fatbinary_section.h"
#endif
#define __CUDAFATBINSECTION ".nvFatBinSegment"
#define __CUDAFATBINDATASECTION ".nv_fatbin"
asm(
".section .nv_fatbin, \"a\"\n"
".align 8\n"
"fatbinData:\n"
".quad 0x00100001ba55ed50,0x0000000000000658,0x0000004001010002,0x00000000000002b0\n"
".quad 0x0000000000000000,0x0000005a00010007,0x0000000000000000,0x0000000000000011\n"
".quad 0x0000000000000000,0x0000000000000000,0x33010102464c457f,0x0000000000000007\n"
".quad 0x0000007900be0002,0x0000000000000000,0x0000000000000240,0x0000000000000100\n"
".quad 0x003800400059055a,0x0001000500400002,0x7472747368732e00,0x747274732e006261\n"
".quad 0x746d79732e006261,0x746d79732e006261,0x78646e68735f6261,0x7466752e766e2e00\n"
".quad 0x2e007972746e652e,0x006f666e692e766e,0x665f67756265642e,0x00000000656d6172\n"
".quad 0x0000000000000000,0x7368732e00000000,0x732e006261747274,0x732e006261747274\n"
".quad 0x732e006261746d79,0x68735f6261746d79,0x2e766e2e0078646e,0x72746e652e746675\n"
".quad 0x6e692e766e2e0079,0x756265642e006f66,0x00656d6172665f67,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000000000000,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000000000000,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000300000001,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000040,0x000000000000004d,0x0000000000000000\n"
".quad 0x0000000000000001,0x0000000000000000,0x000000030000000b,0x0000000000000000\n"
".quad 0x0000000000000000,0x000000000000009b,0x000000000000004d,0x0000000000000000\n"
".quad 0x0000000000000001,0x0000000000000000,0x0000000200000013,0x0000000000000000\n"
".quad 0x0000000000000000,0x00000000000000e8,0x0000000000000018,0x0000000100000002\n"
".quad 0x0000000000000008,0x0000000000000018,0x0000000100000040,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000100,0x0000000000000000,0x0000000000000000\n"
".quad 0x0000000000000001,0x0000000000000000,0x0000000400000006,0x0000000000000240\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000000000070,0x0000000000000070\n"
".quad 0x0000000000000008,0x0000000400000001,0x0000000000000240,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000070,0x0000000000000070,0x0000000000000008\n"
".quad 0x0000004001010002,0x00000000000002a8,0x0000000000000000,0x0000005900010007\n"
".quad 0x0000000000000000,0x0000000000000011,0x0000000000000000,0x0000000000000000\n"
".quad 0x33010102464c457f,0x0000000000000007,0x0000007900be0002,0x0000000000000000\n"
".quad 0x0000000000000238,0x00000000000000f8,0x0038004000590559,0x0001000500400002\n"
".quad 0x7472747368732e00,0x747274732e006261,0x746d79732e006261,0x746d79732e006261\n"
".quad 0x78646e68735f6261,0x7466752e766e2e00,0x2e007972746e652e,0x006f666e692e766e\n"
".quad 0x665f67756265642e,0x732e0000656d6172,0x0062617472747368,0x006261747274732e\n"
".quad 0x006261746d79732e,0x5f6261746d79732e,0x6e2e0078646e6873,0x6e652e7466752e76\n"
".quad 0x2e766e2e00797274,0x65642e006f666e69,0x6d6172665f677562,0x0000000000000065\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000000000000,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000000000000,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000000000000,0x0000000300000001\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000000000040,0x000000000000004d\n"
".quad 0x0000000000000000,0x0000000000000001,0x0000000000000000,0x000000030000000b\n"
".quad 0x0000000000000000,0x0000000000000000,0x000000000000008d,0x000000000000004d\n"
".quad 0x0000000000000000,0x0000000000000001,0x0000000000000000,0x0000000200000013\n"
".quad 0x0000000000000000,0x0000000000000000,0x00000000000000e0,0x0000000000000018\n"
".quad 0x0000000100000002,0x0000000000000008,0x0000000000000018,0x0000000100000040\n"
".quad 0x0000000000000000,0x0000000000000000,0x00000000000000f8,0x0000000000000000\n"
".quad 0x0000000000000000,0x0000000000000001,0x0000000000000000,0x0000000500000006\n"
".quad 0x0000000000000238,0x0000000000000000,0x0000000000000000,0x0000000000000070\n"
".quad 0x0000000000000070,0x0000000000000008,0x0000000500000001,0x0000000000000238\n"
".quad 0x0000000000000000,0x0000000000000000,0x0000000000000070,0x0000000000000070\n"
".quad 0x0000000000000008,0x0000004801010001,0x0000000000000038,0x0000004000000036\n"
".quad 0x0000005900080001,0x0000000000000000,0x0000000000002011,0x0000000000000000\n"
".quad 0x0000000000000038,0x0000000000000000,0x762e21f000010a13,0x38206e6f69737265\n"
".quad 0x677261742e0a312e,0x39385f6d73207465,0x7365726464612e0a,0x3620657a69735f73\n"
".quad 0x0000000a0a0a0a34\n"
".text\n");
#ifdef __cplusplus
extern "C" {
#endif
extern const unsigned long long fatbinData[205];
#ifdef __cplusplus
}
#endif
#ifdef __cplusplus
extern "C" {
#endif
static const __fatBinC_Wrapper_t __fatDeviceText __attribute__ ((aligned (8))) __attribute__ ((section (__CUDAFATBINSECTION)))=
{ 0x466243b1, 1, fatbinData, 0 };
#ifdef __cplusplus
}
#endif

    Line 73 定义的类型为 __fatBinC_Wrapper_t 的变量 __fatDeviceText 即是我们在 main.fatbin.c 的最终产物,可以观察到它最终被包括在了名为 nvFatbinSegment 的 section 中。__fatBinC_Wrapper_t 的结构体定义没有明确的文档说明,但可以参考 Yifan Sun stackoverflow_key 给出的逆向工程分析:

1
2
3
4
5
6
struct {
uint32_t magic; // Always 0x466243b1
uint32_t seq; // Sequence number of the cubin
uint64_t ptr; // The pointer to the real cubin
uint64_t data_ptr; // Some pointer related to the data segment
}

    其首先包含一个 magic number 0x466243b1;然后是当前 cubin 的序列号,当前是我们编译的第一个 cubin,因此其序列号为 1;然后是指向真正 cubin 的指针,在 main_fatbin_c 中我们可以看见 cubin 程序又被包括在了 .nv_fatbin section 中;最后是一个指向 data segment 的指针,main_fatbin_c 中是一个 null 指针。

生成目标文件

1
2
3
4
5
6
7
8
gcc -D__CUDA_ARCH__=890                                           \
-D__CUDA_ARCH_LIST__=890 \
-c \
-x c++ \
-DCUDA_DOUBLE_MATH_FUNCTIONS \
"-I/usr/local/cuda-12.1/bin/../targets/x86_64-linux/include" \
-m64 "/tmp/tmpxft_0000047f_00000000-6_main.cudafe1.cpp" \
-o "/tmp/tmpxft_0000047f_00000000-20_main.o"

    运行了上述命令后,文件夹结构如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
.
# ├── gemm_kernel.cuh
# ├── gemm_kernel_1.cu
# ├── gemm_kernel_2.cu
# ├── main.cu
└── tmp
# ├── tmpxft_0000047f_00000000-17_main.cpp1.ii
# ├── tmpxft_0000047f_00000000-18_main.sm_90.cubin
# ├── tmpxft_0000047f_00000000-19_main.sm_89.cubin
├── tmpxft_0000047f_00000000-20_main.o
# ├── tmpxft_0000047f_00000000-3_main.fatbin.c
# ├── tmpxft_0000047f_00000000-4_main.module_id
# ├── tmpxft_0000047f_00000000-5_main.cpp4.ii
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.c
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.cpp
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.gpu
# ├── tmpxft_0000047f_00000000-6_main.cudafe1.stub.c
# └── tmpxft_0000047f_00000000-6_main.ptx

    上述命令是一个 gcc -c 指令,代表着该指令完成了编译,但是还没有进行链接。其完成了对 main.cudafe1.cpp 文件的编译过程(生成自 cudafe1_cpp),也即编译了 host 侧和 device 侧的程序,但暂未完成跨文件符号的链接操作。

1
objdump -s ./tmp/tmpxft_0000356e_00000000-11_main.o

    通过运行以上命令,我们可以看到编译得到的目标文件中的 section 分布情况,如 main_o_sections 所示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
./tmp/tmpxft_0000047f_00000000-20_main.o:     file format elf64-x86-64

Contents of section .group:
0000 01000000 06000000 ........

Contents of section .text:
0000 f30f1efa 554889e5 48897df8 488b45f8 ....UH..H.}.H.E.
0010 48890500 00000090 5dc3f30f 1efa5548 H.......].....UH
0020 89e55348 83ec7864 488b0425 28000000 ..SH..xdH..%(...
0030 488945e8 31c0c745 94400000 00c74598 H.E.1..E.@....E.
0040 80000000 c7459c20 00000048 c745a000 .....E. ...H.E..
0050 00000048 c745a800 00000048 c745b000 ...H.E.....H.E..
0060 00000048 c745b800 00000048 c745c000 ...H.E.....H.E..
0070 00000048 c745c800 0000008b 45940faf ...H.E......E...
0080 45984898 488d1485 00000000 488d45a0 E.H.H.......H.E.
0090 4889d648 89c7e800 00000048 8b45a048 H..H.......H.E.H
00a0 85c0751e 488d0500 00000048 89c7b800 ..u.H......H....
00b0 000000e8 00000000 b8ffffff ffe9a303 ................
00c0 00008b45 980faf45 9c489848 8d148500 ...E...E.H.H....
00d0 00000048 8d45a848 89d64889 c7e80000 ...H.E.H..H.....
00e0 0000488b 45a84885 c0751e48 8d050000 ..H.E.H..u.H....
00f0 00004889 c7b80000 0000e800 000000b8 ..H.............
0100 ffffffff e95c0300 008b4594 0faf459c .....\....E...E.
0110 4898488d 14850000 0000488d 45b04889 H.H.......H.E.H.
0120 d64889c7 e8000000 00488b45 b04885c0 .H.......H.E.H..
0130 751e488d 05000000 004889c7 b8000000 u.H......H......
0140 00e80000 0000b8ff ffffffe9 15030000 ................
0150 8b45940f af459848 9848c1e0 024889c7 .E...E.H.H...H..
0160 e8000000 00488945 b848837d b800751e .....H.E.H.}..u.
0170 488d0500 00000048 89c7b800 000000e8 H......H........
0180 00000000 b8ffffff ffe9d702 0000c745 ...............E
0190 8c000000 00eb208b 458c4898 488d1485 ...... .E.H.H...
01a0 00000000 488b45b8 488d1c02 e8000000 ....H.E.H.......
01b0 00890383 458c018b 45940faf 45983945 ....E...E...E.9E
01c0 8c7cd48b 45980faf 459c4898 48c1e002 .|..E...E.H.H...
01d0 4889c7e8 00000000 488945c0 48837dc0 H.......H.E.H.}.
01e0 00751e48 8d050000 00004889 c7b80000 .u.H......H.....
01f0 0000e800 000000b8 ffffffff e9640200 .............d..
0200 00c74590 00000000 eb208b45 90489848 ..E...... .E.H.H
0210 8d148500 00000048 8b45c048 8d1c02e8 .......H.E.H....
0220 00000000 89038345 90018b45 980faf45 .......E...E...E
0230 9c394590 7cd48b45 940faf45 9c489848 .9E.|..E...E.H.H
0240 c1e00248 89c7e800 00000048 8945c848 ...H.......H.E.H
0250 837dc800 751e488d 05000000 004889c7 .}..u.H......H..
0260 b8000000 00e80000 0000b8ff ffffffe9 ................
0270 f1010000 8b45940f af459848 98488d14 .....E...E.H.H..
0280 85000000 00488b45 a0488b75 b8b90100 .....H.E.H.u....
0290 00004889 c7e80000 00008b45 980faf45 ..H........E...E
02a0 9c489848 8d148500 00000048 8b45a848 .H.H.......H.E.H
02b0 8b75c0b9 01000000 4889c7e8 00000000 .u......H.......
02c0 488d45dc b9010000 00ba0100 0000be00 H.E.............
02d0 01000048 89c7e800 00000048 8d45d0b9 ...H.......H.E..
02e0 01000000 ba010000 00be0100 00004889 ..............H.
02f0 c7e80000 0000488b 45dc8b4d e44889ca ......H.E..M.H..
0300 488b7dd0 8b75d841 b9000000 0041b800 H.}..u.A.....A..
0310 00000048 89d14889 c2e80000 000085c0 ...H..H.........
0320 7524488b 55b0488b 75a8488b 45a0448b u$H.U.H.u.H.E.D.
0330 459c8b7d 988b4d94 4589c141 89f84889 E..}..M.E..A..H.
0340 c7e80000 00008b45 940faf45 9c489848 .......E...E.H.H
0350 8d148500 00000048 8b75b048 8b45c8b9 .......H.u.H.E..
0360 02000000 4889c7e8 00000000 488d45dc ....H.......H.E.
0370 b9010000 00ba0100 0000be00 01000048 ...............H
0380 89c7e800 00000048 8d45d0b9 01000000 .......H.E......
0390 ba010000 00be0100 00004889 c7e80000 ..........H.....
03a0 0000488b 45dc8b4d e44889ca 488b7dd0 ..H.E..M.H..H.}.
03b0 8b75d841 b9000000 0041b800 00000048 .u.A.....A.....H
03c0 89d14889 c2e80000 000085c0 7524488b ..H.........u$H.
03d0 55b0488b 75a8488b 45a0448b 459c8b7d U.H.u.H.E.D.E..}
03e0 988b4d94 4589c141 89f84889 c7e80000 ..M.E..A..H.....
03f0 00008b45 940faf45 9c489848 8d148500 ...E...E.H.H....
0400 00000048 8b75b048 8b45c8b9 02000000 ...H.u.H.E......
0410 4889c7e8 00000000 488b45a0 4889c7e8 H.......H.E.H...
0420 00000000 488b45a8 4889c7e8 00000000 ....H.E.H.......
0430 488b45b0 4889c7e8 00000000 488b45b8 H.E.H.......H.E.
0440 4889c7e8 00000000 488b45c0 4889c7e8 H.......H.E.H...
0450 00000000 488b45c8 4889c7e8 00000000 ....H.E.H.......
0460 b8000000 00488b55 e864482b 14252800 .....H.U.dH+.%(.
0470 00007405 e8000000 00488b5d f8c9c3f3 ..t......H.]....
0480 0f1efa55 4889e548 897df848 8b45f848 ...UH..H.}.H.E.H
0490 89050000 00005dc3 f30f1efa 554889e5 ......].....UH..
04a0 488d0500 00000048 89c7e8d0 ffffff48 H......H.......H
04b0 8b050000 00004889 c7e80000 0000905d ......H........]
04c0 c3f30f1e fa554889 e54883ec 1048897d .....UH..H...H.}
04d0 f8488b45 f84889c7 e8000000 00c9c3f3 .H.E.H..........
04e0 0f1efa55 4889e548 83ec0848 897df848 ...UH..H...H.}.H
04f0 8b45f848 89050000 0000488b 45f84889 .E.H......H.E.H.
0500 c7e8fafa ffff90c9 c3f30f1e fa554889 .............UH.
0510 e54883ec 10488d05 00000000 4889c7e8 .H...H......H...
0520 00000000 48890500 00000048 8d05adff ....H......H....
0530 ffff4889 45f8488b 55f8488b 05000000 ..H.E.H.U.H.....
0540 004889c7 ffd2488b 05000000 004889c7 .H....H......H..
0550 e8000000 00488d05 3cffffff 4889c7e8 .....H..<...H...
0560 00000000 90c9c3 .......

Contents of section .text._ZN4dim3C2Ejjj:
0000 f30f1efa 554889e5 48897df8 8975f489 ....UH..H.}..u..
0010 55f0894d ec488b45 f88b55f4 8910488b U..M.H.E..U...H.
0020 45f88b55 f0895004 488b45f8 8b55ec89 E..U..P.H.E..U..
0030 5008905d c3 P..].

Contents of section .rodata:
0000 47505520 616c6c6f 63206661 696c0043 GPU alloc fail.C
0010 50552061 6c6c6f63 20666169 6c00 PU alloc fail.

Contents of section __nv_module_id:
0000 5f5f4e56 5f4d4f44 554c455f 494400 __NV_MODULE_ID.

Contents of section .nv_fatbin:
0000 50ed55ba 01001000 58060000 00000000 P.U.....X.......
0010 02000101 40000000 b0020000 00000000 ....@...........
0020 00000000 00000000 07000100 5a000000 ............Z...
0030 00000000 00000000 11000000 00000000 ................
0040 00000000 00000000 00000000 00000000 ................
0050 7f454c46 02010133 07000000 00000000 .ELF...3........
0060 0200be00 79000000 00000000 00000000 ....y...........
0070 40020000 00000000 00010000 00000000 @...............
0080 5a055900 40003800 02004000 05000100 Z.Y.@.8...@.....
0090 002e7368 73747274 6162002e 73747274 ..shstrtab..strt
00a0 6162002e 73796d74 6162002e 73796d74 ab..symtab..symt
00b0 61625f73 686e6478 002e6e76 2e756674 ab_shndx..nv.uft
00c0 2e656e74 7279002e 6e762e69 6e666f00 .entry..nv.info.
00d0 2e646562 75675f66 72616d65 00000000 .debug_frame....
00e0 00000000 00000000 00000000 2e736873 .............shs
00f0 74727461 62002e73 74727461 62002e73 trtab..strtab..s
0100 796d7461 62002e73 796d7461 625f7368 ymtab..symtab_sh
0110 6e647800 2e6e762e 7566742e 656e7472 ndx..nv.uft.entr
0120 79002e6e 762e696e 666f002e 64656275 y..nv.info..debu
0130 675f6672 616d6500 00000000 00000000 g_frame.........
0140 00000000 00000000 00000000 00000000 ................
0150 00000000 00000000 00000000 00000000 ................
0160 00000000 00000000 00000000 00000000 ................
0170 00000000 00000000 00000000 00000000 ................
0180 00000000 00000000 00000000 00000000 ................
0190 01000000 03000000 00000000 00000000 ................
01a0 00000000 00000000 40000000 00000000 ........@.......
01b0 4d000000 00000000 00000000 00000000 M...............
01c0 01000000 00000000 00000000 00000000 ................
01d0 0b000000 03000000 00000000 00000000 ................
01e0 00000000 00000000 9b000000 00000000 ................
01f0 4d000000 00000000 00000000 00000000 M...............
0200 01000000 00000000 00000000 00000000 ................
0210 13000000 02000000 00000000 00000000 ................
0220 00000000 00000000 e8000000 00000000 ................
0230 18000000 00000000 02000000 01000000 ................
0240 08000000 00000000 18000000 00000000 ................
0250 40000000 01000000 00000000 00000000 @...............
0260 00000000 00000000 00010000 00000000 ................
0270 00000000 00000000 00000000 00000000 ................
0280 01000000 00000000 00000000 00000000 ................
0290 06000000 04000000 40020000 00000000 ........@.......
02a0 00000000 00000000 00000000 00000000 ................
02b0 70000000 00000000 70000000 00000000 p.......p.......
02c0 08000000 00000000 01000000 04000000 ................
02d0 40020000 00000000 00000000 00000000 @...............
02e0 00000000 00000000 70000000 00000000 ........p.......
02f0 70000000 00000000 08000000 00000000 p...............
0300 02000101 40000000 a8020000 00000000 ....@...........
0310 00000000 00000000 07000100 59000000 ............Y...
0320 00000000 00000000 11000000 00000000 ................
0330 00000000 00000000 00000000 00000000 ................
0340 7f454c46 02010133 07000000 00000000 .ELF...3........
0350 0200be00 79000000 00000000 00000000 ....y...........
0360 38020000 00000000 f8000000 00000000 8...............
0370 59055900 40003800 02004000 05000100 Y.Y.@.8...@.....
0380 002e7368 73747274 6162002e 73747274 ..shstrtab..strt
0390 6162002e 73796d74 6162002e 73796d74 ab..symtab..symt
03a0 61625f73 686e6478 002e6e76 2e756674 ab_shndx..nv.uft
03b0 2e656e74 7279002e 6e762e69 6e666f00 .entry..nv.info.
03c0 2e646562 75675f66 72616d65 00002e73 .debug_frame...s
03d0 68737472 74616200 2e737472 74616200 hstrtab..strtab.
03e0 2e73796d 74616200 2e73796d 7461625f .symtab..symtab_
03f0 73686e64 78002e6e 762e7566 742e656e shndx..nv.uft.en
0400 74727900 2e6e762e 696e666f 002e6465 try..nv.info..de
0410 6275675f 6672616d 65000000 00000000 bug_frame.......
0420 00000000 00000000 00000000 00000000 ................
0430 00000000 00000000 00000000 00000000 ................
0440 00000000 00000000 00000000 00000000 ................
0450 00000000 00000000 00000000 00000000 ................
0460 00000000 00000000 00000000 00000000 ................
0470 00000000 00000000 01000000 03000000 ................
0480 00000000 00000000 00000000 00000000 ................
0490 40000000 00000000 4d000000 00000000 @.......M.......
04a0 00000000 00000000 01000000 00000000 ................
04b0 00000000 00000000 0b000000 03000000 ................
04c0 00000000 00000000 00000000 00000000 ................
04d0 8d000000 00000000 4d000000 00000000 ........M.......
04e0 00000000 00000000 01000000 00000000 ................
04f0 00000000 00000000 13000000 02000000 ................
0500 00000000 00000000 00000000 00000000 ................
0510 e0000000 00000000 18000000 00000000 ................
0520 02000000 01000000 08000000 00000000 ................
0530 18000000 00000000 40000000 01000000 ........@.......
0540 00000000 00000000 00000000 00000000 ................
0550 f8000000 00000000 00000000 00000000 ................
0560 00000000 00000000 01000000 00000000 ................
0570 00000000 00000000 06000000 05000000 ................
0580 38020000 00000000 00000000 00000000 8...............
0590 00000000 00000000 70000000 00000000 ........p.......
05a0 70000000 00000000 08000000 00000000 p...............
05b0 01000000 05000000 38020000 00000000 ........8.......
05c0 00000000 00000000 00000000 00000000 ................
05d0 70000000 00000000 70000000 00000000 p.......p.......
05e0 08000000 00000000 01000101 48000000 ............H...
05f0 38000000 00000000 36000000 40000000 8.......6...@...
0600 01000800 59000000 00000000 00000000 ....Y...........
0610 11200000 00000000 00000000 00000000 . ..............
0620 38000000 00000000 00000000 00000000 8...............
0630 130a0100 f0212e76 65727369 6f6e2038 .....!.version 8
0640 2e310a2e 74617267 65742073 6d5f3839 .1..target sm_89
0650 0a2e6164 64726573 735f7369 7a652036 ..address_size 6
0660 340a0a0a 0a000000 4.......

Contents of section .nvFatBinSegment:
0000 b1436246 01000000 00000000 00000000 .CbF............
0010 00000000 00000000 ........

Contents of section .init_array:
0000 00000000 00000000 ........

Contents of section .comment:
0000 00474343 3a202855 62756e74 75203131 .GCC: (Ubuntu 11
0010 2e342e30 2d317562 756e7475 317e3232 .4.0-1ubuntu1~22
0020 2e303429 2031312e 342e3000 .04) 11.4.0.

Contents of section .note.gnu.property:
0000 04000000 10000000 05000000 474e5500 ............GNU.
0010 020000c0 04000000 03000000 00000000 ................

Contents of section .eh_frame:
0000 14000000 00000000 017a5200 01781001 .........zR..x..
0010 1b0c0708 90010000 1c000000 1c000000 ................
0020 00000000 1a000000 00450e10 8602430d .........E....C.
0030 06510c07 08000000 1c000000 3c000000 .Q..........<...
0040 00000000 35000000 00450e10 8602430d ....5....E....C.
0050 066c0c07 08000000 20000000 5c000000 .l...... ...\...
0060 00000000 65040000 00450e10 8602430d ....e....E....C.
0070 06458303 0357040c 07080000 1c000000 .E...W..........
0080 80000000 00000000 19000000 00450e10 .............E..
0090 8602430d 06500c07 08000000 1c000000 ..C..P..........
00a0 a0000000 00000000 29000000 00450e10 ........)....E..
00b0 8602430d 06600c07 08000000 1c000000 ..C..`..........
00c0 c0000000 00000000 1e000000 00450e10 .............E..
00d0 8602430d 06550c07 08000000 1c000000 ..C..U..........
00e0 e0000000 00000000 2a000000 00450e10 ........*....E..
00f0 8602430d 06610c07 08000000 20000000 ..C..a...... ...
0100 00010000 00000000 5e000000 00450e10 ........^....E..
0110 8602430d 0602550c 07080000 00000000 ..C...U.........

    至此,我们完成了从 .cu 文件到 .o 文件的分析过程,但是由于我们在 main.cu 并没有定义 kernel,因此有一些细节我们实际上没能观察到,compile_gemm_kernel_1 中我们将以 gemm_kernel_1.cu 的编译过程为例,展示其与 main.cu 编译过程的区别。

gemm_kernel_1.cu 的编译

    由于 gemm_kernel_1.cugemm_kernel_2.cu 两个文件的编译过程是完全一样的,所以以下只以 gemm_kernel_1.cu 为例。

    cudafe1.stub.c 文件的内容如 stub 所示,Line 27~32 我们可以看见 void sum(int *, int *, int *) 接口的进一步定义,其实际上调用了在同文件下定义的 __device_stub__Z3sumPiS_S_ 函数,Line 17~25 展示了后者的实现细节,其使用了 __cudaLaunchPrologue__cudaSetupArgSimple__cudaLaunch 三个宏,这三个宏的定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#define __cudaLaunchPrologue(size) \
void * __args_arr[size]; \
int __args_idx = 0

#define __cudaSetupArg(arg, offset) \
__args_arr[__args_idx] = (void *)__cudaAddressOf(arg); ++__args_idx

#define __cudaSetupArgSimple(arg, offset) \
__args_arr[__args_idx] = (void *)(char *)&arg; ++__args_idx

/* the use of __args_idx in the expression below avoids host compiler warning about it being an
unused variable when the launch has no arguments */
#define __cudaLaunch(fun) \
{ volatile static char *__f __NV_ATTR_UNUSED_FOR_LAUNCH; __f = fun; \
dim3 __gridDim, __blockDim;\
size_t __sharedMem; \
cudaStream_t __stream; \
if (__cudaPopCallConfiguration(&__gridDim, &__blockDim, &__sharedMem, &__stream) != cudaSuccess) \
return; \
if (__args_idx == 0) {\
(void)cudaLaunchKernel(fun, __gridDim, __blockDim, &__args_arr[__args_idx], __sharedMem, __stream);\
} else { \
(void)cudaLaunchKernel(fun, __gridDim, __blockDim, &__args_arr[0], __sharedMem, __stream);\
}\
}

    因此,void sum(int *, int *, int *) 接口的行为实际上是,先使用 __cudaLaunchPrologue 宏,初始化出一个长度与 kernel 参数列表长度相同的数组,然后使用 __cudaSetupArgSimple 宏向数组中填充参数,最后调用 __cudaLaunch 宏启动 kernel。__cudaLaunch 中则实际上显示调用 __cudaPopCallConfiguration,将我们在 cudafe1_cpp 中看到的使用 __cudaPushCallConfiguration 压入 CUDA Runtime 中维护的某个 stack 的 kernel 调用参数 (i.e., gridDim, blockDim, sharedMem, stream) 重新弹出来,最后再调用 cudaLaunchKernel 启动对应的 kernel。

    最后在 stub 中还有值得深挖的一层: 在 Line 15 中我们可以看到有一个标记了 __attribute__((__constructor__)) 的函数 __sti____cudaRegisterAll,也就是说它在程序启动时就会被自动执行。这个函数的定义可以在 Line 46~48 中找到,可以看到其实际上调用了 __cudaRegisterBinary,根据名字推测,其用于向 device 注册 cuda 二进制,其实际上调用了 __cudaRegisterEntry (Line 38) 完成了对 sum 这个 kernel 的注册工作。

    这里有一个有趣的地方:上面我们看到的 cudaLaunchKernel__cudaRegisterEntry 接口,它都接受了 void sum(int *, int *, int *) 接口的函数地址作为参数。细心的读者可能会有疑问:void sum(int *, int *, int *) 接口不是给 Host 侧 C++ 代码使用的么?为什么注册和发射 kernel 也用它来作为参数?实际上此时 void sum(int *, int *, int *) 的函数指针是被 CUDA Runtime 当作 key 来使用的 stackoverflow_key,在 Line 38~43 我们可以看见 __cudaRegisterEntry 的参数列表中有两个重要参数: 一个是 void sum(int *, int *, int *) 的函数指针,另一个名为 _Z3sumPiS_S_ 的名称我们可以在 main_ptx 的 Line 15 中看到,其是 sum kernel 对应的 PTX entry 的名字。__cudaRegisterEntry 这里相当于在内部维护了一个映射:当 cudaLaunchKernel 尝试使用 void sum(int *, int *, int *) 的函数指针发射 kernel 时,Runtime 就知道用户想要运行的是 _Z3sumPiS_S_ 对应的程序了。

合并多个 .cu 文件

链接

1
2
3
4
5
6
7
8
9
nvlink  -m64 --arch=sm_52                                                                   \
--register-link-binaries="/tmp/tmpxft_0000356e_00000000-7_staticthread_dlink.reg.c" \
"-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib/stubs" \
"-L/usr/local/cuda-12.1/bin/../targets/x86_64-linux/lib" \
-cpu-arch=X86_64 \
"./tmp/tmpxft_0000356e_00000000-11_main.o" \
-lcudadevrt \
-o "./tmp/tmpxft_0000356e_00000000-12_staticthread_dlink.sm_52.cubin" \
--host-ccbin "gcc"

    运行了上述命令后,文件夹结构如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
.
# ├── main.cu
└── tmp
# ├── tmpxft_0000356e_00000000-10_main.sm_52.cubin
# ├── tmpxft_0000356e_00000000-11_main.o
├── tmpxft_0000356e_00000000-12_staticthread_dlink.sm_52.cubin
# ├── tmpxft_0000356e_00000000-3_main.fatbin.c
# ├── tmpxft_0000356e_00000000-4_main.module_id
# ├── tmpxft_0000356e_00000000-5_main.cpp4.ii
# ├── tmpxft_0000356e_00000000-6_main.cudafe1.c
# ├── tmpxft_0000356e_00000000-6_main.cudafe1.cpp
# ├── tmpxft_0000356e_00000000-6_main.cudafe1.gpu
# ├── tmpxft_0000356e_00000000-6_main.cudafe1.stub.c
# ├── tmpxft_0000356e_00000000-6_main.ptx
├── tmpxft_0000356e_00000000-7_staticthread_dlink.reg.c
# └── tmpxft_0000356e_00000000-9_main.cpp1.ii

    新增加的文件中,staticthread_dlink.sm_52.cubin 包含了 nvlink 输入的所有目标文件中 Device 侧程序 (i.e., PTX 和 SASS) 的集合,我们此处只有一个目标文件,因此 staticthread_dlink.sm_52.cubin 只包含了前面我们处理过的 main.cu 中包含的 Device 侧程序;staticthread_dlink.reg.c

1
#define NUM_PRELINKED_OBJECTS 0