Segment Anything Model for Medical Images?

发表于MICCAI 2024

Testing pipeline of SAM

image-20240528220811746
train with box
接下来是的代码来自
https://github.com/yuhoo0302/Segment-Anything-Model-for-Medical-Images
任务是医学图片的分割

------------------------------------------------------------------------------------------------------------
优化器和损失函数设计:
------------------------------------------------------------------------------------------------------------
# Set up the optimizer, hyperparameter tuning will improve performance here
optimizer = torch.optim.AdamW(sam_model.mask_decoder.parameters(), lr=args.lr, weight_decay=args.weight_decay)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
------------------------------------------------------------------------------------------------------------
为一个训练过程中的代码:
------------------------------------------------------------------------------------------------------------
outputs = []
# do not compute gradients for image encoder and prompt encoder
with torch.no_grad():
none_grad_features = {"sparse": {}, "dense": {}}
for idx, image_record in enumerate(batched_input):
sparse_embeddings, dense_embeddings = model.prompt_encoder(
points=None,
boxes=image_record["box"].to(device),
masks=None,
)
none_grad_features["sparse"][idx] = sparse_embeddings
none_grad_features["dense"][idx] = dense_embeddings

batched_loss = 0
for id, im_record in enumerate(batched_input):
# low_res_masks.shape == (B, M, 256, 256) M is set to 1
low_res_masks, iou_predictions = model.mask_decoder(
image_embeddings=im_record["img_embed"].unsqueeze(0).to(device), # (1, 256, 64, 64) !!1 = batch size
image_pe=model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) !!1 = batch size
sparse_prompt_embeddings=none_grad_features["sparse"][id], # (B, 2, 256) !!B = target num instead of batch size
dense_prompt_embeddings=none_grad_features["dense"][id], # (B, 256, 64, 64) !!B = target num instead of batch size
multimask_output=False,
)
# upscale + eliminate padding + restore to ori size
masks = model.postprocess_masks(
low_res_masks,
input_size=tuple(im_record["size_before_pad"]),
original_size=tuple(im_record["image_ori_size"]),
)
outputs.append({
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
"gt2D": im_record["gt2D"].to(device)
})
# first ele: 1, B, ori_H, ori_W
# second ele: 1, B, ori_H, ori_W
# considering the multi-object situation
batched_loss += criterion(masks.squeeze(1).unsqueeze(0), im_record["gt2D"].to(device).unsqueeze(0))
loss = batched_loss / len(batched_input)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()

Segment Anything in High Quality

发表于NeurIPS 2023

image-20240528221000632

图3:HQ-SAM将HQ输出令牌和全局局部特征融合引入SAM,用于高质量掩模预测。为了保持SAM的零样本能力,轻量级HQ-Output-Token重用SAM的掩码解码器,并生成新的MLP层,用于执行具有融合HQ-Features的点向产品。在训练过程中,当我们固定预先训练的SAM的模型参数时,HQ-SAM中只有少数可学习的参数是可训练的。为了清晰起见,此处省略了提示编码器。误差校正简单地用作推理期间SAM的输出令牌和HQ输出令牌的预测logits之间的直接元素和。

损失函数设置

We supervise mask prediction of the new HQ-Output token with a combination of both BCE Loss and Dice Loss.

我们用BCE损失和Dice 损失组合的联合损失监督新HQ输出token的掩码预测。

train with box、point、noise_mask
接下来是的代码来自
https://github.com/SysCV/SAM-HQ
任务是SAM的高质量掩模预测问题
------------------------------------------------------------------------------------------------------------
from utils.loss_mask import loss_masks

net = MaskDecoderHQ(args.model_type)
------------------------------------------------------------------------------------------------------------
if input_type == 'box':
dict_input['boxes'] = labels_box[b_i:b_i+1]
elif input_type == 'point':
point_coords = labels_points[b_i:b_i+1]
dict_input['point_coords'] = point_coords
dict_input['point_labels'] = torch.ones(point_coords.shape[1], device=point_coords.device)[None,:]
elif input_type == 'noise_mask':
dict_input['mask_inputs'] = labels_noisemask[b_i:b_i+1]
else:
raise NotImplementedError
dict_input['original_size'] = imgs[b_i].shape[:2]
batched_input.append(dict_input)

with torch.no_grad():
batched_output, interm_embeddings = sam(batched_input, multimask_output=False)

batch_len = len(batched_output)
encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0)
image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)]
sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)]
dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)]
masks_hq = net(
image_embeddings=encoder_embedding,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
hq_token_only=True,
interm_embeddings=interm_embeddings,
)

loss_mask, loss_dice = loss_masks(masks_hq, labels/255.0, len(masks_hq))
loss = loss_mask + loss_dice

loss_dict = {"loss_mask": loss_mask, "loss_dice":loss_dice}

# reduce losses over all GPUs for logging purposes
loss_dict_reduced = misc.reduce_dict(loss_dict)
losses_reduced_scaled = sum(loss_dict_reduced.values())
loss_value = losses_reduced_scaled.item()

optimizer.zero_grad()
loss.backward()
optimizer.step()

metric_logger.update(training_loss=loss_value, **loss_dict_reduced)
------------------------------------------------------------------------------------------------------------
loss_masks:
------------------------------------------------------------------------------------------------------------
def loss_masks(src_masks, target_masks, num_masks, oversample_ratio=3.0):
"""Compute the losses related to the masks: the focal loss and the dice loss.
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
"""

# No need to upsample predictions as we are using normalized coordinates :)

with torch.no_grad():
# sample point_coords
point_coords = get_uncertain_point_coords_with_randomness(
src_masks,
lambda logits: calculate_uncertainty(logits),
112 * 112,
oversample_ratio,
0.75,
)
# get gt labels
point_labels = point_sample(
target_masks,
point_coords,
align_corners=False,
).squeeze(1)

point_logits = point_sample(
src_masks,
point_coords,
align_corners=False,
).squeeze(1)

loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks)
loss_dice = dice_loss_jit(point_logits, point_labels, num_masks)

del src_masks
del target_masks
return loss_mask, loss_dice
------------------------------------------------------------------------------------------------------------
get_uncertain_point_coords_with_randomness:
"""
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
are calculated for each point using 'uncertainty_func' function that takes point's logit
prediction as input.
See PointRend paper for details.
Args:
coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
class-specific or class-agnostic prediction.
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
contains logit predictions for P points and returns their uncertainties as a Tensor of
shape (N, 1, P).
num_points (int): The number of points P to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
sampled points.
"""
point_sample:
"""
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""

sigmoid_ce_loss_jit = torch.jit.script(
sigmoid_ce_loss
) # type: torch.jit.ScriptModule
def sigmoid_ce_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
Loss tensor
"""
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

return loss.mean(1).sum() / num_masks

dice_loss_jit = torch.jit.script(
dice_loss
) # type: torch.jit.ScriptModule
def dice_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(-1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_masks
------------------------------------------------------------------------------------------------------------