Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__/

#data
data.json
data.json*

#logs
logs/
Expand Down
2 changes: 2 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ loss_alpha: 0.75
loss_gamma: 0
label_smoothing: 0
loss_reduction: "sum"
negative_rate: 0.75
neg_span_masking : "global_w_threshold"

# Learning Rate and weight decay Configuration
lr_encoder: 1e-5
Expand Down
2 changes: 2 additions & 0 deletions gliner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self,
max_types: int = 25,
max_len: int = 384,
words_splitter_type: str = "whitespace",
neg_span_masking: str = None,
has_rnn: bool = True,
fuse_layers: bool = False,
embed_ent_token: bool = True,
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(self,
self.embed_ent_token = embed_ent_token
self.ent_token = ent_token
self.sep_token = sep_token
self.neg_span_masking=neg_span_masking

# Register the configuration
from transformers import CONFIG_MAPPING
Expand Down
36 changes: 35 additions & 1 deletion gliner/modeling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def forward(self,

def loss(self, scores, labels, prompts_embedding_mask, mask_label,
alpha: float = -1., gamma: float = 0.0, label_smoothing: float = 0.0,
reduction: str = 'sum', **kwargs):
reduction: str = 'sum', negative_rate: float = 0.75, neg_span_masking: str = None, **kwargs):

batch_size = scores.shape[0]
num_classes = prompts_embedding_mask.shape[-1]
Expand All @@ -274,6 +274,40 @@ def loss(self, scores, labels, prompts_embedding_mask, mask_label,

all_losses = all_losses * mask_label.float()

if neg_span_masking is not None :

if neg_span_masking == "global_w_threshold":

mask_negative_examples = (torch.rand_like(labels, dtype=torch.float) + labels > negative_rate).float()
all_losses = all_losses * mask_negative_examples

elif neg_span_masking == "global_wo_threshold" :

p = torch.sigmoid(scores)
random_mask = torch.bernoulli(1 - p) + labels
mask_negative_examples = torch.where(labels == 1, torch.ones_like(labels), random_mask)
all_losses = all_losses*mask_negative_examples

elif neg_span_masking == "entity_w_threshold":

mask_negative_examples = labels.clone()
zero_rows = labels.sum(dim=1) == 0
mask_negative_examples[zero_rows] = (torch.rand((zero_rows.sum(), labels.size(1))) >= negative_rate).float()
all_losses = all_losses*mask_negative_examples

elif neg_span_masking == "entity_wo_threshold":

p = torch.sigmoid(scores)
mask = labels.clone()
rows_to_sample = labels.sum(dim=1) == 0
mask[rows_to_sample] = torch.bernoulli(p[rows_to_sample])

else:

warnings.warn(
f"Invalid Value for config 'neg_span_masking': '{neg_span_masking}. ")


if reduction == "mean":
loss = all_losses.mean()
elif reduction == 'sum':
Expand Down