Skip to content

Commit decda87

Browse files
committed
Feat: add address spaces for user defined statics
1 parent f79889a commit decda87

File tree

11 files changed

+169
-66
lines changed

11 files changed

+169
-66
lines changed

crates/cuda_std_macros/src/lib.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,31 @@ pub fn externally_visible(
208208

209209
func.into_token_stream().into()
210210
}
211+
212+
/// Notifies the codegen to put a `static`/`static mut` inside of a specific memory address space.
213+
/// This is mostly for internal use and/or advanced users, as the codegen and `cuda_std` handle address space placement
214+
/// implicitly. **Improper use of this macro could yield weird or undefined behavior**.
215+
///
216+
/// This macro takes a single argument which can either be `global`, `shared`, `constant`, or `local`.
217+
///
218+
/// This macro does nothing on the CPU.
219+
#[proc_macro_attribute]
220+
pub fn address_space(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
221+
let mut global = syn::parse_macro_input!(item as syn::ItemStatic);
222+
let input = syn::parse_macro_input!(attr as Ident);
223+
224+
let addrspace_num = match input.to_string().as_str() {
225+
"global" => 1,
226+
// what did you do to address space 2 libnvvm??
227+
"shared" => 3,
228+
"constant" => 4,
229+
"local" => 5,
230+
addr => panic!("Invalid address space `{}`", addr),
231+
};
232+
233+
let new_attr =
234+
parse_quote!(#[cfg_attr(target_os = "cuda", nvvm_internal(addrspace(#addrspace_num)))]);
235+
global.attrs.push(new_attr);
236+
237+
global.into_token_stream().into()
238+
}

crates/rustc_codegen_nvvm/CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@ Notable changes to this project will be documented in this file.
44

55
## Unreleased
66

7+
### Address Spaces
8+
9+
CUDA Address Spaces have been mostly implemented. Statics that are not mut statics and do not rely on
10+
interior mutability (are "freeze" types) are placed in constant memory (`__constant__` in CUDA C++), otherwise
11+
they are placed in global memory (`__global__`). Currently this only happens for user-defined statics, not for
12+
codegen-internal globals such as intermediate alloc globals.
13+
14+
An `#[address_space(...)]` macro has been added to cuda_std to complement this change. However, this macro
15+
is mostly just for advanced users and internal support for things like shared memory. Improper use can
16+
cause undefined behavior, so its use is generally discouraged.
17+
718
### Dead Code Elimination
819

920
PTX files no longer include useless functions and globals, we have switched to an alternative

crates/rustc_codegen_nvvm/rustc_llvm_wrapper/RustWrapper.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,22 +187,20 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertFunction(LLVMModuleRef M,
187187
}
188188

189189
extern "C" LLVMValueRef
190-
LLVMRustGetOrInsertGlobal(LLVMModuleRef M, const char *Name, LLVMTypeRef Ty, unsigned AddressSpace)
191-
{
192-
// Module *Mod = unwrap(M);
193-
// GlobalVariable *GV = dyn_cast_or_null<GlobalVariable>(Mod->getNamedValue(Name));
194-
// if (!GV)
195-
// {
196-
// GV = new GlobalVariable(unwrap(Ty), false, GlobalValue::ExternalLinkage,
197-
// nullptr, Name, GlobalValue::NotThreadLocal, AddressSpace);
198-
// }
199-
// Type *GVTy = GV->getType();
200-
// PointerType *PTy = PointerType::get(unwrap(Ty), GVTy->getPointerAddressSpace());
201-
// if (GVTy != PTy)
202-
// return wrap(ConstantExpr::getBitCast(GV, PTy));
203-
204-
return wrap(unwrap(M)->getOrInsertGlobal(Name, unwrap(Ty)));
205-
// return wrap(GV);
190+
LLVMRustGetOrInsertGlobal(LLVMModuleRef M, const char *Name, size_t NameLen, LLVMTypeRef Ty, unsigned AddressSpace)
191+
{
192+
Module *Mod = unwrap(M);
193+
StringRef NameRef(Name, NameLen);
194+
195+
// We don't use Module::getOrInsertGlobal because that returns a Constant*,
196+
// which may either be the real GlobalVariable*, or a constant bitcast of it
197+
// if our type doesn't match the original declaration. We always want the
198+
// GlobalVariable* so we can access linkage, visibility, etc.
199+
GlobalVariable *GV = Mod->getGlobalVariable(NameRef, true);
200+
if (!GV)
201+
GV = new GlobalVariable(*Mod, unwrap(Ty), false,
202+
GlobalValue::ExternalLinkage, nullptr, NameRef, nullptr, GlobalValue::NotThreadLocal, AddressSpace);
203+
return wrap(GV);
206204
}
207205

208206
extern "C" LLVMTypeRef LLVMRustMetadataTypeInContext(LLVMContextRef C)

crates/rustc_codegen_nvvm/src/attributes.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::llvm::{self, AttributePlace::*, Value};
2-
use rustc_ast::Attribute;
2+
use rustc_ast::{Attribute, Lit, LitKind};
33
use rustc_attr::{InlineAttr, OptimizeAttr};
44
use rustc_middle::{middle::codegen_fn_attrs::CodegenFnAttrFlags, ty};
55
use rustc_session::{config::OptLevel, Session};
@@ -93,13 +93,15 @@ pub(crate) fn from_fn_attrs<'ll, 'tcx>(
9393
pub struct Symbols {
9494
pub nvvm_internal: Symbol,
9595
pub kernel: Symbol,
96+
pub addrspace: Symbol,
9697
}
9798

9899
// inspired by rust-gpu's attribute handling
99100
#[derive(Default, Clone, PartialEq)]
100101
pub(crate) struct NvvmAttributes {
101102
pub kernel: bool,
102103
pub used: bool,
104+
pub addrspace: Option<u8>,
103105
}
104106

105107
impl NvvmAttributes {
@@ -116,6 +118,21 @@ impl NvvmAttributes {
116118
if arg.has_name(sym::used) {
117119
nvvm_attrs.used = true;
118120
}
121+
if arg.has_name(cx.symbols.addrspace) {
122+
let args = arg.meta_item_list().unwrap_or_default();
123+
if let Some(arg) = args.first() {
124+
let lit = arg.literal();
125+
if let Some(Lit {
126+
kind: LitKind::Int(val, _),
127+
..
128+
}) = lit
129+
{
130+
nvvm_attrs.addrspace = Some(*val as u8);
131+
} else {
132+
panic!();
133+
}
134+
}
135+
}
119136
}
120137
}
121138
}

crates/rustc_codegen_nvvm/src/builder.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use rustc_middle::ty::layout::{
1717
use rustc_middle::ty::{self, Ty, TyCtxt};
1818
use rustc_span::Span;
1919
use rustc_target::abi::call::FnAbi;
20-
use rustc_target::abi::{self, Align, Size, WrappingRange};
20+
use rustc_target::abi::{self, AddressSpace, Align, Size, WrappingRange};
2121
use rustc_target::spec::{HasTargetSpec, Target};
2222
use std::borrow::Cow;
2323
use std::ffi::{CStr, CString};
@@ -870,6 +870,21 @@ impl<'ll, 'tcx, 'a> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
870870
fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
871871
trace!("Bitcast `{:?}` to ty `{:?}`", val, dest_ty);
872872
unsafe {
873+
let ty = llvm::LLVMRustGetValueType(val);
874+
let kind = llvm::LLVMRustGetTypeKind(ty);
875+
if kind == llvm::TypeKind::Pointer {
876+
let element = llvm::LLVMGetElementType(ty);
877+
let addrspace = llvm::LLVMGetPointerAddressSpace(ty);
878+
let new_ty = self.type_ptr_to_ext(element, AddressSpace::DATA);
879+
if addrspace != 0 {
880+
return llvm::LLVMBuildAddrSpaceCast(
881+
*self.llbuilder.lock().unwrap(),
882+
val,
883+
new_ty,
884+
unnamed(),
885+
);
886+
}
887+
}
873888
llvm::LLVMBuildBitCast(&mut self.llbuilder.lock().unwrap(), val, dest_ty, unnamed())
874889
}
875890
}
@@ -1212,8 +1227,19 @@ impl<'ll, 'tcx, 'a> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
12121227

12131228
impl<'a, 'll, 'tcx> StaticBuilderMethods for Builder<'a, 'll, 'tcx> {
12141229
fn get_static(&mut self, def_id: DefId) -> &'ll Value {
1215-
// Forward to the `get_static` method of `CodegenCx`
1216-
self.cx().get_static(def_id)
1230+
unsafe {
1231+
let mut g = self.cx.get_static(def_id);
1232+
let llty = llvm::LLVMRustGetValueType(g);
1233+
let addrspace = AddressSpace(llvm::LLVMGetPointerAddressSpace(llty));
1234+
if addrspace != AddressSpace::DATA {
1235+
trace!("Remapping global address space of global {:?}", g);
1236+
let llty = llvm::LLVMGetElementType(llty);
1237+
let ty = self.type_ptr_to_ext(llty, AddressSpace::DATA);
1238+
let builder = &*self.llbuilder.lock().unwrap();
1239+
g = llvm::LLVMBuildAddrSpaceCast(builder, g, ty, unnamed());
1240+
}
1241+
g
1242+
}
12171243
}
12181244
}
12191245

crates/rustc_codegen_nvvm/src/const_ty.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,12 @@ impl<'ll, 'tcx> ConstMethods<'tcx> for CodegenCx<'ll, 'tcx> {
110110
}
111111
}
112112

113-
fn scalar_to_backend(&self, cv: Scalar, layout: abi::Scalar, llty: &'ll Type) -> &'ll Value {
113+
fn scalar_to_backend(
114+
&self,
115+
cv: Scalar,
116+
layout: abi::Scalar,
117+
mut llty: &'ll Type,
118+
) -> &'ll Value {
114119
trace!("Scalar to backend `{:?}`, `{:?}`, `{:?}`", cv, layout, llty);
115120
let bitsize = if layout.is_bool() {
116121
1
@@ -152,7 +157,11 @@ impl<'ll, 'tcx> ConstMethods<'tcx> for CodegenCx<'ll, 'tcx> {
152157
GlobalAlloc::Static(def_id) => {
153158
assert!(self.tcx.is_static(def_id));
154159
assert!(!self.tcx.is_thread_local_static(def_id));
155-
(self.get_static(def_id), AddressSpace::DATA)
160+
let val = self.get_static(def_id);
161+
let addrspace = unsafe {
162+
llvm::LLVMGetPointerAddressSpace(llvm::LLVMRustGetValueType(val))
163+
};
164+
(self.get_static(def_id), AddressSpace(addrspace))
156165
}
157166
};
158167
let llval = unsafe {
@@ -162,15 +171,23 @@ impl<'ll, 'tcx> ConstMethods<'tcx> for CodegenCx<'ll, 'tcx> {
162171
1,
163172
)
164173
};
174+
165175
if layout.value != Pointer {
166176
unsafe { llvm::LLVMConstPtrToInt(llval, llty) }
167177
} else {
178+
if base_addr_space != AddressSpace::DATA {
179+
unsafe {
180+
let element = llvm::LLVMGetElementType(llty);
181+
llty = self.type_ptr_to_ext(element, base_addr_space);
182+
}
183+
}
168184
self.const_bitcast(llval, llty)
169185
}
170186
}
171187
};
172188

173189
trace!("...Scalar to backend: `{:?}`", val);
190+
trace!("{:?}", std::backtrace::Backtrace::force_capture());
174191

175192
val
176193
}

crates/rustc_codegen_nvvm/src/consts.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ fn check_and_apply_linkage<'ll, 'tcx>(
213213
ty: Ty<'tcx>,
214214
sym: &str,
215215
span_def_id: DefId,
216+
instance: Instance<'tcx>,
216217
) -> &'ll Value {
218+
let addrspace = cx.static_addrspace(instance);
217219
let llty = cx.layout_of(ty).llvm_type(cx);
218220
if let Some(linkage) = attrs.linkage {
219221
// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#linkage-types-nvvm
@@ -239,7 +241,7 @@ fn check_and_apply_linkage<'ll, 'tcx>(
239241
};
240242
unsafe {
241243
// Declare a symbol `foo` with the desired linkage.
242-
let g1 = cx.declare_global(sym, llty2, AddressSpace::DATA);
244+
let g1 = cx.declare_global(sym, llty2, addrspace);
243245
llvm::LLVMRustSetLinkage(g1, linkage_to_llvm(linkage));
244246

245247
// Declare an internal global `extern_with_linkage_foo` which
@@ -251,7 +253,7 @@ fn check_and_apply_linkage<'ll, 'tcx>(
251253
let mut real_name = "_rust_extern_with_linkage_".to_string();
252254
real_name.push_str(sym);
253255
let g2 = cx
254-
.define_global(&real_name, llty, AddressSpace::DATA)
256+
.define_global(&real_name, llty, addrspace)
255257
.unwrap_or_else(|| {
256258
cx.sess().span_fatal(
257259
cx.tcx.def_span(span_def_id),
@@ -263,7 +265,7 @@ fn check_and_apply_linkage<'ll, 'tcx>(
263265
g2
264266
}
265267
} else {
266-
cx.declare_global(sym, llty, AddressSpace::DATA)
268+
cx.declare_global(sym, llty, addrspace)
267269
}
268270
}
269271

@@ -322,7 +324,8 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
322324
}
323325
}
324326

325-
let g = self.declare_global(sym, llty, AddressSpace::DATA);
327+
let addrspace = self.static_addrspace(instance);
328+
let g = self.declare_global(sym, llty, addrspace);
326329

327330
if !self.tcx.is_reachable_non_generic(def_id) {
328331
unsafe {
@@ -332,7 +335,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
332335

333336
g
334337
} else {
335-
check_and_apply_linkage(self, fn_attrs, ty, sym, def_id)
338+
check_and_apply_linkage(self, fn_attrs, ty, sym, def_id, instance)
336339
};
337340

338341
if fn_attrs.flags.contains(CodegenFnAttrFlags::THREAD_LOCAL) {
@@ -406,12 +409,13 @@ impl<'ll, 'tcx> StaticMethods for CodegenCx<'ll, 'tcx> {
406409
let linkage = llvm::LLVMRustGetLinkage(g);
407410
let visibility = llvm::LLVMRustGetVisibility(g);
408411

412+
let addrspace = self.static_addrspace(instance);
409413
let new_g = llvm::LLVMRustGetOrInsertGlobal(
410414
self.llmod,
411415
name.as_ptr().cast(),
412416
name.len(),
413417
val_llty,
414-
AddressSpace::DATA.0,
418+
addrspace.0,
415419
);
416420

417421
llvm::LLVMRustSetLinkage(new_g, linkage);
@@ -434,7 +438,7 @@ impl<'ll, 'tcx> StaticMethods for CodegenCx<'ll, 'tcx> {
434438
if !is_mutable && self.type_is_freeze(ty) {
435439
// TODO(RDambrosio016): is this the same as putting this in
436440
// the __constant__ addrspace for nvvm? should we set this addrspace explicitly?
437-
llvm::LLVMSetGlobalConstant(g, llvm::True);
441+
// llvm::LLVMSetGlobalConstant(g, llvm::True);
438442
}
439443

440444
debug_info::create_global_var_metadata(self, def_id, g);
@@ -446,6 +450,7 @@ impl<'ll, 'tcx> StaticMethods for CodegenCx<'ll, 'tcx> {
446450
if attrs.flags.contains(CodegenFnAttrFlags::USED) {
447451
self.add_used_global(g);
448452
}
453+
trace!("Codegen static `{:?}`", g);
449454
}
450455
}
451456

crates/rustc_codegen_nvvm/src/context.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
use crate::abi::FnAbiLlvmExt;
2-
use crate::attributes::{self, Symbols};
2+
use crate::attributes::{self, NvvmAttributes, Symbols};
33
use crate::debug_info::{self, compile_unit_metadata, CrateDebugContext};
44
use crate::llvm::{self, BasicBlock, Type, Value};
55
use crate::{target, LlvmMod};
66
use nvvm::NvvmOption;
7-
use rustc_codegen_ssa::traits::ConstMethods;
87
use rustc_codegen_ssa::traits::{BackendTypes, BaseTypeMethods, CoverageInfoMethods, MiscMethods};
8+
use rustc_codegen_ssa::traits::{ConstMethods, DerivedTypeMethods};
99
use rustc_data_structures::base_n;
1010
use rustc_hash::FxHashMap;
1111
use rustc_middle::dep_graph::DepContext;
1212
use rustc_middle::ty::layout::{
13-
FnAbiError, FnAbiOf, FnAbiRequest, HasParamEnv, LayoutError, TyAndLayout,
13+
FnAbiError, FnAbiOf, FnAbiRequest, HasParamEnv, HasTyCtxt, LayoutError, TyAndLayout,
1414
};
1515
use rustc_middle::ty::layout::{FnAbiOfHelpers, LayoutOfHelpers};
1616
use rustc_middle::ty::{Ty, TypeFoldable};
@@ -158,6 +158,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
158158
symbols: Symbols {
159159
nvvm_internal: Symbol::intern("nvvm_internal"),
160160
kernel: Symbol::intern("kernel"),
161+
addrspace: Symbol::intern("addrspace"),
161162
},
162163
dbg_cx,
163164
codegen_args: CodegenArgs::from_session(tcx.sess()),
@@ -261,6 +262,24 @@ impl<'ll, 'tcx> MiscMethods<'tcx> for CodegenCx<'ll, 'tcx> {
261262
}
262263

263264
impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
265+
/// Computes the address space for a static.
266+
pub fn static_addrspace(&self, instance: Instance<'tcx>) -> AddressSpace {
267+
let ty = instance.ty(self.tcx, ty::ParamEnv::reveal_all());
268+
let is_mutable = self.tcx().is_mutable_static(instance.def_id());
269+
let attrs = self.tcx.get_attrs(instance.def_id());
270+
let nvvm_attrs = NvvmAttributes::parse(self, attrs);
271+
272+
if let Some(addr) = nvvm_attrs.addrspace {
273+
return AddressSpace(addr as u32);
274+
}
275+
276+
if !is_mutable && self.type_is_freeze(ty) {
277+
AddressSpace(4)
278+
} else {
279+
AddressSpace::DATA
280+
}
281+
}
282+
264283
/// Declare a global value, returns the existing value if it was already declared.
265284
pub fn declare_global(
266285
&self,

crates/rustc_codegen_nvvm/src/link.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -227,22 +227,10 @@ fn codegen_into_ptx_file(
227227
// in this crate. We must unpack them and devour their bitcode to link in.
228228
for rlib in rlibs {
229229
let mut cgus = Vec::with_capacity(16);
230-
// just pick the first cgu name as the overall name for now.
231-
let mut name = String::new();
232230
for entry in Archive::new(File::open(rlib)?).entries()? {
233231
let mut entry = entry?;
234232
// metadata is where rustc puts rlib metadata, so its not a cgu we are interested in.
235233
if entry.path().unwrap() != Path::new(".metadata") {
236-
if name == String::new() {
237-
name = entry
238-
.path()
239-
.unwrap()
240-
.file_name()
241-
.unwrap()
242-
.to_str()
243-
.unwrap()
244-
.to_string();
245-
}
246234
// std::fs::read adds 1 to the size, so do the same here - see comment:
247235
// https://github.com/rust-lang/rust/blob/72868e017bdade60603a25889e253f556305f996/library/std/src/fs.rs#L200-L202
248236
let mut bitcode = Vec::with_capacity(entry.size() as usize + 1);

0 commit comments

Comments
 (0)