zl程序教程

您现在的位置是:首页 >  后端

当前栏目

蓄水池抽样算法

算法 抽样
2023-06-13 09:15:31 时间

问题

从n个数字中随机选取m个数字作为样本,要求这n个数字每个被选到的概率都相等。

常规算法

如果n并不是一个特别大的数字,我们可以一次性把这n个数字加载进内存,每次从中选取1个,选取m次。

123456

def random_m(nums,m): n = len(nums) for i in range(m): j = random.randint(i,n - 1) nums[i],nums[j] = nums[j],nums[i] return nums[:m]

代码逻辑为先从[0,n-1]随机选择一个作为选中数字的索引,代表该索引上的数字已被选中,将其和索引0更换位置,此后索引0上的数字不再发生变化。之后再从[1,n-1]随机选择一个作为选中索引,将其和索引1的数字更换位置,此后索引0和索引1都不会发生变化。以此类推,重复m轮后,被选取的m个数字全部位于数组的前m项,将其返回即可。

我们可用数学公式证明使用该算法每个数字被选取的概率都是相等的,且结果均为 \frac{m}{n}

如果一个数字最终被选中,那么它一定是在这m轮中的某一轮被选中的。如果是在第1轮被选中,那么概率应该为 \frac{1}{n} ,如果是在第2轮被选中,概率应该为第1轮未选中的概率乘以第2轮选中的概率,即 (1-\frac{1}{n})\times\frac{1}{n-1} = \frac{1}{n} 。同理,第3轮被选中的概率为前2次都不选中的概率乘以第3轮选中的概率,即 (1-\frac{1}{n}) \times (1-\frac{1}{n-1}) \times \frac{1}{n-2} = \frac{n-1}{n} \times \frac{n-2}{n-1} \times \frac{1}{n-2} = \frac{1}{n}

综上,每一轮被选中的概率都为 \frac{1}{n} ,总共有m轮,即有m次机会可能被选中,所以最终被选中的概率应该为每轮被选中的概率总和,也即 \frac{m}{n}

蓄水池算法

对于数值较大的n,我们无法一次性将所有数字加载进内存,或者说,如果面向的是数据流,无法确定后续的数字是什么,那么蓄水池算法就可以派上用场了。它可以保证在n巨大或者n最终不确定的情况下,让每个数字被选中的概率均为 \frac{m}{n} 。具体实现步骤:

  1. 构造一个大小为m的池子,所有在池子中的条目,代表被选中;
  2. 如果当前n小于等于m,此时所有数字都应被放入池子,所有数字被选中的概率均为1;
  3. 当n等于m+1时,我们构造一个概率为 \frac{m}{m+1} 的事件,实现方法为从[1,m+1]中随机一个数,如果该数在[1,m]范围内,则事件命中。如果事件命中,我们则将m+1这个条目放入池子,此时该条目被选中概率为 \frac{m}{m+1} 。而已在池子中的某个条目则要被随机选择换到池子外,被换出的概率为 \frac{1}{m} 。所以已在池子中的条目,被换出的概率为 \frac{m}{m+1} \times \frac{1}{m} ,则不被换出的概率为 1-\frac{m}{m+1} \times \frac{1}{m} ,它们一开始进入池子的概率为1,所以最终留在池子的概率为 1\times(1-\frac{m}{m+1} \times \frac{1}{m}) = \frac{m}{m+1} 。此时,所有条目被选中的概率均为 \frac{m}{m+1}
  4. 当n等于m+2时,我们构造概率为 \frac{m}{m+2} 的事件,让m+2条目在事件命中时进入池子,则其选中概率为 \frac{m}{m+2} 。对于其他条目,上一轮中选中概率均为 \frac{m}{m+1} ,该轮不被选出的概率为 (1-\frac{m}{m+2} \times \frac{1}{m}) ,则经过该轮仍在池子中的概率为 \frac{m}{m+1} \times (1-\frac{m}{m+2} \times \frac{1}{m})= \frac{m}{m+2}
  5. 自此可归纳,条目最终保留在池子中的概率P为:P=1\times(1-\frac{m}{m+1}\times\frac{1}{m})\times(1-\frac{m}{m+2}\times\frac{1}{m}) ... \times(1-\frac{m}{n}\times\frac{1}{m})其中1n<=m时必入池子的概率,后面每一项为n>m不被选出池子的概率,而每一个新条目加入池子的概率都和之前条目保留在池子的概率一致,所以所有的条目在池中的概率都可以用该公式表示,简化得:P=\frac{m}{m+1}\times\frac{m+1}{m+2} ... \times \frac{n-1}{n}=\frac{m}{n}

12345678910111213

def sampling_m(nums,m): res = [] n = 0 for num in nums: n += 1 if n <= m: res.append(num) continue i = random.randint(1,n) if i <= m: j = random.randint(1,m) - 1 res[j] = num return res

可以看到,算法实现并不依赖采样原始数据nums的长度大小,当nums过于庞大时,完全可以作为数据流的方式进行读取。不需要一次性将所有数字读入内存,并且能够保证每个数字都能等概率被选中,这就是蓄水池算法的实现目的。