Source code for expert.core.contradiction.contr_tools.NLIModel
import torch.nn as nn
[docs]class BERTNLIModel(nn.Module):
def __init__(self, model):
super().__init__()
output_dim = 3
self.bert = model
embedding_dim = self.bert.config.to_dict()["hidden_size"]
self.out = nn.Linear(embedding_dim, output_dim)
[docs] def forward(self, sequence, attn_mask, token_type):
embedded = self.bert(
input_ids=sequence,
attention_mask=attn_mask,
token_type_ids=token_type,
)[1]
output = self.out(embedded)
return output