import torch import math class ArcFace(torch.nn.Module): """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): """ def __init__(self, s=64.0, margin=0.5): super(ArcFace, self).__init__() self.scale = s self.cos_m = math.cos(margin) self.sin_m = math.sin(margin) self.theta = math.cos(math.pi - margin) self.sinmm = math.sin(math.pi - margin) * margin self.easy_margin = False def forward(self, logits: torch.Tensor, labels: torch.Tensor): index = torch.where(labels != -1)[0] target_logit = logits[index, labels[index].view(-1)] sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) if self.easy_margin: final_target_logit = torch.where( target_logit > 0, cos_theta_m, target_logit) else: final_target_logit = torch.where( target_logit > self.theta, cos_theta_m, target_logit - self.sinmm) logits[index, labels[index].view(-1)] = final_target_logit logits = logits * self.scale return logits class CosFace(torch.nn.Module): def __init__(self, s=64.0, m=0.40): super(CosFace, self).__init__() self.s = s self.m = m def forward(self, logits: torch.Tensor, labels: torch.Tensor): index = torch.where(labels != -1)[0] target_logit = logits[index, labels[index].view(-1)] final_target_logit = target_logit - self.m logits[index, labels[index].view(-1)] = final_target_logit logits = logits * self.s return logits