Optimizing Test-Time Query Represntations for Dense Retrieval 阅读笔记
稠密检索通常采用双塔架构,节省了空间时间,但是检索的精确度有很大丢失,常常不能检索到相关的文章。最近的研究都从问题(pre-trained query)或文章(context encoders)的角度入手,来对检索结果进行适当的优化。本文提出的TouR(Test-Time Optimization of Query Representations)模型,便是对问题表表示进行尝试性时间步优化。
1、背景:
稠密检索(Dense Retrieval):
稠密检索需要对于一个给定的 query(q),找到与它具有高相关性的context(c)。一般而言,它首先通过编码器将 query 向量和 context 向量映射到同一个向量空间中,然后表示相似度(一般是点积):
$$sim(q,c) = q^Tc$$
之后选取 top-k 个 context 作为答案,保证 relevant context 在结果中存在的可能性尽可能大。问题端微调(Query-side Fine-tuning):
在稠密检索中,当测试问题的分布和训练集中问题的分布差异很大时,(差异是否可显性理解为语言表述)稠密检索的表现并不好。所以要对问题端进行微调。在 Lee 的论文 Learning Dense Representations of Phrases at Scale 中,给出了问题得分的全局函数:
$$\mathcal{L}\text{query }=-\sum{q\in\mathcal{Q}\text{train }c\in\mathcal{C}{1:k}^q,c=c^*}\log\sum_{c\in\mathcal{C}_{1:k}^q,c=c^*}P_k(c|q)$$而在本文中则将每个 query 作为一个单独的数据进行微调。我们只需要通过梯度下降的方式最大化得分函数,就能得到最佳的微调模型。
PRF算法(Pseudo Relevance Feedback):
PRF 算法是一种用正例 $c_r$ 、负例 $c_{nr}$ 来更新问题表示的方式。在最经典的 Rocchio 算法中,问题表示被更新为:
$$\begin{aligned}&g(\mathbf{q}t,\mathcal{C}{1:k}^{q_t})=\&\alpha\mathbf{q}t+\beta\frac{1}{|\mathcal{C}r|}\sum{c_r\in\mathcal{C}r}\mathbf{c}r-\gamma\frac{1}{|\mathcal{C}{nr}|}\sum{c{nr}\in\mathcal{C}{nr}}\mathbf{c}{nr}\end{aligned}$$
在本文中,作者得到的最终结果和 PRF 中的问题表示有异曲同工之妙。这也从一个形象的角度为本文计算结果提供了解释。
2、模型架构
模型计算的流程简单如下所示:
1、用传统模型算出 top-k 个 context 和它们的相关性分数
该部分使用了 Lee 在2021年提出的 DensePhrases 模型作为基础模型。其他没什么好说的。
2、用 cross-encoder re-ranker 对 top-k 算 re-ranking 得分并重新排序、更新相关性分数
该 re-ranker 的成分如下:
- Input:三元组 $\mathcal{D}_\text{train }={(q,c_q^+,c_q^-)}$ ,其中 q 表示 query,$c_q^+$ 表示正实例,$c_q^-$ 表示负实例。
- Architecture:使用了 RoBERTa-large model(一个BERT模型的变体)作为基础模型。
- Score:用重排名分数和相关性分数加权,更新相关性分数为:$\begin{aligned}\lambda s_i+(1-\lambda)&\sin(q,c_i).\end{aligned}$
3、设置一个阈值 p,从 top-k 中取正确率超过 p 的所有context
根据 2 中更新后的相关性分数,计算 context 是正确的 context($c^{*}$) 的概率,公式表示为:
$$\begin{aligned}P_k(\tilde{c}=&c^*|q,\phi)=\frac{\exp(\phi(q,\tilde{c})/\tau)}{\sum_{i=1}^k\exp(\phi(q,c_i)/\tau)}\end{aligned}$$
该公式即一个 softmax 函数,t 作为温度参数来调整平滑度。
4、用选取的 context 条件下的 query 得分进行梯度下降,更新模型参数
对于 hard 标签,采用问题背景中提到的局部函数来计算得分:
$$\mathcal{L}{\mathrm{hard}}(q,\mathcal{C}{1:k}^q)=-\log\sum_{\tilde{c}\in\mathcal{C}_{\mathrm{hard}}^q}P_k(\tilde{c}|q)$$
则梯度下降的表达式为:
$$\mathbf{q}_{t+1}\leftarrow\mathbf{q}t-\eta\frac{\partial\mathcal{L}{\mathrm{hard}}(\mathbf{q}t,\mathcal{C}{1:k}^{q_t})}{\partial\mathbf{q}_t}$$
对于 soft 标签,变形局部函数用KL散度来计算得分:
$$\begin{aligned}\mathcal{L}{\mathrm{soft}}(\mathbf{q}t,\mathcal{C}{1:k}^{q_t})&=\&-\sum{i=1}^kP(c_i|q_t,\phi)\log\frac{P_k(c_i|q_t)}{P(c_i|q_t,\phi)}\end{aligned}$$
它的梯度下降的表达式为:
$$\begin{aligned}&g(\mathbf{q}t,\mathcal{C}{1:k}^{q_t})\&=\mathbf{q}t+\eta\sum{i=1}^kP(c_i|q_t,\phi)\mathbf{c}i-\eta\sum{i=1}^kP_k(c_i|q_t)\mathbf{c}_i.\end{aligned}$$
对该公式的解释详见后续内容。
5、迭代重复上述过程,对置信度更高的答案进行挖掘。
由于使用了 cross-encoder 并且需要额外迭代检索,导致该做法具有极高的时间复杂度。为减少计算的规模,一般最多采用 t = 3 层迭代,或者在极高置信度(pseudo-positive or highest relevance score)答案出现时终止。
论文里给的抽象流程图如下所示:
![[CM1.png]]
它的意思估计就是通过从 $q_0$ 到 $q_3$ 的迭代,挖掘到了 1983 这个正确答案。不过个人认为论文后面的 Fig.5 中给的 sample 要更加清晰明了:
![[CM2.png]]