rand/distributions/
weighted_index.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Weighted index sampling
10
11use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
12use crate::distributions::Distribution;
13use crate::Rng;
14use core::cmp::PartialOrd;
15use core::fmt;
16
17// Note that this whole module is only imported if feature="alloc" is enabled.
18use alloc::vec::Vec;
19
20#[cfg(feature = "serde1")]
21use serde::{Serialize, Deserialize};
22
23/// A distribution using weighted sampling of discrete items
24///
25/// Sampling a `WeightedIndex` distribution returns the index of a randomly
26/// selected element from the iterator used when the `WeightedIndex` was
27/// created. The chance of a given element being picked is proportional to the
28/// value of the element. The weights can use any type `X` for which an
29/// implementation of [`Uniform<X>`] exists.
30///
31/// # Performance
32///
33/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
34/// `N` is the number of weights. As an alternative,
35/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html)
36/// supports `O(1)` sampling, but with much higher initialisation cost.
37///
38/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
39/// size is the sum of the size of those objects, possibly plus some alignment.
40///
41/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
42/// weights of type `X`, where `N` is the number of weights. However, since
43/// `Vec` doesn't guarantee a particular growth strategy, additional memory
44/// might be allocated but not used. Since the `WeightedIndex` object also
45/// contains, this might cause additional allocations, though for primitive
46/// types, [`Uniform<X>`] doesn't allocate any memory.
47///
48/// Sampling from `WeightedIndex` will result in a single call to
49/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
50/// will request a single value from the underlying [`RngCore`], though the
51/// exact number depends on the implementation of `Uniform<X>::sample`.
52///
53/// # Example
54///
55/// ```
56/// use rand::prelude::*;
57/// use rand::distributions::WeightedIndex;
58///
59/// let choices = ['a', 'b', 'c'];
60/// let weights = [2,   1,   1];
61/// let dist = WeightedIndex::new(&weights).unwrap();
62/// let mut rng = thread_rng();
63/// for _ in 0..100 {
64///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
65///     println!("{}", choices[dist.sample(&mut rng)]);
66/// }
67///
68/// let items = [('a', 0), ('b', 3), ('c', 7)];
69/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
70/// for _ in 0..100 {
71///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
72///     println!("{}", items[dist2.sample(&mut rng)].0);
73/// }
74/// ```
75///
76/// [`Uniform<X>`]: crate::distributions::Uniform
77/// [`RngCore`]: crate::RngCore
78#[derive(Debug, Clone, PartialEq)]
79#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
80#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
81pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82    cumulative_weights: Vec<X>,
83    total_weight: X,
84    weight_distribution: X::Sampler,
85}
86
87impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88    /// Creates a new a `WeightedIndex` [`Distribution`] using the values
89    /// in `weights`. The weights can use any type `X` for which an
90    /// implementation of [`Uniform<X>`] exists.
91    ///
92    /// Returns an error if the iterator is empty, if any weight is `< 0`, or
93    /// if its total value is 0.
94    ///
95    /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
96    pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
97    where
98        I: IntoIterator,
99        I::Item: SampleBorrow<X>,
100        X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
101    {
102        let mut iter = weights.into_iter();
103        let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
104
105        let zero = <X as Default>::default();
106        if !(total_weight >= zero) {
107            return Err(WeightedError::InvalidWeight);
108        }
109
110        let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
111        for w in iter {
112            // Note that `!(w >= x)` is not equivalent to `w < x` for partially
113            // ordered types due to NaNs which are equal to nothing.
114            if !(w.borrow() >= &zero) {
115                return Err(WeightedError::InvalidWeight);
116            }
117            weights.push(total_weight.clone());
118            total_weight += w.borrow();
119        }
120
121        if total_weight == zero {
122            return Err(WeightedError::AllWeightsZero);
123        }
124        let distr = X::Sampler::new(zero, total_weight.clone());
125
126        Ok(WeightedIndex {
127            cumulative_weights: weights,
128            total_weight,
129            weight_distribution: distr,
130        })
131    }
132
133    /// Update a subset of weights, without changing the number of weights.
134    ///
135    /// `new_weights` must be sorted by the index.
136    ///
137    /// Using this method instead of `new` might be more efficient if only a small number of
138    /// weights is modified. No allocations are performed, unless the weight type `X` uses
139    /// allocation internally.
140    ///
141    /// In case of error, `self` is not modified.
142    pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
143    where X: for<'a> ::core::ops::AddAssign<&'a X>
144            + for<'a> ::core::ops::SubAssign<&'a X>
145            + Clone
146            + Default {
147        if new_weights.is_empty() {
148            return Ok(());
149        }
150
151        let zero = <X as Default>::default();
152
153        let mut total_weight = self.total_weight.clone();
154
155        // Check for errors first, so we don't modify `self` in case something
156        // goes wrong.
157        let mut prev_i = None;
158        for &(i, w) in new_weights {
159            if let Some(old_i) = prev_i {
160                if old_i >= i {
161                    return Err(WeightedError::InvalidWeight);
162                }
163            }
164            if !(*w >= zero) {
165                return Err(WeightedError::InvalidWeight);
166            }
167            if i > self.cumulative_weights.len() {
168                return Err(WeightedError::TooMany);
169            }
170
171            let mut old_w = if i < self.cumulative_weights.len() {
172                self.cumulative_weights[i].clone()
173            } else {
174                self.total_weight.clone()
175            };
176            if i > 0 {
177                old_w -= &self.cumulative_weights[i - 1];
178            }
179
180            total_weight -= &old_w;
181            total_weight += w;
182            prev_i = Some(i);
183        }
184        if total_weight <= zero {
185            return Err(WeightedError::AllWeightsZero);
186        }
187
188        // Update the weights. Because we checked all the preconditions in the
189        // previous loop, this should never panic.
190        let mut iter = new_weights.iter();
191
192        let mut prev_weight = zero.clone();
193        let mut next_new_weight = iter.next();
194        let &(first_new_index, _) = next_new_weight.unwrap();
195        let mut cumulative_weight = if first_new_index > 0 {
196            self.cumulative_weights[first_new_index - 1].clone()
197        } else {
198            zero.clone()
199        };
200        for i in first_new_index..self.cumulative_weights.len() {
201            match next_new_weight {
202                Some(&(j, w)) if i == j => {
203                    cumulative_weight += w;
204                    next_new_weight = iter.next();
205                }
206                _ => {
207                    let mut tmp = self.cumulative_weights[i].clone();
208                    tmp -= &prev_weight; // We know this is positive.
209                    cumulative_weight += &tmp;
210                }
211            }
212            prev_weight = cumulative_weight.clone();
213            core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
214        }
215
216        self.total_weight = total_weight;
217        self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
218
219        Ok(())
220    }
221}
222
223impl<X> Distribution<usize> for WeightedIndex<X>
224where X: SampleUniform + PartialOrd
225{
226    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
227        use ::core::cmp::Ordering;
228        let chosen_weight = self.weight_distribution.sample(rng);
229        // Find the first item which has a weight *higher* than the chosen weight.
230        self.cumulative_weights
231            .binary_search_by(|w| {
232                if *w <= chosen_weight {
233                    Ordering::Less
234                } else {
235                    Ordering::Greater
236                }
237            })
238            .unwrap_err()
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use super::*;
245
246    #[cfg(feature = "serde1")]
247    #[test]
248    fn test_weightedindex_serde1() {
249        let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
250
251        let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
252        let de_weighted_index: WeightedIndex<i32> =
253            bincode::deserialize(&ser_weighted_index).unwrap();
254
255        assert_eq!(
256            de_weighted_index.cumulative_weights,
257            weighted_index.cumulative_weights
258        );
259        assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
260    }
261
262    #[test]
263    fn test_accepting_nan(){
264        assert_eq!(
265            WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
266            WeightedError::InvalidWeight,
267        );
268        assert_eq!(
269            WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
270            WeightedError::InvalidWeight,
271        );
272        assert_eq!(
273            WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
274            WeightedError::InvalidWeight,
275        );
276
277        assert_eq!(
278            WeightedIndex::new(&[0.5, 7.0])
279                .unwrap()
280                .update_weights(&[(0, &core::f32::NAN)])
281                .unwrap_err(),
282            WeightedError::InvalidWeight,
283        )
284    }
285
286
287    #[test]
288    #[cfg_attr(miri, ignore)] // Miri is too slow
289    fn test_weightedindex() {
290        let mut r = crate::test::rng(700);
291        const N_REPS: u32 = 5000;
292        let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
293        let total_weight = weights.iter().sum::<u32>() as f32;
294
295        let verify = |result: [i32; 14]| {
296            for (i, count) in result.iter().enumerate() {
297                let exp = (weights[i] * N_REPS) as f32 / total_weight;
298                let mut err = (*count as f32 - exp).abs();
299                if err != 0.0 {
300                    err /= exp;
301                }
302                assert!(err <= 0.25);
303            }
304        };
305
306        // WeightedIndex from vec
307        let mut chosen = [0i32; 14];
308        let distr = WeightedIndex::new(weights.to_vec()).unwrap();
309        for _ in 0..N_REPS {
310            chosen[distr.sample(&mut r)] += 1;
311        }
312        verify(chosen);
313
314        // WeightedIndex from slice
315        chosen = [0i32; 14];
316        let distr = WeightedIndex::new(&weights[..]).unwrap();
317        for _ in 0..N_REPS {
318            chosen[distr.sample(&mut r)] += 1;
319        }
320        verify(chosen);
321
322        // WeightedIndex from iterator
323        chosen = [0i32; 14];
324        let distr = WeightedIndex::new(weights.iter()).unwrap();
325        for _ in 0..N_REPS {
326            chosen[distr.sample(&mut r)] += 1;
327        }
328        verify(chosen);
329
330        for _ in 0..5 {
331            assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
332            assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
333            assert_eq!(
334                WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
335                    .unwrap()
336                    .sample(&mut r),
337                4
338            );
339        }
340
341        assert_eq!(
342            WeightedIndex::new(&[10][0..0]).unwrap_err(),
343            WeightedError::NoItem
344        );
345        assert_eq!(
346            WeightedIndex::new(&[0]).unwrap_err(),
347            WeightedError::AllWeightsZero
348        );
349        assert_eq!(
350            WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
351            WeightedError::InvalidWeight
352        );
353        assert_eq!(
354            WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
355            WeightedError::InvalidWeight
356        );
357        assert_eq!(
358            WeightedIndex::new(&[-10]).unwrap_err(),
359            WeightedError::InvalidWeight
360        );
361    }
362
363    #[test]
364    fn test_update_weights() {
365        let data = [
366            (
367                &[10u32, 2, 3, 4][..],
368                &[(1, &100), (2, &4)][..], // positive change
369                &[10, 100, 4, 4][..],
370            ),
371            (
372                &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
373                &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
374                &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
375            ),
376        ];
377
378        for (weights, update, expected_weights) in data.iter() {
379            let total_weight = weights.iter().sum::<u32>();
380            let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
381            assert_eq!(distr.total_weight, total_weight);
382
383            distr.update_weights(update).unwrap();
384            let expected_total_weight = expected_weights.iter().sum::<u32>();
385            let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
386            assert_eq!(distr.total_weight, expected_total_weight);
387            assert_eq!(distr.total_weight, expected_distr.total_weight);
388            assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
389        }
390    }
391
392    #[test]
393    fn value_stability() {
394        fn test_samples<X: SampleUniform + PartialOrd, I>(
395            weights: I, buf: &mut [usize], expected: &[usize],
396        ) where
397            I: IntoIterator,
398            I::Item: SampleBorrow<X>,
399            X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
400        {
401            assert_eq!(buf.len(), expected.len());
402            let distr = WeightedIndex::new(weights).unwrap();
403            let mut rng = crate::test::rng(701);
404            for r in buf.iter_mut() {
405                *r = rng.sample(&distr);
406            }
407            assert_eq!(buf, expected);
408        }
409
410        let mut buf = [0; 10];
411        test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
412            0, 6, 2, 6, 3, 4, 7, 8, 2, 5,
413        ]);
414        test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
415            0, 0, 0, 1, 0, 0, 2, 3, 0, 0,
416        ]);
417        test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
418            2, 2, 1, 3, 2, 1, 3, 3, 2, 1,
419        ]);
420    }
421
422    #[test]
423    fn weighted_index_distributions_can_be_compared() {
424        assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2]));
425    }
426}
427
428/// Error type returned from `WeightedIndex::new`.
429#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
430#[derive(Debug, Clone, Copy, PartialEq, Eq)]
431pub enum WeightedError {
432    /// The provided weight collection contains no items.
433    NoItem,
434
435    /// A weight is either less than zero, greater than the supported maximum,
436    /// NaN, or otherwise invalid.
437    InvalidWeight,
438
439    /// All items in the provided weight collection are zero.
440    AllWeightsZero,
441
442    /// Too many weights are provided (length greater than `u32::MAX`)
443    TooMany,
444}
445
446#[cfg(feature = "std")]
447impl std::error::Error for WeightedError {}
448
449impl fmt::Display for WeightedError {
450    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
451        f.write_str(match *self {
452            WeightedError::NoItem => "No weights provided in distribution",
453            WeightedError::InvalidWeight => "A weight is invalid in distribution",
454            WeightedError::AllWeightsZero => "All weights are zero in distribution",
455            WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
456        })
457    }
458}