zl程序教程

您现在的位置是:首页 >  其他

当前栏目

【算法优化】记一次不太成功的文本相似性去重算法优化实践【续】

2023-03-20 14:52:09 时间

前一篇文章中,简单的以为可以将计算总数的部分去掉,但是如果去掉了总数的计算,就没法计算每个文章的热度(就是相似文章的数量),这点客户没法接受。

所以,还是得想法子继续优化。

1、回顾前一个版本的算法

在前一个版本的算法里,我们使用numpy实现成矩阵运算的形式,在数据量1.7万的时候,还是比直接使用es查询要慢。

分析这其中的原因,主要是计算量太大了,而es则可以充分利用已有的索引机制来提升效率。

那有没有办法降低算法的时间复杂度呢?

剪枝、分治、贪心、。。。。

典型的分治法,如二分法排序

2、分析数据的特征

一个simhash编码如:

0111001111110101011100100000111100110110010011011000010101110110

关于simhash的文章网上很多,这里不细说。

对于两个相似的文章,我们有理由假设他们的simhash码会有相当长的一个子串是相同的,这个假设显然是合理的。

于是,我们就可以采用分治法,将一个64位的simhash字符串均分成若干段,例如如果我们将上面的simhash串切成长度相等的4段:

0111001111110101
0111001000001111
0011011001001101
1000010101110110

根据前面的假设,那么另一篇相似的文章的simhash码也均分成4段,则至少有一段是和上面对应的一段是完全一样的。(对此,你可能有不同疑问,如果有这样一个文章,它的simhash码和这个文章的只是差了4个位,但是这4个位刚好又导致了这4个子串都不相同。确实,这是可能发生的,如果有人转载文章时,这里改一点,那里改一点,总体改动不多,但是改动的地方却是比较均匀,这就可能发生刚才的情形。不过对于我们来说,可以不考虑这种情况,因为这种情况本省发生的概率就比较小,其实为了降低时间延迟,牺牲一点点精度也是可以接受的)

因此切分之后的simhash子串的数量最大可能有2的16次方个(65536),每个子串可能对应若干个文章ID,最后把有交集的文章ID合并到一个类即可。

3、第三个优化版本

上面说起来好像挺简单,实现起来还是有点点复杂的。

3.1 数据切分:

import time
import json
import numpy as np
from typing import List

# 常量配置
sim_thr = 0.85 * 64    # 85%相似度阈值

# 加载1.7万的文章id及simhash值
with open('./article_simhash_17k.json') as f:
    data = json.load(f)
    
print(len(data))
start = time.time()
new_data = {}
for item in data:
    simhash, article_id = item['simhash'], item['id']
    simhash = ['0']*(64-len(simhash)) + list(simhash)
    # 不足64位的前面补0,文章id放到第一个值
    item = [article_id] + simhash
    simhash = ''.join(simhash)
    for i in range(4):            # 这个方式耗时:28s(会有一定的精度损失)
        new_data.setdefault(str(i)+simhash[i*16:i*16+16], []).append(item)
    # for i in range(8):            # 这个方式整体时间:39s
    #     if i*8+16 <= 64:
    #         new_data.setdefault(str(i)+simhash[i*8:i*8+16], []).append(item)

all_np_data = [np.array(vals) for vals in new_data.values() if len(vals) > 1]
print('time: ', time.time()-start, len(new_data), len(all_np_data))

这个比较简单,耗时约0.46秒。

有一个小技巧,切分生成key时,加上了一个序号ID,这样就能保证只有相同位置的段才会完全相同。

3.2 初步聚类

对每一个切分好的段内部的文章进行聚类:

def cluster(np_data) -> List[List[str]]:
    """将相似的文章聚类在一起"""
    results = []
    while np_data.shape[0] > 1:
        curr_row = np_data[0]
        np_data = np_data[1:]
        counts = np.count_nonzero(curr_row == np_data, axis=1)
        if len(counts) > 0:
            ls = list(np_data[counts > sim_thr][:, 0])
            ls.append(curr_row[0])
            results.append(ls)
        np_data = np_data[counts <= sim_thr]

    return results

all_cls = []   # List[List[str]], 文章id的初步聚类结果,每个类可以包含多个文章id
for np_data in all_np_data:
    all_cls += cluster(np_data)     # TODO 这里是可以并行的,待优化(不过这里耗时不是太多)

results = {item['id']: []  for item in data}   # 每个id对应一个类别,空列表表示没有分到具体的类别里
for cls_id, aids in enumerate(all_cls):    # 循环处理每个原始类别
    for aid in aids:      # 把该类的文章id都对应到类别id
        results[aid].append(cls_id)    # 一个文章可以包含多个初始分类

print('cluster time: ', time.time()-start, len(all_cls))

其中的cluster函数和前一篇文章的思路基本是一致的。

3.3 合并有交集的类别

上面的算法已经聚合了很多的列表,一篇文章最多可能被分到了4个类别上,需要对有交集的类别进行合并:

# 合并所有有交集的分类id
# TODO 这个步骤最耗时间,超过99%的时间都是消耗在这里
clses = [set(vals) for vals in results.values() if len(vals) > 1]
merge_cls = np.array(range(len(all_cls)))      # 每个类别对应的原始id
print('before merge, for len:', len(clses))
for i, cls in enumerate(clses):
    for cls_j in clses[i+1:]:
        # 集合计算比较慢
        # if len(cls.intersection(clses[j])) == 0:
        if len(cls.intersection(cls_j)) == 0:
            continue
        # 有交集则对应的全修改为最小值
        # cls = cls.union(cls_j)     # time: 39s
        cls.update(cls_j)          # time: 38s
    # 获取最小的类别
    merge_cls[list(cls)] = min(cls)

print('merge time', time.time()-start, len(set(merge_cls)))

其实就是循环处理每一个类别,将后面和该类别有交集的类别都合并在一起,并取类别中最小的类别id作为新的类别id。

从打印的数据可以看到,这里循环的长度有1.3万多,两重循环就得2亿左右次,整个程序超过99%的时间都是消耗在了这里。

3.4 计算文章类别及热度

有了前面一步的结果,这个倒是比较容易实现的了:

# 计算每个分类id对应的文章id列表
cls_results = {}
for aid, _cls in results.items():
    if len(_cls) > 0:
        cls_id = merge_cls[_cls[0]]   # 合并后的id
        cls_results.setdefault(cls_id, []).append(aid)

# 生成最后的结果
# TODO 需要验证这里的相似文章的相似度怎么样
cls_results = [vals for vals in cls_results.values()]
print('hot max:', max([len(vals) for vals in cls_results]))

# 计算热度
cls_results = [(vals[0], len(vals)) for vals in cls_results]

# 补全热度为1的数据
for _id, _cls in results.items():
    if len(_cls) == 0:
        cls_results.append((_id, 1))

print('time: ', time.time() - start, ' len: ', len(results), len(cls_results))
assert len(data) == sum([val[1] for val in cls_results])

最后,在我的笔记本上运行大概耗时27秒,比之前的4分钟还是下降了很多的。

4、总结

通过分治法,牺牲一点精度,换来了时间消耗的减少,这是值得的。不过执行一次还需要27秒,这个还是有点多的,而且随着数据量的增大,这个耗时可能是指数级的,还需要继续优化。

合并类别那里还是有不少优化空间的,待续。。。