-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix alignment issue of TensorData bytes #2416
base: main
Are you sure you want to change the base?
Conversation
introduce max alignment (which depends on platform anyway) and dont serialize that part fixes Clone, Debug, and Eq impls to work on the bytes, not the pointers.
5e9679a
to
f84c73a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2416 +/- ##
==========================================
- Coverage 85.26% 82.91% -2.35%
==========================================
Files 792 811 +19
Lines 104516 105103 +587
==========================================
- Hits 89115 87149 -1966
- Misses 15401 17954 +2553 ☔ View full report in Codecov by Sentry. |
Re Codecov report: The missed lines are for the most part the Debug impl of Bytes. I think it makes sense to invest a bit more time into nice formatting there, any comments on what the output should be would be appreciated. |
Sorry for the delayed response! We discussed this offline but forgot to get back to you. First, thanks for identifying the potential issue. While these changes to handle alignment concerns more generically could certainly provide better safety, it also adds complexity and might not be immediately necessary. That's why we're a bit hesitant on making this change for now. As you noticed this is not an issue for most allocators. Of course that doesn't mean it is not an issue, just that it does not affect the majority of use cases. So we should still keep this in mind, and we might circle back in the future thanks to your investigation. |
While I understand the added complexity makes it difficult to review, I am a bit disappointed. There is no way around the fact that the currently implementation is misusing I only noticed this when I implemented the deserialization logic, but please note that the current impl will also happily allocate a In any case, I suppose there is no problem with just leaving this open for the time being. Just one more question: Would you be less wary if the new unsafe code was in a tested outside crate such as bytemuck? Some parts such as the (de)serialization I think don't make sense to be outside burn, but most of the "byte container" could possibly make more sense upstream. |
No worries, we're on the same page here 🙂 I think our main concerns were more in terms of trade-offs, mostly time and possible ramifications to this breaking change (though I think my position has changed on that.. as you'll see below).
I'm glad you brought this up! Honestly I think I kind of forgot to take that into consideration.. this should be especially important for Burn as we aim for portability. I think our previous position was that the current implementation was not an issue for most modern processors and platforms, but that's not good enough. I should really get a ras-pi myself and start messing around with Burn on
I'm not too wary about this tbh. I gotta run but wanted to make sure to get back to you with my current thoughts. In short, in my previous comment I was opened to the change but did not think it was required for the short term. But now I think we should make sure that this is addressed (ideally, without performance impacts to de/serialization - so we should make sure to benchmark that for some practical use cases). |
No worries, I too think we are on the same page :) While strictly speaking violating the allocator contract like atm could lead to memory unsafety, I don't think any allocator actually does use the violated guarantees in a meaningful way (but I'm open to surprises). On the other hand, misaligned allocations would be a real concern but at least as far as I can tell mostly lead to panics, not undefined behaviour or violations of memory safety, so a relaxed approach to fixing this is fine 👍 Sorry if I came off as a bit panicky |
@WorldSEnder I was looking at your PR yesterday and I think we'd like to stay away from a new opaque type for now. I also benchmarked some solutions to preserve alignment and none of them come close to the current pointer re-interpret sadly. I think we can keep both approaches instead based on some target architectures for now, since we know that heap allocations are overaligned on mainstream systems. And for others with stricter alignment rules, we fallback to something similar to what you implemented in the /// Initializes a new tensor data structure from the provided values.
fn init<E: Element, S: Into<Vec<usize>>>(
mut value: Vec<E>,
shape: S,
dtype: DType,
) -> Self {
// Ensure shape is valid
let shape = shape.into();
let shape_numel = Self::numel(&shape);
value.truncate(shape_numel);
let numel = value.len();
assert_eq!(
shape_numel, numel,
"Shape {:?} is invalid for input of size {:?}",
shape, numel,
);
let bytes = Self::to_bytes(value);
Self {
bytes,
shape,
dtype,
}
}
fn to_bytes<E: Element>(value: Vec<E>) -> Vec<u8> {
// Fast path for architectures with heap allocation overalignment.
#[cfg(all(
not(miri),
any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")
))]
{
let mut value = value;
let factor = std::mem::size_of::<E>() / std::mem::size_of::<u8>();
let len = value.len() * factor;
let capacity = value.capacity() * factor;
let ptr = value.as_mut_ptr();
std::mem::forget(value);
// SAFETY: Practically, mainstream systems align heap allocations beyond
// the minimum required alignment for the element type (e.g., 8 bytes on 64-bit
// systems).
//
// This is not guaranteed by Rust's memory model, which follows strict alignment
// requirements as specified by the type's layout.
//
// But we leverage the overalignment and efficient unaligned access to interpret
// the same memory layout as individual bytes instead of `E` elements.
// This simple trick makes this pointer cast acceptable (and requires no copy).
unsafe { Vec::from_raw_parts(ptr as *mut u8, len, capacity) }
}
// Fallback for architectures with stricter alignment rules.
#[cfg(any(
miri,
not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))
))]
{
let num_bytes = value.len() * std::mem::size_of::<E>();
let mut bytes = vec![0u8; num_bytes];
unsafe {
std::ptr::copy_nonoverlapping(
value.as_ptr() as *const u8,
bytes.as_mut_ptr(),
num_bytes,
);
}
bytes
}
} I also added |
@laggui I think in the long run, an opaque wrapper should happen at some point (note the wrapper in this PR does no mem-copy so should not have the overhead you mention). A publicly exposed field of Secondly, the target detection is perhaps 90% there. Sadly, I don't think there is a way to detect the specific allocator being used, and the check is of course only detecting the What exactly is the concern with an opaque struct? It seems there would be a simple almost-zero-cost conversion from PS: Did you encounter any slowdown with this PR in your benchmarking? If so, this would be a concern and I would put in more work to make sure I didn't miss any unreasonable overhead. |
Yeah I think we could (should) remove the public field and users could only get the slice with the current available methods. So no direct access to the underlying bytes. Not sure why we left it public actually 😅
That's what I found as well. Was trying to cover the most important cases with this condition where stricter alignment requirements are necessary, but I realize this is not 100%. I wasn't sure if a new type was warranted at this time but I now realize that it might be required to do this cleanly.
Oh I agree! This was mostly for the current tests with miri 😅
An opaque struct like Also tiny note on the current
Ran some benchmarks for The results seem to be on par with the current, which is great 🙏 I think I recalled incorrectly that there was a copy in Current
This PR
|
I was wondering if it was intentional that you could push further bytes after construction, cause that seems a bit weird since you could push e.g. a single
After more investigation, and a bit longer look than I had wished, I found that the default allocator on windows actually does use the passed align in a non-trivial way. Specifically, data below
The reason the current code gets away with it - supposedly - is because on 64-bit platforms
Yeah, the timing close to the release of 0.15 is a bit unfortunate. Not sure about your planned release cycles, I wouldn't mind holding it until the breaking 0.16 comes along naturally.
Thread-safety is indeed a good point. The current
There seems to be a slight slowdown in deserialization, which might perform an extra copy to an aligned buffer. I think in that case we can get away with actually checking the runtime alignment instead (still storing and de-allocating with the alignment of the actual buffer) and not copying/only copying when someone wants to convert it an actual |
this already improves readability a bit by separating out alloc/dealloc logic and adding a bunch of safety comments and better error messages
borrowing from the deserializer will not save a copy, and is moreover inefficient when we could take ownership of an existing byte buffer
both changes only target miri's borrowing semantics, oprationally the pointers are the same, but they obey different borrow-stack rules.
Alright, have added both suggestions (
|
Checklist
run-checks all
script has been executed.I have some issues with flaky tests, in particular
quantize_dynamic_int8
anddownsample_interpolation
are off-by-1 in a few cases. I've ignored them locally, they seem to be fine in CI.Further, I've stumbled over a test failure seemingly related to my AMD GPU driver
Related Issues/PRs
Fixes #2375.
Changes
Add a new opaque structure
Bytes
which holds on to tensor data. As explained in the linked issue, using aVec<u8>
is incorrect, as this will deallocate the memory with an incorrect alignment/memory layout given to the allocator, which violates the allocator contract. Instead, the structure remembers the alignment of the data and uses that when deallocating.For serialization, only the bytes are written and no alignment is taken into account. Instead, when deserializing, the data is allocated with over-alignment (currently
align_of::<u128>
) which should be sufficient to interpret a slice of the data as a slice of some larger type, such asf32
etc.Note that this means that serialization is not strictly a round-trip, which can show in the difference between how slices and
Vec
works:[u8]
as a[E]
, the slice needs to be aligned at least as strictly as the alignment ofE
.Vec<A>
as aVec<B>
, the alignments must be exactly equal. This means that a deserialized instance ofBytes
possibly can not be converted viaBytes::try_into_vec
. In case this is needed in the future, one would either need to get access to the needed alignment during deserialization or copy the data to a correctly aligned buffer. Such an API is possible but not included in this PR.Testing
Some limited tests of
TensorData
has been run through miri manually to confirm that it doesn't complain after the patch. In case you want tests in the repo, I would need more guidance how to set that up, since miri should probably only run a limited number of tests for CI performance reasons.