Model tools
- expert.core.contradiction.contr_tools.model_tools.create_model(lang: str = 'en', device: str = 'cpu')[source]
Function for creating the model. Defines model structure and download weights.
- Parameters
lang (str, optional) – Speech language for text processing [‘ru’, ‘en’]. Defaults to ‘en’.
device (str, optional) – Device type on local machine (GPU recommended). Defaults to ‘cpu’.
- Raises
NotImplementedError – If ‘language’ is not equal to ‘en’ or ‘ru’.
- Returns
Model.
- Return type
[torch.model]
- expert.core.contradiction.contr_tools.model_tools.averaging(prem_type, prem_t, hypo_t, model, tokenizer, device='cpu')[source]
Function for averaging predictions for long texts (longer than 512 tokens).
- expert.core.contradiction.contr_tools.model_tools.predict_inference(premise: str, hypothesis: str, model, lang='en', device='cpu')[source]
- Function for prediction, returns labels:
0 - entailment; 1 - contradiction; 2 - neutral.
- Parameters
premise (str) – Entered text.
hypothesis (str) – Text for analysis.
model (torch.nn.model) – Get model structure and weights.
lang (str, optional) – Speech language for text processing [‘ru’, ‘en’]. Defaults to ‘en’.
device (torch.device, optional) – Device type on local machine (GPU recommended). Defaults to None.
- Raises
NotImplementedError – If ‘language’ is not equal to ‘en’ or ‘ru’.
- Returns
Label of prediction.
- Return type
[torch.LongTensor]