Source code for expert.core.contradiction.contr_tools.model_tools

from __future__ import annotations

import torch
import torch.nn as nn
from transformers import (
    AutoModel,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    logging,
)

from expert.core.contradiction.contr_tools import NLIModel
from expert.core.functional_tools import get_model_weights


logging.set_verbosity_error()


[docs]def create_model(lang: str = "en", device: str = "cpu"): """Function for creating the model. Defines model structure and download weights. Args: lang (str, optional): Speech language for text processing ['ru', 'en']. Defaults to 'en'. device (str, optional): Device type on local machine (GPU recommended). Defaults to 'cpu'. Raises: NotImplementedError: If 'language' is not equal to 'en' or 'ru'. Returns: [torch.model]: Model. """ if lang == "en": model = AutoModel.from_pretrained("prajjwal1/bert-medium") url = "https://drive.google.com/open?id=1sJXQqnXnnJsOEbT3pbDS9c97z4x10SXm&authuser=0" model_name = "bert-nli-medium.pt" cached_file = get_model_weights(model_name=model_name, url=url) model = NLIModel.BERTNLIModel(model).to(device) model.load_state_dict(torch.load(cached_file, map_location=device)) elif lang == "ru": model_checkpoint = "cointegrated/rubert-base-cased-nli-threeway" model = AutoModelForSequenceClassification.from_pretrained( model_checkpoint ) model = model.to(device) else: raise NotImplementedError("'lang' must be 'en' or 'ru'.") return model
def choose_toketizer(lang): if lang == "en": tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-medium") elif lang == "ru": tokenizer = AutoTokenizer.from_pretrained( "cointegrated/rubert-base-cased-nli-threeway" ) else: raise NameError return tokenizer def get_sent1_token_type(sent): try: return [0] * len(sent) except ValueError: return [] def get_sent2_token_type(sent): try: return [1] * len(sent) except ValueError: return [] def tokenize_bert(sentence, tokenizer): tokens = tokenizer.tokenize(sentence) return tokens
[docs]def averaging(prem_type, prem_t, hypo_t, model, tokenizer, device="cpu"): """Function for averaging predictions for long texts (longer than 512 tokens).""" func = nn.Softmax(dim=1) hypo_size = 512 - len(prem_t) parts = [] predictions = [] for i in range(0, len(hypo_t), hypo_size): parts.append(hypo_t[i : i + hypo_size]) for part in parts: hypo_t = part hypo_type = get_sent2_token_type(hypo_t) indexes = prem_t + hypo_t indexes = tokenizer.convert_tokens_to_ids(indexes) indexes_type = prem_type + hypo_type attn_mask = get_sent2_token_type(indexes) indexes = torch.LongTensor(indexes).unsqueeze(0).to(device) indexes_type = torch.LongTensor(indexes_type).unsqueeze(0).to(device) attn_mask = torch.LongTensor(attn_mask).unsqueeze(0).to(device) preds = func(model(indexes, attn_mask, indexes_type)) predictions.append( [float(preds[0][0]), float(preds[0][1]), float(preds[0][2])] ) # Averaging. predictions = torch.tensor(predictions) / len(predictions) prediction = torch.tensor( [ predictions[:, 0].sum(), predictions[:, 1].sum(), predictions[:, 2].sum(), ] ) prediction.unsqueeze_(0) return prediction
[docs]def predict_inference( premise: str, hypothesis: str, model, lang="en", device="cpu" ): """Function for prediction, returns labels: 0 - entailment; 1 - contradiction; 2 - neutral. Args: premise (str): Entered text. hypothesis (str): Text for analysis. model (torch.nn.model): Get model structure and weights. lang (str, optional): Speech language for text processing ['ru', 'en']. Defaults to 'en'. device (torch.device, optional): Device type on local machine (GPU recommended). Defaults to None. Raises: NotImplementedError: If 'language' is not equal to 'en' or 'ru'. Returns: [torch.LongTensor]: Label of prediction. """ if lang not in ["en", "ru"]: raise NotImplementedError("'lang' must be 'en' or 'ru'.") else: tokenizer = choose_toketizer(lang) func = nn.Softmax(dim=1) model.eval() model.to(device) premise = "[CLS] " + premise + " [SEP]" hypothesis = hypothesis + " [SEP]" prem_t = tokenize_bert(premise, tokenizer) if len(prem_t) > 512: return f""" The chosen text is too large (={len(prem_t)}). Sum of tokens should be less or equal to 512 """ hypo_t = tokenize_bert(hypothesis, tokenizer) if len(prem_t) + len(hypo_t) <= 512: prem_type = get_sent1_token_type(prem_t) hypo_type = get_sent2_token_type(hypo_t) indexes = prem_t + hypo_t indexes = tokenizer.convert_tokens_to_ids(indexes) indexes_type = prem_type + hypo_type attn_mask = get_sent2_token_type(indexes) indexes = torch.LongTensor(indexes).unsqueeze(0).to(device) indexes_type = torch.LongTensor(indexes_type).unsqueeze(0).to(device) attn_mask = torch.LongTensor(attn_mask).unsqueeze(0).to(device) prediction = model(indexes, attn_mask, indexes_type) # Models have different output types. if lang == "en": return hypothesis, func(prediction).argmax() elif lang == "ru": return hypothesis, torch.softmax(prediction.logits, -1).argmax() else: hypo_size = 512 - len(prem_t) prem_type = get_sent1_token_type(prem_t) parts = hypothesis.split(".") parts = dict(zip(range(len(parts)), parts)) for key, value in parts.items(): hypo_t = tokenize_bert(value, tokenizer) if len(hypo_t) <= hypo_size: hypo_type = get_sent2_token_type(hypo_t) indexes = prem_t + hypo_t indexes = tokenizer.convert_tokens_to_ids(indexes) indexes_type = prem_type + hypo_type attn_mask = get_sent2_token_type(indexes) indexes = torch.LongTensor(indexes).unsqueeze(0).to(device) indexes_type = ( torch.LongTensor(indexes_type).unsqueeze(0).to(device) ) attn_mask = torch.LongTensor(attn_mask).unsqueeze(0).to(device) prediction = model(indexes, attn_mask, indexes_type) else: hypo_t = tokenize_bert(value, tokenizer) prediction = averaging( prem_type, prem_t, hypo_t, model, tokenizer ) if lang == "en": return hypothesis, func(prediction).argmax() elif lang == "ru": return hypothesis, torch.softmax(prediction.logits, -1).argmax()