diff --git a/src/NERDA/predictions.py b/src/NERDA/predictions.py index 14dd557..c5045fb 100644 --- a/src/NERDA/predictions.py +++ b/src/NERDA/predictions.py @@ -6,12 +6,20 @@ from .preprocessing import create_dataloader import torch import numpy as np -from tqdm import tqdm from nltk.tokenize import sent_tokenize, word_tokenize from typing import List, Callable import transformers import sklearn.preprocessing +try: + from IPython import get_ipython + if 'IPKernelApp' in get_ipython().config: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm +except: + from tqdm import tqdm + def predict(network: torch.nn.Module, sentences: List[List[str]], transformer_tokenizer: transformers.PreTrainedTokenizer, diff --git a/src/NERDA/training.py b/src/NERDA/training.py index 7c52c37..34d4a78 100644 --- a/src/NERDA/training.py +++ b/src/NERDA/training.py @@ -4,7 +4,15 @@ from transformers import AdamW, get_linear_schedule_with_warmup import random import torch -from tqdm import tqdm + +try: + from IPython import get_ipython + if 'IPKernelApp' in get_ipython().config: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm +except: + from tqdm import tqdm def train(model, data_loader, optimizer, device, scheduler, n_tags): """One Iteration of Training"""