Cluster Triplet Loss for Unsupervised Domain Adaptation on Histology Images
Cluster Triplet Loss for Unsupervised Domain Adaptation on Histology Images
Ruby Wood1 ;Enric Domingo2 ;Viktor Hendrik Koelzer2,3,4 ;Timothy S. Maughan2,5 ;Jens Rittscher 1
1牛津大学工程科学系,英国牛津
2英国牛津大学肿瘤科,牛津
3瑞士苏黎世大学医院病理学和分子病理学系,苏黎世大学,苏黎世,瑞士
4瑞士巴塞尔大学医院医学遗传学和病理学研究所,巴塞尔,瑞士
5英国利物浦大学分子和临床癌症医学系,利物浦,英国
摘要
用于从医学影像预测癌症患者治疗反应的深度学习模型,需要在不同患者群体中具有普适性。然而,由于患者群体的多样性,这是一项挑战。本研究关注从肿瘤活检扫描的组织学图像中预测结直肠癌患者对放疗的反应,并将这一预测模型应用于一个新且外观不同的患者群体。我们提出了一种新的无监督域适应方法,该方法采用聚类三元组损失函数,仅使用源域中的少量信息,从而将目标群体的AUC值从0.544提升至0.818。我们避免使用伪标签和类别特征中心,以防止对适应模型引入噪声和偏差,并通过实验验证我们的模型优于这些最先进方法。我们提出的方法适用于多种复杂的医学成像场景,包括基于图神经网络提取的图像的小型、内存可行表示的预测,用于大型全切片图像的预测。
1.引言
将医学影像领域的深度学习模型从一组患者适应到另一组患者可能具有挑战性,因为患者之间存在广泛的个体差异。本研究中,我们利用深度学习技术,通过分析治疗前肿瘤组织的数字组织学图像,预测结直肠癌(CRC)患者对放疗的反应,并尝试将该模型应用于来自不同地理区域的全新患者群体。本研究特别关注无监督域适应(UDA),因为为了使该预测模型在临床实践中有效,我们需要在使用时无需了解患者的预后情况来调整模型。
尽管在其他领域中,使用域适应技术的研究已经相当广泛,但将其应用于组织学图像则更具挑战性,因为这种成像方式的尺寸和异质性带来了复杂性[12]。
组织学切片是从活检样本中切割出来的,经过苏木精和伊红(H&E)染色并数字化扫描的肿瘤组织切片。这些切片以极高的分辨率扫描,导致文件体积极大。为了适应计算机内存,图像需要被分割成更小的部分,然后采用多实例学习(MIL)方法将这些预测整合为每张切片的一个最终预测。本文提出了一种方法,专注于模型内部的中间特征表示,从而保留了所有可选的MIL方法,用于最终输出。具体而言,我们利用图神经网络(GNN)方法,从自然分割的组织区域进行预测,通过GNN中的特征来帮助模型适应新领域。
社会经济因素可能影响不同地区或国家的患者对癌症的体验[5],而组织学图像中的批次效应通常源于肿瘤活检样本从患者体内取出后的处理过程。由于各医疗中心在组织样本处理上的方法略有差异,这导致数据中存在固有的领域偏移[29]。我们使用来自三个不同医疗中心的患者队列来训练和验证我们的方法,这些中心的组织处理实践各不相同。
我们以一种通用的方法来解决二元预测问题,通过仅调整基础特征以适应新领域,同时保留原有的分类分支,从而避免使用伪标签。与许多其他统一数据适应(UDA)方法不同[19,30,35,40-43,45],我们避免了从源模型中引入偏差和噪声到预测结果中。
此外,我们避免使用基于类别的聚类方法来为每个类别标签找到一个代表,因为许多文献[10,15,19,30,43]已经采用了这种方法。通过一次性对整个特征集进行聚类,可以允许每个类别标签内部存在更多的变异,从而获得一个自然数量的聚类,而不受数据集中类别标签数量的限制。这种方法在处理二元结果数据时效果尤为显著,因为它能够用超过两个聚类来表示整个源数据集。
本文中,我们开发了一种特征对齐的UDA技术,能够在不使用任何目标标签的情况下,将训练好的临床模型应用于未见过的目标群体。我们提出了一种创新的方法,定义了一个损失函数,用于‘源监督’训练,以实现领域适应。该损失函数仅需轻量级的源数据表示,即可指导新目标模型的学习,使其适应新的领域。我们的方法允许对队列调整的模型进行分布式训练,无需对原始模型进行任何训练或更新,从而提供了一种安全的联邦学习技术,能够保护不同地点之间的患者隐私。我们不将结果与所有数据集排列混淆,而是专注于最不相似的数据集作为目标数据集,因为这是最大的挑战。这在临床实践中也有类似的应用,即需要将预训练的模型转移到新的队列领域,以更好地预测患者的治疗结果,而无需事先了解患者对治疗的反应。这种方法不需要任何批次或队列假设,即使只有一个新数据点也可以应用。
2.相关工作
2.1.无监督域适应
2.1.1 聚类
尽管许多论文已经探索了使用聚类进行领域适应的方法,采用了对比或对抗损失方法[13,16,22,43]对齐源和领域分布,但据我们所知,还没有人使用我们建议的轻量级聚类方法。
我们领域适应方法的灵感来源于吸引和分散[40]的概念,该方法旨在将相似特征聚集在一起,而将不相似特征分开。这一无监督方法通过使用k-最近邻算法和伪标签,最大化邻居之间的预测一致性,同时最小化不相似特征预测的相似性。一种类似的方法,结构化正则化深度聚类(SRDC)[30],通过KMeans算法对目标数据的中间网络特征进行聚类,并最小化预测的目标标签与真实源标签分布之间的KL散度,以及学习到的源和目标聚类中心之间的KL散度。另一种利用KMeans的方法是Liang等人提出的源假设转移(SHOT)方法[19]。该方法通过冻结源模型的最终分类器层,其余部分作为目标模型的初始化。这种方法通过预测伪标签并最小化熵来实现无监督学习,类似于加权KMeans,找到目标类别的中心点,并根据最近邻类中心点(使用余弦距离测量)定义目标样本的伪标签。
2.1.2 伪标签
大多数统一数据辅助(UDA)方法使用伪标签来训练模型[19,30,35,40-43,45],这在多类分类中比二分类提供了更多的信息。这些伪标签通常用于掩码或作为指示器,以计算用于损失函数[40,45]的进一步统计信息。使用伪标签的方法很大程度上依赖于教师模型在目标领域具有合理的先验准确性,但这一点并不总是成立,正如Li等人[18]所指出的。重要的是,他们还发现没有通用的方法来评估这些伪标签的质量。尽管许多研究承认了这一问题,并提出了相应的解决方法[18,30,41,43],但这种设计上的缺陷显然会引入不必要的偏差和噪声。张等人也意识到了这一点,在训练过程中通过测量与类别[41]特征中心的距离来对伪标签进行加权正则化[41]。分割与对比的方法将目标数据分为源样或非源样,并合理假设,来自源样目标数据的伪标签比来自特定目标样本的伪标签更准确[43]。SRDC的作者也承认,源模型在目标数据上的不可靠性可能导致一些错误的目标预测,因此他们在损失函数中加入了额外的项,使用伪标签作为预测标签的指示器[30]。
2.1.3 三联体损失
使用中心特征的三元组损失的概念最初由[10]在物体检索中提出,他们提出了三元组中心损失(TCL),旨在将同一类别的特征对齐到一个可学习的类别中心,并排斥不同类别的特征。他们使用欧几里得距离来衡量类别中心与样本特征之间的差异,这一点与我们的方法相同,不过在他们的三元组损失中,他们选择了最近的负样本中心。此外,他们还利用类别标签来确定相应的类别中心,因此该方法并非无监督的。其他研究也采用了类似的方法,通过特征中心的三元组损失来处理不同领域的问题。大多数研究集中在从伪标签中计算特征中心,以找到每个类别在分类问题中的代表中心。
Wieczorek等人提出的质心三元组损失[37]用于图像检索,该方法在目标特征上采用传统的三元组损失,其中正例作为目标类别的中心,负例作为负类别的中心,这与我们提出的方法相似,但不同之处在于我们排除了任何假设或已知的类别信息。Lagunes-Fortiz等人[15]在他们的三元组损失中使用了不同的负样本,即从域本身中选取样本,而不是以特征为中心。此外,三元组损失还被用来定义目标和源集群作为类别引导的约束条件[35],以实现域之间的更好类别对齐。
2.2.组织学领域适应
腐蚀斑
在组织病理学的深度学习领域,不同医院和实验室之间的组织染色和处理方法存在显著差异,因此已经采取措施来减少这些批次染色的影响[8,14,25,44]超越传统的颜色标准化方法[20,31]。然而,有时仅靠这种方法还不足以确保模型的领域泛化能力。Lafarge等人[14]提出了一种领域对抗神经网络(DANN),用于预测样本是否来自特定领域,从而在保留对预测有用的特征的同时去除特定领域的特征。他们还尝试了传统的染色领域适应方法,发现当DANN与颜色增强或染色标准化结合使用时,效果最佳。
特征对齐
在本研究中,我们专注于源域与目标域之间的特征对齐。在组织学领域,大多数基于聚类的方法使用伪标签来预测类别,这有助于更新类别特征中心[6,32]。Distill-SODA
[32]是一种无源统一数据对齐方法,通过蒙特卡洛模拟聚类过程以提高鲁棒性。与我们的方法类似,他们计算聚类中心以在损失函数中与目标特征进行比较;然而,他们的聚类中心不是无标签的,而是每个类别仅限一个,而不是从源域自然推导而来。Jian等人提出了一种特征对齐方法[26],通过训练卷积神经网络(CNN)将目标图像映射到源模型的特征空间,以最小化不同领域的差异。该方法进一步引入了孪生模型,鼓励来自同一全切片图像(WSI)的区域被赋予相同的标签,但这种方法未能考虑到组织样本中自然存在的异质性。Wang等人[36]专注于利用图神经网络(GNN)节点特征,通过对抗损失对结直肠癌(CRC)组织学图像进行核检测的对齐。.Abbet等人[1]则使用少量源标签训练模型,用于结直肠癌组织分类。
二分类
大多数研究集中在多类分类或分割问题上,其中伪标签或类别中心可以提供更多的信息。一些研究则关注二分类问题,例如上皮-间质分类,有一篇论文同时在源域和目标域上训练单一模型,并通过简单地将目标域和源域的最大特征值对应的特征向量进行向量乘法,来调整卷积神经网络的内核以适应目标域[11]。Qi等人[24]也研究了上皮-间质分类,采用课程学习方法,通过测量样本与类别中心之间的余弦相似度来避免可能产生错误伪标签的样本,根据与源域的最大距离选择初始训练样本。
Li等人[17]专注于在乳腺癌、肺癌和结肠癌的组织学切片上对肿瘤进行良性或恶性的分类。尽管没有明确的结果类别,但他们确实使用了多个数据集队列,因此他们的UDA方法在每个源域和目标域上分别训练了一个特征提取器,并利用源标签来学习特征分布的对齐。最优传输也被用于惩罚肿瘤与正常组织二元分类中的域预测[9]。我们未发现有关UDA在从组织学图像预测患者治疗反应模型方面的先前研究。
组织学下的三联体损失
很少有研究将三元组损失应用于组织学的领域适应,尤其是在无监督方法方面。Sikaroudi等人[28]在他们的研究中使用了三联体损失,旨在学习不受医院限制的组织学表示,重点关注不同领域之间的类别条件变化。他们采用监督方法,使用交叉熵损失对目标预测进行优化,同时利用KL散度来对齐特征域,并通过度量损失来区分不同的类别。
3.方法
本研究假设我们已经有一个预训练的源模型,希望将其适应到新的领域。下面我们将描述这个源模型,该模型基于[38]领域的类似先前模型构建,并解释如何使用源模型的权重初始化来训练新模型,以适应新领域的预测。此外,我们还将介绍在源数据上使用的聚类方法,以提取源数据的轻量级表示,这些表示随后被用于我们提出的聚类三元组损失函数中,以训练和适应新模型。
我们首先介绍一些术语。源模型是我们应用任何领域适应之前所使用的初始模型。该源模型之前已在源数据\(x_s\)上进行了训练和验证。目标数据\(x_t\)是我们希望让源模型适应的新未见过的数据集。目标模型是源模型的更新版本,且已经适应了目标数据。
3.1.源模型
我们的源模型是一个图神经网络(GNN),包含三个图同构网络层,特征尺寸分别为64、32和16。我们没有直接将WSI输入到GNN中,而是先对WSI应用了超像素方法,然后从同一区域(大小[1,768])
[38]的补丁特征中计算出超像素特征,这些补丁特征是通过自监督预训练的大组织学模型CTransPath
[34]提取的。基于这些超像素特征,我们构建了每个WSI的图表示,其中节点和节点特征由超像素定义,图的边则通过最近邻关系使用Delaunay三角化定义。这些图随后作为GNN的输入,GNN以半监督方式训练,以预测患者对放疗的反应。
在源验证数据集上,源模型达到了0.931的AUC、0.803的平衡准确率和0.885的加权F1分数。显然,我们的源模型在源队列上表现优异。尽管我们在训练过程中努力使模型具有更强的泛化能力,但当该模型应用于未见过的测试队列时,其泛化能力明显不足,AUC为0.544、平衡准确率为0.500、加权F1分数为0.840,具体数据见表2。为了防止训练队列上的过拟合,我们采取了多种措施,包括在提取特征前对训练图像进行大量数据增强,在图神经网络(GNN)和分类分支中引入高概率的dropout(p=0.5),在多个地理队列的患者上进行训练,并采用多任务学习方法,确保最终特征集不仅包含分子特征信息,还涵盖了空间组织结构等[38]。
在本研究中,我们仅关注中间特征表示,而不涉及模型的最终预测阶段。在训练新的领域适应模型时,我们冻结了目标模型中的分类分支(由于采用了源模型的多任务学习方法,因此存在多个分类分支,其中一个分支用于预测患者对放射治疗的反应),并且仅在这些分类分支之前训练图神经网络层。因此,本研究的重点在于数据集的节点级特征,而非幻灯片级特征。我们将源模型中的节点特征提取器部分称为Fs,而该部分之后的分类器在源模型和目标模型之间保持不变。
3.2.聚类
我们通过在源数据上应用聚类技术,提取出源数据特征集的轻量级、高层次表示。图神经网络(GNNs)为我们提供了来自超像素节点的节点级预测,从而直观、自然地展现了肿瘤内部组织段的形态特征。我们从图神经网络的最终层中提取这些节点的特征,然后将其分为三个预测分支,以实现多任务学习方法。
我们将聚类方法应用于训练中观察到的源数据队列的标准化节点特征向量的连接集合。这些连接的特征向量大小为[N,16],其中N
=
134,132是节点总数,16是每个节点的特征数。为了确定最佳聚类数量kopt,我们计算了聚类数量k
=
2,...,20时的轮廓宽度[27]。我们选择该范围内轮廓宽度最高、Calinski-Harabasz指数[4]最低且David
Bouldin得分[7]最低的聚类作为最佳聚类数量,以确保在无监督设置中聚类的最显著性。由于样本量较大,我们采用了KMeans
MiniBatch方法,该方法在Python库sklearn.cluster(版本1.1.3)
[23]中实现。为了提高效率,我们在源数据集的一个子样本(n =
10,000个节点特征)上拟合了MiniBatch
KMeans,使用了最佳的聚类数量。我们提取了最终的聚类中心C,大小为[kopt,16]。
3.3.聚类三元组损失
为了将我们的模型训练并适应目标数据集,我们提出了簇三元组损失,该方法利用了前一节中的源聚类。
我们提出的聚类三元组损失函数以每个样本为基础,这意味着它可以适应任何规模的队列。对于每个提供的特征向量,它计算该特征向量与固定源聚类中心之间的均方误差损失,类似于传统KMeans算法的一次迭代。基于此,我们选取与输入特征向量最接近和最远的聚类中心,并将它们作为三元组损失计算中的正样本和负样本,以输入特征向量为锚点,将特征向量移动到聚类域中,同时对样本进行聚类。我们对模型训练批次中的所有特征向量同时进行向量化和应用该方法。在实现三元组损失时,我们采用了1的边界,并根据Balntas等人[3]的建议,将输入与负样本中心的距离与正样本与负样本中心的距离进行了交换。
我们首先定义源模型\(\mathcal{M}_{s}=\mathcal{H}_{s}(\mathcal{F}_{s})\),其中\(\mathcal{H}_{s}\)是模型的分类器部分,\(\mathcal{F}_{s}\)是模型的特征部分,后者被调整以适应新领域。接着,我们定义目标模型\(\mathcal{M}_{t}=\mathcal{H}_{s}(\mathcal{F}_{t})\),这里我们使用源模型\(\mathcal{H}_{s}\)是的相同分类器,但更新了源模型的特征部分以生成\(\mathcal{F}_{t}\)。因此,源模型和目标模型具有完全相同的架构,但权重不同。
在我们提出的聚类三元组损失函数中,我们从源数据的最佳聚类得到的聚类中心C出发。对于批量大小为b的输入目标数据\(x_t\),我们计算每个聚类中心与输入之间的欧氏距离\(d_{ij}\), \[d_{i j}=\|x_{t_{i}}-C_{j}\|^{2},\]
其中,i∈[1,b]表示批次中的每个节点输入。
我们利用这些距离来确定最近的(\(C_{j_{pos}}\))和最远的(\(C_{j_{neg}}\))聚类中心,使用 \[j_{p o s_{i}}=\arg\operatorname*{min}_{j}d_{i
j},\quad j_{n e g_{i}}=\arg\operatorname*{max}_{j}d_{i j}.\]
我们在调整后的三元组损失函数中使用这些正负聚类中心,如定义所示 \[L_{i}(x_{t_{i}})=\operatorname*{max}\{||x_{t_{i}}-C_{j_{p
o s_{i}}}||^{2}-\|C_{j_{p o s_{i}}}-C_{j_{n e
g_{i}}}||^{2}+\mu,0\}\]
使用边缘μ=1。
最后,我们通过取批次的平均值来减少输出,并使用批次损失对模型进行反向传播。
\[L_{b}(x_{t};C,\mu)=\frac{1}{b}\sum_{i}L_{i}(x_{t_{i}},j_{p
o s_{i}},j_{n e g_{i}};C,\mu),\]
在簇中心C和边缘μ固定的情况下,\(j_{p o
s_{i}}\)和\(j_{n e
g_{i}}\)会根据方程(1)和(2)的变化而变化。
我们整个方法的算法详见算法1。步骤1-3只需执行一次,然后,给定源数据表示C,从步骤4开始,可以用于在不同领域训练任意数量的目标模型。

算法1:使用聚类三元组损失进行训练
输入:源特征模型\({\mathcal{F}}_{s}\)、源数据\(x_{s}\)以及目标数据\(x_t\)。
1
在分类前,从GNN的最终层中提取特征\({\mathcal{F}}_{s}(x_{s})\);
2 对\({\mathcal{F}}_{s}(x_{s})\)运行KMeans算法,设置k
= 2,...,20个聚类,并通过轮廓宽度计算最优k值\(k_{opt}\);
3 从最佳KMeans中提取\(k_{opt}\)个聚类中心C;
4 将源模型\({\mathcal{F}}_{s}\)的权重应用于目标模型\({\mathcal{F}}_{t}\);
5
while Training do
6 |
从目标模型中提取目标特征\({\mathcal{F}}_{t}(x_{t})\);
7 |
使用等式(1)计算\({\mathcal{F}}_{t}(x_{t})\)与C中每个聚类中心之间的欧氏距离;
8
|
使用等式(2)计算距离,找到与目标特征最接近(Cpos)和最远(Cneg)的聚类;
9
| 使用公式(3)和(4),计算\(x_t\)的平均三元组损失,并进行反向传播。
10
end
输出:调整后的目标特征模型\({\mathcal{F}}_{t}\)
4.实验
4.1.数据
4.2.结果
4.3.与最先进水平(SOTA)的比较
4.4.消融研究
5.讨论
5.1.当前限制与未来工作
我们承认,我们的适应模型仅训练到特征提取阶段,这意味着用于从这些域转换特征预测结果的分类分支没有更新。由于我们将特征域转移到了原始源特征的域上,而现有的分类分支正是基于这些原始特征训练的,因此这部分模型应该无需进一步训练就能适应。然而,在这最后一步中,可能遗漏了一些有用的队列特定信息。
正如我们在消融研究(第4.4节)中所展示的,找到源数据的最佳聚类是该方法取得最佳效果的关键。可以将这项工作扩展,测试这种方法是否能推广到多个目标群体。为了模拟实际应用,建议采用累积方法,在每个新的目标领域重新计算聚类中心,并评估这如何影响模型的适应性。此外,还可以探讨调整后的模型在原始源域中的表现变化。
这种方法的效果在很大程度上取决于源数据中疾病变异对疾病空间的覆盖程度。如果我们确信源模型之前已经处理过特定的疾病变异,那么在调整特征时可以更加积极,同时对于异常值则可以采取较为保守的态度,引入某种加权异常检测方法。
5.2.结论
我们提出了一种新方法,该方法利用图节点特征和源聚类中心,在聚类三元组损失函数中实现组织学深度学习模型的统一域适应。这种方法允许在全息切片图像(WSI)中进行局部域适应,使得一个目标图像中的不同组织切片不必以相同的方式‘移动’。
尽管我们提出的方法并非完全无源,但只需要原始数据的密集表示,这不仅避免了存储内存密集型的数据集,而且在不同医院环境中实施时,能够保护患者数据的匿名性。该方法适用于多种结果类别,并可应用于多种深度学习和多实例学习(MIL)方法。