1#[cfg(feature = "alloc")] use core::slice;
12
13#[cfg(feature = "alloc")] use alloc::vec::{self, Vec};
14#[cfg(all(feature = "alloc", not(feature = "std")))]
16use alloc::collections::BTreeSet;
17#[cfg(feature = "std")] use std::collections::HashSet;
18
19#[cfg(feature = "std")]
20use crate::distributions::WeightedError;
21
22#[cfg(feature = "alloc")]
23use crate::{Rng, distributions::{uniform::SampleUniform, Distribution, Uniform}};
24
25#[cfg(feature = "serde1")]
26use serde::{Serialize, Deserialize};
27
28#[derive(Clone, Debug)]
32#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
33pub enum IndexVec {
34 #[doc(hidden)]
35 U32(Vec<u32>),
36 #[doc(hidden)]
37 USize(Vec<usize>),
38}
39
40impl IndexVec {
41 #[inline]
43 pub fn len(&self) -> usize {
44 match *self {
45 IndexVec::U32(ref v) => v.len(),
46 IndexVec::USize(ref v) => v.len(),
47 }
48 }
49
50 #[inline]
52 pub fn is_empty(&self) -> bool {
53 match *self {
54 IndexVec::U32(ref v) => v.is_empty(),
55 IndexVec::USize(ref v) => v.is_empty(),
56 }
57 }
58
59 #[inline]
64 pub fn index(&self, index: usize) -> usize {
65 match *self {
66 IndexVec::U32(ref v) => v[index] as usize,
67 IndexVec::USize(ref v) => v[index],
68 }
69 }
70
71 #[inline]
73 pub fn into_vec(self) -> Vec<usize> {
74 match self {
75 IndexVec::U32(v) => v.into_iter().map(|i| i as usize).collect(),
76 IndexVec::USize(v) => v,
77 }
78 }
79
80 #[inline]
82 pub fn iter(&self) -> IndexVecIter<'_> {
83 match *self {
84 IndexVec::U32(ref v) => IndexVecIter::U32(v.iter()),
85 IndexVec::USize(ref v) => IndexVecIter::USize(v.iter()),
86 }
87 }
88}
89
90impl IntoIterator for IndexVec {
91 type Item = usize;
92 type IntoIter = IndexVecIntoIter;
93
94 #[inline]
96 fn into_iter(self) -> IndexVecIntoIter {
97 match self {
98 IndexVec::U32(v) => IndexVecIntoIter::U32(v.into_iter()),
99 IndexVec::USize(v) => IndexVecIntoIter::USize(v.into_iter()),
100 }
101 }
102}
103
104impl PartialEq for IndexVec {
105 fn eq(&self, other: &IndexVec) -> bool {
106 use self::IndexVec::*;
107 match (self, other) {
108 (&U32(ref v1), &U32(ref v2)) => v1 == v2,
109 (&USize(ref v1), &USize(ref v2)) => v1 == v2,
110 (&U32(ref v1), &USize(ref v2)) => {
111 (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as usize == *y))
112 }
113 (&USize(ref v1), &U32(ref v2)) => {
114 (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as usize))
115 }
116 }
117 }
118}
119
120impl From<Vec<u32>> for IndexVec {
121 #[inline]
122 fn from(v: Vec<u32>) -> Self {
123 IndexVec::U32(v)
124 }
125}
126
127impl From<Vec<usize>> for IndexVec {
128 #[inline]
129 fn from(v: Vec<usize>) -> Self {
130 IndexVec::USize(v)
131 }
132}
133
134#[derive(Debug)]
136pub enum IndexVecIter<'a> {
137 #[doc(hidden)]
138 U32(slice::Iter<'a, u32>),
139 #[doc(hidden)]
140 USize(slice::Iter<'a, usize>),
141}
142
143impl<'a> Iterator for IndexVecIter<'a> {
144 type Item = usize;
145
146 #[inline]
147 fn next(&mut self) -> Option<usize> {
148 use self::IndexVecIter::*;
149 match *self {
150 U32(ref mut iter) => iter.next().map(|i| *i as usize),
151 USize(ref mut iter) => iter.next().cloned(),
152 }
153 }
154
155 #[inline]
156 fn size_hint(&self) -> (usize, Option<usize>) {
157 match *self {
158 IndexVecIter::U32(ref v) => v.size_hint(),
159 IndexVecIter::USize(ref v) => v.size_hint(),
160 }
161 }
162}
163
164impl<'a> ExactSizeIterator for IndexVecIter<'a> {}
165
166#[derive(Clone, Debug)]
168pub enum IndexVecIntoIter {
169 #[doc(hidden)]
170 U32(vec::IntoIter<u32>),
171 #[doc(hidden)]
172 USize(vec::IntoIter<usize>),
173}
174
175impl Iterator for IndexVecIntoIter {
176 type Item = usize;
177
178 #[inline]
179 fn next(&mut self) -> Option<Self::Item> {
180 use self::IndexVecIntoIter::*;
181 match *self {
182 U32(ref mut v) => v.next().map(|i| i as usize),
183 USize(ref mut v) => v.next(),
184 }
185 }
186
187 #[inline]
188 fn size_hint(&self) -> (usize, Option<usize>) {
189 use self::IndexVecIntoIter::*;
190 match *self {
191 U32(ref v) => v.size_hint(),
192 USize(ref v) => v.size_hint(),
193 }
194 }
195}
196
197impl ExactSizeIterator for IndexVecIntoIter {}
198
199
200pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
223where R: Rng + ?Sized {
224 if amount > length {
225 panic!("`amount` of samples must be less than or equal to `length`");
226 }
227 if length > (::core::u32::MAX as usize) {
228 return sample_rejection(rng, length, amount);
231 }
232 let amount = amount as u32;
233 let length = length as u32;
234
235 if amount < 163 {
240 const C: [[f32; 2]; 2] = [[1.6, 8.0 / 45.0], [10.0, 70.0 / 9.0]];
241 let j = if length < 500_000 { 0 } else { 1 };
242 let amount_fp = amount as f32;
243 let m4 = C[0][j] * amount_fp;
244 if amount > 11 && (length as f32) < (C[1][j] + m4) * amount_fp {
246 sample_inplace(rng, length, amount)
247 } else {
248 sample_floyd(rng, length, amount)
249 }
250 } else {
251 const C: [f32; 2] = [270.0, 330.0 / 9.0];
252 let j = if length < 500_000 { 0 } else { 1 };
253 if (length as f32) < C[j] * (amount as f32) {
254 sample_inplace(rng, length, amount)
255 } else {
256 sample_rejection(rng, length, amount)
257 }
258 }
259}
260
261#[cfg(feature = "std")]
276#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
277pub fn sample_weighted<R, F, X>(
278 rng: &mut R, length: usize, weight: F, amount: usize,
279) -> Result<IndexVec, WeightedError>
280where
281 R: Rng + ?Sized,
282 F: Fn(usize) -> X,
283 X: Into<f64>,
284{
285 if length > (core::u32::MAX as usize) {
286 sample_efraimidis_spirakis(rng, length, weight, amount)
287 } else {
288 assert!(amount <= core::u32::MAX as usize);
289 let amount = amount as u32;
290 let length = length as u32;
291 sample_efraimidis_spirakis(rng, length, weight, amount)
292 }
293}
294
295
296#[cfg(feature = "std")]
309fn sample_efraimidis_spirakis<R, F, X, N>(
310 rng: &mut R, length: N, weight: F, amount: N,
311) -> Result<IndexVec, WeightedError>
312where
313 R: Rng + ?Sized,
314 F: Fn(usize) -> X,
315 X: Into<f64>,
316 N: UInt,
317 IndexVec: From<Vec<N>>,
318{
319 if amount == N::zero() {
320 return Ok(IndexVec::U32(Vec::new()));
321 }
322
323 if amount > length {
324 panic!("`amount` of samples must be less than or equal to `length`");
325 }
326
327 struct Element<N> {
328 index: N,
329 key: f64,
330 }
331 impl<N> PartialOrd for Element<N> {
332 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
333 self.key.partial_cmp(&other.key)
334 }
335 }
336 impl<N> Ord for Element<N> {
337 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
338 self.partial_cmp(other).unwrap()
341 }
342 }
343 impl<N> PartialEq for Element<N> {
344 fn eq(&self, other: &Self) -> bool {
345 self.key == other.key
346 }
347 }
348 impl<N> Eq for Element<N> {}
349
350 #[cfg(feature = "nightly")]
351 {
352 let mut candidates = Vec::with_capacity(length.as_usize());
353 let mut index = N::zero();
354 while index < length {
355 let weight = weight(index.as_usize()).into();
356 if !(weight >= 0.) {
357 return Err(WeightedError::InvalidWeight);
358 }
359
360 let key = rng.gen::<f64>().powf(1.0 / weight);
361 candidates.push(Element { index, key });
362
363 index += N::one();
364 }
365
366 let (_, mid, greater)
371 = candidates.select_nth_unstable(length.as_usize() - amount.as_usize());
372
373 let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
374 result.push(mid.index);
375 for element in greater {
376 result.push(element.index);
377 }
378 Ok(IndexVec::from(result))
379 }
380
381 #[cfg(not(feature = "nightly"))]
382 {
383 use alloc::collections::BinaryHeap;
384
385 let mut candidates = BinaryHeap::with_capacity(length.as_usize());
388 let mut index = N::zero();
389 while index < length {
390 let weight = weight(index.as_usize()).into();
391 if !(weight >= 0.) {
392 return Err(WeightedError::InvalidWeight);
393 }
394
395 let key = rng.gen::<f64>().powf(1.0 / weight);
396 candidates.push(Element { index, key });
397
398 index += N::one();
399 }
400
401 let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
402 while result.len() < amount.as_usize() {
403 result.push(candidates.pop().unwrap().index);
404 }
405 Ok(IndexVec::from(result))
406 }
407}
408
409fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
416where R: Rng + ?Sized {
417 let floyd_shuffle = amount < 50;
421
422 debug_assert!(amount <= length);
423 let mut indices = Vec::with_capacity(amount as usize);
424 for j in length - amount..length {
425 let t = rng.gen_range(0..=j);
426 if floyd_shuffle {
427 if let Some(pos) = indices.iter().position(|&x| x == t) {
428 indices.insert(pos, j);
429 continue;
430 }
431 } else if indices.contains(&t) {
432 indices.push(j);
433 continue;
434 }
435 indices.push(t);
436 }
437 if !floyd_shuffle {
438 for i in (1..amount).rev() {
440 indices.swap(i as usize, rng.gen_range(0..=i) as usize);
442 }
443 }
444 IndexVec::from(indices)
445}
446
447fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
460where R: Rng + ?Sized {
461 debug_assert!(amount <= length);
462 let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
463 indices.extend(0..length);
464 for i in 0..amount {
465 let j: u32 = rng.gen_range(i..length);
466 indices.swap(i as usize, j as usize);
467 }
468 indices.truncate(amount as usize);
469 debug_assert_eq!(indices.len(), amount as usize);
470 IndexVec::from(indices)
471}
472
473trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform
474 + core::hash::Hash + core::ops::AddAssign {
475 fn zero() -> Self;
476 fn one() -> Self;
477 fn as_usize(self) -> usize;
478}
479impl UInt for u32 {
480 #[inline]
481 fn zero() -> Self {
482 0
483 }
484
485 #[inline]
486 fn one() -> Self {
487 1
488 }
489
490 #[inline]
491 fn as_usize(self) -> usize {
492 self as usize
493 }
494}
495impl UInt for usize {
496 #[inline]
497 fn zero() -> Self {
498 0
499 }
500
501 #[inline]
502 fn one() -> Self {
503 1
504 }
505
506 #[inline]
507 fn as_usize(self) -> usize {
508 self
509 }
510}
511
512fn sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec
522where
523 R: Rng + ?Sized,
524 IndexVec: From<Vec<X>>,
525{
526 debug_assert!(amount < length);
527 #[cfg(feature = "std")]
528 let mut cache = HashSet::with_capacity(amount.as_usize());
529 #[cfg(not(feature = "std"))]
530 let mut cache = BTreeSet::new();
531 let distr = Uniform::new(X::zero(), length);
532 let mut indices = Vec::with_capacity(amount.as_usize());
533 for _ in 0..amount.as_usize() {
534 let mut pos = distr.sample(rng);
535 while !cache.insert(pos) {
536 pos = distr.sample(rng);
537 }
538 indices.push(pos);
539 }
540
541 debug_assert_eq!(indices.len(), amount.as_usize());
542 IndexVec::from(indices)
543}
544
545#[cfg(test)]
546mod test {
547 use super::*;
548
549 #[test]
550 #[cfg(feature = "serde1")]
551 fn test_serialization_index_vec() {
552 let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]);
553 let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap();
554 match (some_index_vec, de_some_index_vec) {
555 (IndexVec::U32(a), IndexVec::U32(b)) => {
556 assert_eq!(a, b);
557 },
558 (IndexVec::USize(a), IndexVec::USize(b)) => {
559 assert_eq!(a, b);
560 },
561 _ => {panic!("failed to seralize/deserialize `IndexVec`")}
562 }
563 }
564
565 #[cfg(feature = "alloc")] use alloc::vec;
566
567 #[test]
568 fn test_sample_boundaries() {
569 let mut r = crate::test::rng(404);
570
571 assert_eq!(sample_inplace(&mut r, 0, 0).len(), 0);
572 assert_eq!(sample_inplace(&mut r, 1, 0).len(), 0);
573 assert_eq!(sample_inplace(&mut r, 1, 1).into_vec(), vec![0]);
574
575 assert_eq!(sample_rejection(&mut r, 1u32, 0).len(), 0);
576
577 assert_eq!(sample_floyd(&mut r, 0, 0).len(), 0);
578 assert_eq!(sample_floyd(&mut r, 1, 0).len(), 0);
579 assert_eq!(sample_floyd(&mut r, 1, 1).into_vec(), vec![0]);
580
581 let sum: usize = sample_rejection(&mut r, 1 << 25, 10u32).into_iter().sum();
583 assert!(1 << 25 < sum && sum < (1 << 25) * 25);
584
585 let sum: usize = sample_floyd(&mut r, 1 << 25, 10).into_iter().sum();
586 assert!(1 << 25 < sum && sum < (1 << 25) * 25);
587 }
588
589 #[test]
590 #[cfg_attr(miri, ignore)] fn test_sample_alg() {
592 let seed_rng = crate::test::rng;
593
594 let (length, amount): (usize, usize) = (100, 50);
600 let v1 = sample(&mut seed_rng(420), length, amount);
601 let v2 = sample_inplace(&mut seed_rng(420), length as u32, amount as u32);
602 assert!(v1.iter().all(|e| e < length));
603 assert_eq!(v1, v2);
604
605 let v3 = sample_floyd(&mut seed_rng(420), length as u32, amount as u32);
607 assert!(v1 != v3);
608
609 let (length, amount): (usize, usize) = (1 << 20, 50);
611 let v1 = sample(&mut seed_rng(421), length, amount);
612 let v2 = sample_floyd(&mut seed_rng(421), length as u32, amount as u32);
613 assert!(v1.iter().all(|e| e < length));
614 assert_eq!(v1, v2);
615
616 let (length, amount): (usize, usize) = (1 << 20, 600);
618 let v1 = sample(&mut seed_rng(422), length, amount);
619 let v2 = sample_rejection(&mut seed_rng(422), length as u32, amount as u32);
620 assert!(v1.iter().all(|e| e < length));
621 assert_eq!(v1, v2);
622 }
623
624 #[cfg(feature = "std")]
625 #[test]
626 fn test_sample_weighted() {
627 let seed_rng = crate::test::rng;
628 for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] {
629 let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap();
630 match v {
631 IndexVec::U32(mut indices) => {
632 assert_eq!(indices.len(), amount);
633 indices.sort_unstable();
634 indices.dedup();
635 assert_eq!(indices.len(), amount);
636 for &i in &indices {
637 assert!((i as usize) < len);
638 }
639 },
640 IndexVec::USize(_) => panic!("expected `IndexVec::U32`"),
641 }
642 }
643 }
644
645 #[test]
646 fn value_stability_sample() {
647 let do_test = |length, amount, values: &[u32]| {
648 let mut buf = [0u32; 8];
649 let mut rng = crate::test::rng(410);
650
651 let res = sample(&mut rng, length, amount);
652 let len = res.len().min(buf.len());
653 for (x, y) in res.into_iter().zip(buf.iter_mut()) {
654 *y = x as u32;
655 }
656 assert_eq!(
657 &buf[0..len],
658 values,
659 "failed sampling {}, {}",
660 length,
661 amount
662 );
663 };
664
665 do_test(10, 6, &[8, 0, 3, 5, 9, 6]); do_test(25, 10, &[18, 15, 14, 9, 0, 13, 5, 24]); do_test(300, 8, &[30, 283, 150, 1, 73, 13, 285, 35]); do_test(300, 80, &[31, 289, 248, 154, 5, 78, 19, 286]); do_test(300, 180, &[31, 289, 248, 154, 5, 78, 19, 286]); do_test(1_000_000, 8, &[
672 103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573,
673 ]); do_test(1_000_000, 180, &[
675 103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573,
676 ]); }
678}