From 8807b5833645a7f0a3a5327e575cba9a5ca7485e Mon Sep 17 00:00:00 2001 From: furkan-celik Date: Thu, 8 Apr 2021 19:49:05 +0300 Subject: [PATCH] Implemented notebook tqdm support --- src/NERDA/predictions.py | 10 +++++++++- src/NERDA/training.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/NERDA/predictions.py b/src/NERDA/predictions.py index 14dd557..c5045fb 100644 --- a/src/NERDA/predictions.py +++ b/src/NERDA/predictions.py @@ -6,12 +6,20 @@ from .preprocessing import create_dataloader import torch import numpy as np -from tqdm import tqdm from nltk.tokenize import sent_tokenize, word_tokenize from typing import List, Callable import transformers import sklearn.preprocessing +try: + from IPython import get_ipython + if 'IPKernelApp' in get_ipython().config: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm +except: + from tqdm import tqdm + def predict(network: torch.nn.Module, sentences: List[List[str]], transformer_tokenizer: transformers.PreTrainedTokenizer, diff --git a/src/NERDA/training.py b/src/NERDA/training.py index 7c52c37..34d4a78 100644 --- a/src/NERDA/training.py +++ b/src/NERDA/training.py @@ -4,7 +4,15 @@ from transformers import AdamW, get_linear_schedule_with_warmup import random import torch -from tqdm import tqdm + +try: + from IPython import get_ipython + if 'IPKernelApp' in get_ipython().config: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm +except: + from tqdm import tqdm def train(model, data_loader, optimizer, device, scheduler, n_tags): """One Iteration of Training"""