def forward(self, **kwargs):
x = kwargs[self.input_key]
x = self.rgb_bn(x)
x = self.encoder.forward_features(x)
embedding = self.pool(x)
result = {
OUTPUT_PRED_MODIFICATION_FLAG: self.flag_classifier(self.drop(embedding)),
OUTPUT_PRED_MODIFICATION_TYPE: self.type_classifier(self.drop(embedding)),
if self.need_embedding:
result[OUTPUT_PRED_EMBEDDING] = embedding
if self.arc_margin is not None:
result[OUTPUT_PRED_EMBEDDING_ARC_MARGIN] = self.arc_margin(embedding)
return result
