In [26]:
!mkdir lecture19demo
mkdir: cannot create directory ‘lecture19demo’: File exists
In [12]:
%cd lecture19demo
/content/lecture19demo
In [13]:
%%writefile setup.py
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name="lecture19_cuda_extension",
ext_modules=[
cpp_extension.CUDAExtension(
"lecture19_cuda_extension",
["wrapper.cpp","kernels.cu"]
)],
cmdclass={'build_ext': cpp_extension.BuildExtension}
)
Overwriting setup.py
In [14]:
%%writefile wrapper.cpp
#include <torch/extension.h>
void saxpy(
float a,
torch::Tensor &x,
torch::Tensor &y,
torch::Tensor &out
);
void haxpy(
float a,
torch::Tensor &x,
torch::Tensor &y,
torch::Tensor &out
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("saxpy", &saxpy, "saxpy");
m.def("haxpy", &haxpy, "haxpy");
}
Overwriting wrapper.cpp
In [15]:
%%writefile kernels.cu
#include <iostream>
#include <cassert>
#include <vector>
#include <utility>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/types.h>
#include <torch/extension.h>
using namespace torch::indexing;
#define FULL_MASK 0xffffffff
#define HALF_MASK 0x0000ffff
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) do { CHECK_CUDA(x); CHECK_CONTIGUOUS(x); } while(false)
#define gpuErrchk(ans) do { gpuAssert((ans), __FILE__, __LINE__); } while (false)
__global__ void saxpy_kernel(
float a,
const float* X,
const float* Y,
float* out
) {
size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
out[tid] = a * X[tid] + Y[tid];
}
void saxpy(
float a,
torch::Tensor &x,
torch::Tensor &y,
torch::Tensor &out
) {
size_t N = x.numel();
assert(x.dtype() == torch::kFloat32);
assert(y.dtype() == torch::kFloat32);
assert(out.dtype() == torch::kFloat32);
assert(x.sizes() == y.sizes());
assert(x.sizes() == out.sizes());
assert(N % 32 == 0);
const dim3 threads(32);
const dim3 blocks(N/32);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
saxpy_kernel<<<blocks, threads, 0, stream>>>(
a,
x.data_ptr<float>(),
y.data_ptr<float>(),
out.data_ptr<float>()
);
}
__global__ void haxpy_kernel(
__half2 a,
const __half2* __restrict__ X,
const __half2* __restrict__ Y,
__half2* __restrict__ out
) {
size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
out[tid] = __hfma2(a, X[tid], Y[tid]);
}
void haxpy(
float a,
torch::Tensor &x,
torch::Tensor &y,
torch::Tensor &out
) {
size_t N = x.numel();
assert(x.dtype() == torch::kFloat16);
assert(y.dtype() == torch::kFloat16);
assert(out.dtype() == torch::kFloat16);
assert(x.sizes() == y.sizes());
assert(x.sizes() == out.sizes());
assert(N % 64 == 0);
// convert a to half
__half2 ah = __float2half2_rn(a);
const dim3 threads(32);
const dim3 blocks(N/64);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
haxpy_kernel<<<blocks, threads, 0, stream>>>(
ah,
(const __half2*)x.data_ptr<at::Half>(),
(const __half2*)y.data_ptr<at::Half>(),
(__half2*)out.data_ptr<at::Half>()
);
}
Overwriting kernels.cu
In [16]:
!python3 setup.py build
running build running build_ext W1104 01:20:51.300000 28079 torch/utils/cpp_extension.py:615] Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend. W1104 01:20:51.305000 28079 torch/utils/cpp_extension.py:507] The detected CUDA version (12.5) has a minor version mismatch with the version that was used to compile PyTorch (12.6). Most likely this shouldn't be a problem. W1104 01:20:51.305000 28079 torch/utils/cpp_extension.py:517] There are no x86_64-linux-gnu-g++ version bounds defined for CUDA version 12.5 building 'lecture19_cuda_extension' extension W1104 01:20:51.392000 28079 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. W1104 01:20:51.392000 28079 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures. /usr/local/cuda/bin/nvcc -I/usr/local/lib/python3.12/dist-packages/torch/include -I/usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/usr/include/python3.12 -c kernels.cu -o build/temp.linux-x86_64-cpython-312/kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -DTORCH_EXTENSION_NAME=lecture19_cuda_extension -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 -std=c++17 x86_64-linux-gnu-g++ -fno-strict-overflow -Wsign-compare -DNDEBUG -g -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/usr/local/lib/python3.12/dist-packages/torch/include -I/usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/usr/include/python3.12 -c wrapper.cpp -o build/temp.linux-x86_64-cpython-312/wrapper.o -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -DTORCH_EXTENSION_NAME=lecture19_cuda_extension -std=c++17 x86_64-linux-gnu-g++ -fno-strict-overflow -Wsign-compare -DNDEBUG -g -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions -Wl,-Bsymbolic-functions -g -fwrapv -O2 build/temp.linux-x86_64-cpython-312/kernels.o build/temp.linux-x86_64-cpython-312/wrapper.o -L/usr/local/lib/python3.12/dist-packages/torch/lib -L/usr/local/cuda/lib64 -L/usr/lib/x86_64-linux-gnu -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-cpython-312/lecture19_cuda_extension.cpython-312-x86_64-linux-gnu.so
In [18]:
!python3 setup.py install
running install
/usr/local/lib/python3.12/dist-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!
********************************************************************************
Please avoid running ``setup.py`` directly.
Instead, use pypa/build, pypa/installer or other
standards-based tools.
See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
********************************************************************************
!!
self.initialize_options()
/usr/local/lib/python3.12/dist-packages/setuptools/_distutils/cmd.py:66: EasyInstallDeprecationWarning: easy_install command is deprecated.
!!
********************************************************************************
Please avoid running ``setup.py`` and ``easy_install``.
Instead, use pypa/build, pypa/installer or other
standards-based tools.
See https://github.com/pypa/setuptools/issues/917 for details.
********************************************************************************
!!
self.initialize_options()
running bdist_egg
running egg_info
writing lecture19_cuda_extension.egg-info/PKG-INFO
writing dependency_links to lecture19_cuda_extension.egg-info/dependency_links.txt
writing top-level names to lecture19_cuda_extension.egg-info/top_level.txt
W1104 01:23:41.619000 28844 torch/utils/cpp_extension.py:615] Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
reading manifest file 'lecture19_cuda_extension.egg-info/SOURCES.txt'
writing manifest file 'lecture19_cuda_extension.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
W1104 01:23:41.668000 28844 torch/utils/cpp_extension.py:507] The detected CUDA version (12.5) has a minor version mismatch with the version that was used to compile PyTorch (12.6). Most likely this shouldn't be a problem.
W1104 01:23:41.668000 28844 torch/utils/cpp_extension.py:517] There are no x86_64-linux-gnu-g++ version bounds defined for CUDA version 12.5
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-cpython-312/lecture19_cuda_extension.cpython-312-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for lecture19_cuda_extension.cpython-312-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/lecture19_cuda_extension.py to lecture19_cuda_extension.cpython-312.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying lecture19_cuda_extension.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lecture19_cuda_extension.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lecture19_cuda_extension.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lecture19_cuda_extension.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.lecture19_cuda_extension.cpython-312: module references __file__
creating 'dist/lecture19_cuda_extension-0.0.0-py3.12-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing lecture19_cuda_extension-0.0.0-py3.12-linux-x86_64.egg
removing '/usr/local/lib/python3.12/dist-packages/lecture19_cuda_extension-0.0.0-py3.12-linux-x86_64.egg' (and everything under it)
creating /usr/local/lib/python3.12/dist-packages/lecture19_cuda_extension-0.0.0-py3.12-linux-x86_64.egg
Extracting lecture19_cuda_extension-0.0.0-py3.12-linux-x86_64.egg to /usr/local/lib/python3.12/dist-packages
Adding lecture19-cuda-extension 0.0.0 to easy-install.pth file
Installed /usr/local/lib/python3.12/dist-packages/lecture19_cuda_extension-0.0.0-py3.12-linux-x86_64.egg
Processing dependencies for lecture19-cuda-extension==0.0.0
Finished processing dependencies for lecture19-cuda-extension==0.0.0
In [19]:
%cd ..
/content
In [1]:
import torch
import lecture19_cuda_extension
In [2]:
X = torch.randn(1024,1024,device='cuda')
Y = torch.randn(1024,1024,device='cuda')
out = torch.zeros(1024,1024,device='cuda')
In [3]:
0.3 * X + Y
Out[3]:
tensor([[ 2.3551, 0.8725, 1.1007, ..., -0.5701, -0.4147, -0.9135],
[-1.7619, 2.0183, 0.1907, ..., -0.4397, 0.6221, -1.5504],
[ 0.3077, 1.0810, -0.8518, ..., 1.9129, -0.3097, -0.3331],
...,
[-0.5927, 0.2277, -0.7137, ..., -0.1314, 0.7454, 0.0859],
[ 0.7176, -1.3972, 1.1140, ..., 0.5322, -0.2436, 0.0596],
[ 0.4437, 0.3978, -0.8556, ..., 1.6715, -0.4643, 0.7062]],
device='cuda:0')
In [4]:
lecture19_cuda_extension.saxpy(0.3,X,Y,out)
In [5]:
out
Out[5]:
tensor([[ 2.3551, 0.8725, 1.1007, ..., -0.5701, -0.4147, -0.9135],
[-1.7619, 2.0183, 0.1907, ..., -0.4397, 0.6221, -1.5504],
[ 0.3077, 1.0810, -0.8518, ..., 1.9129, -0.3097, -0.3331],
...,
[-0.5927, 0.2277, -0.7137, ..., -0.1314, 0.7454, 0.0859],
[ 0.7176, -1.3972, 1.1140, ..., 0.5322, -0.2436, 0.0596],
[ 0.4437, 0.3978, -0.8556, ..., 1.6715, -0.4643, 0.7062]],
device='cuda:0')
In [7]:
out - (0.3 * X + Y)
Out[7]:
tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, -1.4901e-08, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
...,
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 7.4506e-09],
[ 5.9605e-08, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 2.6077e-08],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -1.1921e-07,
0.0000e+00, 0.0000e+00]], device='cuda:0')
In [6]:
# result is very close to what torch would give us... but why is there a difference?
(out - (0.3 * X + Y)).square().sum() / out.square().sum()
Out[6]:
tensor(6.2965e-16, device='cuda:0')
In [8]:
# if we compare to the output at double-precision compute, there's zero error! why?
(out - (torch.tensor(0.3,dtype=torch.float32).double() * X.double() + Y.double()).float()).square().sum() / out.square().sum()
Out[8]:
tensor(0., device='cuda:0')
In [2]:
# what about half precision
X = torch.randn(1024,1024,device='cuda',dtype=torch.float16)
Y = torch.randn(1024,1024,device='cuda',dtype=torch.float16)
out = torch.zeros(1024,1024,device='cuda',dtype=torch.float16)
In [3]:
lecture19_cuda_extension.haxpy(0.3,X,Y,out)
In [4]:
out
Out[4]:
tensor([[-0.2129, 0.0614, -0.5029, ..., -0.1725, 0.9058, -2.8555],
[ 1.8701, 0.9106, 0.2498, ..., 1.2354, -0.6157, 1.1787],
[ 0.7583, -1.9717, 0.3303, ..., -0.3806, -2.0820, -1.7461],
...,
[ 1.3398, 0.1453, -0.2715, ..., 0.6475, 0.6055, 0.9404],
[-1.0254, -0.4341, 2.1426, ..., -0.8857, -0.2075, -0.2971],
[ 2.0820, 0.6890, 0.5845, ..., -1.3691, -0.2169, 0.8159]],
device='cuda:0', dtype=torch.float16)
In [5]:
0.3 * X + Y
Out[5]:
tensor([[-0.2129, 0.0615, -0.5029, ..., -0.1726, 0.9058, -2.8555],
[ 1.8701, 0.9102, 0.2496, ..., 1.2354, -0.6157, 1.1787],
[ 0.7583, -1.9727, 0.3303, ..., -0.3806, -2.0820, -1.7461],
...,
[ 1.3398, 0.1451, -0.2715, ..., 0.6475, 0.6055, 0.9404],
[-1.0254, -0.4341, 2.1426, ..., -0.8857, -0.2073, -0.2974],
[ 2.0820, 0.6890, 0.5845, ..., -1.3691, -0.2168, 0.8159]],
device='cuda:0', dtype=torch.float16)
In [9]:
(out - (0.3 * X + Y)).float().square().sum() / out.float().square().sum()
Out[9]:
tensor(4.7860e-08, device='cuda:0')
In [7]:
(out - (torch.tensor(0.3,dtype=torch.float16).float() * X.float() + Y.float()).to(torch.float16)).float().square().sum() / out.float.square().sum()
Out[7]:
tensor(0., device='cuda:0')
In [9]:
In [ ]: