@@ -3,8 +3,8 @@ use rand::Rng;
3
3
use crate :: sync:: atomic:: { AtomicUsize , Ordering } ;
4
4
use crate :: sync:: mpsc:: channel;
5
5
use crate :: sync:: {
6
- Arc , MappedRwLockReadGuard , MappedRwLockWriteGuard , RwLock , RwLockReadGuard , RwLockWriteGuard ,
7
- TryLockError ,
6
+ Arc , Barrier , MappedRwLockReadGuard , MappedRwLockWriteGuard , RwLock , RwLockReadGuard ,
7
+ RwLockWriteGuard , TryLockError ,
8
8
} ;
9
9
use crate :: thread;
10
10
@@ -501,3 +501,123 @@ fn panic_while_mapping_write_unlocked_poison() {
501
501
502
502
drop ( lock) ;
503
503
}
504
+
505
+ #[ test]
506
+ fn test_downgrade_basic ( ) {
507
+ let r = RwLock :: new ( ( ) ) ;
508
+
509
+ let write_guard = r. write ( ) . unwrap ( ) ;
510
+ let _read_guard = RwLockWriteGuard :: downgrade ( write_guard) ;
511
+ }
512
+
513
+ #[ test]
514
+ fn test_downgrade_readers ( ) {
515
+ // This test creates 1 writing thread and `R` reader threads doing `N` iterations.
516
+ const R : usize = 10 ;
517
+ const N : usize = if cfg ! ( target_pointer_width = "64" ) { 100 } else { 20 } ;
518
+
519
+ // The writer thread will constantly update the value inside the `RwLock`, and this test will
520
+ // only pass if every reader observes all values between 0 and `N`.
521
+ let rwlock = Arc :: new ( RwLock :: new ( 0 ) ) ;
522
+ let barrier = Arc :: new ( Barrier :: new ( R + 1 ) ) ;
523
+
524
+ // Create the writing thread.
525
+ let r_writer = rwlock. clone ( ) ;
526
+ let b_writer = barrier. clone ( ) ;
527
+ thread:: spawn ( move || {
528
+ for i in 0 ..N {
529
+ let mut write_guard = r_writer. write ( ) . unwrap ( ) ;
530
+ * write_guard = i;
531
+
532
+ let read_guard = RwLockWriteGuard :: downgrade ( write_guard) ;
533
+ assert_eq ! ( * read_guard, i) ;
534
+
535
+ // Wait for all readers to observe the new value.
536
+ b_writer. wait ( ) ;
537
+ }
538
+ } ) ;
539
+
540
+ for _ in 0 ..R {
541
+ let rwlock = rwlock. clone ( ) ;
542
+ let barrier = barrier. clone ( ) ;
543
+ thread:: spawn ( move || {
544
+ // Every reader thread needs to observe every value up to `N`.
545
+ for i in 0 ..N {
546
+ let read_guard = rwlock. read ( ) . unwrap ( ) ;
547
+ assert_eq ! ( * read_guard, i) ;
548
+ drop ( read_guard) ;
549
+
550
+ // Wait for everyone to read and for the writer to change the value again.
551
+ barrier. wait ( ) ;
552
+
553
+ // Spin until the writer has changed the value.
554
+ loop {
555
+ let read_guard = rwlock. read ( ) . unwrap ( ) ;
556
+ assert ! ( * read_guard >= i) ;
557
+
558
+ if * read_guard > i {
559
+ break ;
560
+ }
561
+ }
562
+ }
563
+ } ) ;
564
+ }
565
+ }
566
+
567
+ #[ test]
568
+ fn test_downgrade_atomic ( ) {
569
+ const NEW_VALUE : i32 = -1 ;
570
+
571
+ // This test checks that `downgrade` is atomic, meaning as soon as a write lock has been
572
+ // downgraded, the lock must be in read mode and no other threads can take the write lock to
573
+ // modify the protected value.
574
+
575
+ // `W` is the number of evil writer threads.
576
+ const W : usize = if cfg ! ( target_pointer_width = "64" ) { 100 } else { 20 } ;
577
+ let rwlock = Arc :: new ( RwLock :: new ( 0 ) ) ;
578
+
579
+ // Spawns many evil writer threads that will try and write to the locked value before the
580
+ // initial writer (who has the exclusive lock) can read after it downgrades.
581
+ // If the `RwLock` behaves correctly, then the initial writer should read the value it wrote
582
+ // itself as no other thread should be able to mutate the protected value.
583
+
584
+ // Put the lock in write mode, causing all future threads trying to access this go to sleep.
585
+ let mut main_write_guard = rwlock. write ( ) . unwrap ( ) ;
586
+
587
+ // Spawn all of the evil writer threads. They will each increment the protected value by 1.
588
+ let handles: Vec < _ > = ( 0 ..W )
589
+ . map ( |_| {
590
+ let rwlock = rwlock. clone ( ) ;
591
+ thread:: spawn ( move || {
592
+ // Will go to sleep since the main thread initially has the write lock.
593
+ let mut evil_guard = rwlock. write ( ) . unwrap ( ) ;
594
+ * evil_guard += 1 ;
595
+ } )
596
+ } )
597
+ . collect ( ) ;
598
+
599
+ // Wait for a good amount of time so that evil threads go to sleep.
600
+ // Note: this is not strictly necessary...
601
+ let eternity = crate :: time:: Duration :: from_millis ( 42 ) ;
602
+ thread:: sleep ( eternity) ;
603
+
604
+ // Once everyone is asleep, set the value to `NEW_VALUE`.
605
+ * main_write_guard = NEW_VALUE ;
606
+
607
+ // Atomically downgrade the write guard into a read guard.
608
+ let main_read_guard = RwLockWriteGuard :: downgrade ( main_write_guard) ;
609
+
610
+ // If the above is not atomic, then it would be possible for an evil thread to get in front of
611
+ // this read and change the value to be non-negative.
612
+ assert_eq ! ( * main_read_guard, NEW_VALUE , "`downgrade` was not atomic" ) ;
613
+
614
+ // Drop the main read guard and allow the evil writer threads to start incrementing.
615
+ drop ( main_read_guard) ;
616
+
617
+ for handle in handles {
618
+ handle. join ( ) . unwrap ( ) ;
619
+ }
620
+
621
+ let final_check = rwlock. read ( ) . unwrap ( ) ;
622
+ assert_eq ! ( * final_check, W as i32 + NEW_VALUE ) ;
623
+ }
0 commit comments