From a74772e24bcbb3ec3181463325ce6af2dbff728e Mon Sep 17 00:00:00 2001 From: Bhuvanesh Verma Date: Mon, 19 Sep 2022 14:48:26 +0200 Subject: [PATCH 1/4] add weak span based f1 metric --- src/pie_utils/metric/__init__.py | 0 src/pie_utils/metric/weak_span_based_f1.py | 236 +++++++++++++++ tests/metric/__init__.py | 0 tests/metric/test_weak_span_based_f1.py | 328 +++++++++++++++++++++ 4 files changed, 564 insertions(+) create mode 100644 src/pie_utils/metric/__init__.py create mode 100644 src/pie_utils/metric/weak_span_based_f1.py create mode 100644 tests/metric/__init__.py create mode 100644 tests/metric/test_weak_span_based_f1.py diff --git a/src/pie_utils/metric/__init__.py b/src/pie_utils/metric/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pie_utils/metric/weak_span_based_f1.py b/src/pie_utils/metric/weak_span_based_f1.py new file mode 100644 index 0000000..429ac23 --- /dev/null +++ b/src/pie_utils/metric/weak_span_based_f1.py @@ -0,0 +1,236 @@ +from collections import defaultdict +from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple + +import torch +from torchmetrics import Metric + +from pie_utils.sequence_tagging import tag_sequence_to_token_spans +from pie_utils.span.slice import get_overlap_len + +TypedStringSpan = Tuple[str, Tuple[int, int]] +TAGS_TO_SPANS_FUNCTION_TYPE = Callable[[List[str], Optional[List[str]]], List[TypedStringSpan]] + + +def has_weak_overlap(indices_1: Tuple[int, int], indices_2: Tuple[int, int]) -> bool: + # checks if overlap in span is at least half of the length of the shorter span + min_len = min(indices_1[1] - indices_1[0], indices_2[1] - indices_2[0]) + overlap_len = get_overlap_len(indices_1, indices_2) + return 2 * overlap_len >= min_len + + +def increase_span_end_index( + span: Tuple[str, Tuple[int, int]], offset: int +) -> Tuple[str, Tuple[int, int]]: + # format of span is (label,(start,end)) + return span[0], (span[1][0], span[1][1] + offset) + + +def get_weak_match( + span: Tuple[str, Tuple[int, int]], + gold_spans: List[Tuple[str, Tuple[int, int]]], + inclusive_end_index: bool = False, +) -> Optional[Tuple[str, Tuple[int, int]]]: + """This method checks if the predicted span is weakly matched with any of the gold spans. If + predicted type and gold type matches then we check if their respective indices are weakly + overlapping or not. Weak overlap between gold and predicted span is defined in Lauscher et al. + (2018) as overlap which should be at least half of the length of shorter span. If they are + weakly overlapping as well then we return the matched span. In addition to this, we use + inclusive_end_index boolean which if set adds an offset to the end index of each span in gold + spans list and also to predicted span. Once a match is found we revert back changes to end + index of matched span. Due to AllenNLP token based predictions, a span containing single token + would have length of 0 since start and end index of span would be same. That is why we add an + offset to end index. + + :param span: Predicted span instance as a tuple with span label and indices(start and end) of span. + :param gold_spans: List of gold span instances as tuple with span label and indices(start and end) of span. + :param inclusive_end_index: if set adds an offset to the end index of each span in gold spans list and also to + predicted span. Once a match is found we revert back changes to end index of matched span. + :return: gold span instance if matched with predicted span instance else None + """ + if inclusive_end_index: + span = increase_span_end_index(span, offset=1) + gold_spans = [increase_span_end_index(gold_span, offset=1) for gold_span in gold_spans] + + match_found = None + predicted_type, predicted_indices = span + for gold_type, gold_indices in gold_spans: + if predicted_type == gold_type and has_weak_overlap(predicted_indices, gold_indices): + match_found = gold_type, gold_indices + if inclusive_end_index: + match_found = increase_span_end_index(match_found, offset=-1) + break + return match_found + + +def get_span_classes(label_vocabulary): + return { + label.split("-")[1] + for label in list(label_vocabulary.values()) + if len(label.split("-")) == 2 + } + + +class SpanBasedF1WeakMeasure(Metric): + def __init__( + self, + label_to_id: Dict[str, int], + weak: bool = True, + return_metric: str = "micro/f1", + label_encoding: str = "IOB2", + ignore_classes: List[str] = None, + dist_sync_on_step=False, + ): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self._label_encoding = label_encoding + self._ignore_classes: List[str] = ignore_classes or [] + self._weak = weak + self._label_vocabulary = dict(zip(label_to_id.values(), label_to_id.keys())) + self._true_positives: Dict[str, int] = defaultdict(int) + self._false_positives: Dict[str, int] = defaultdict(int) + self._false_negatives: Dict[str, int] = defaultdict(int) + self._span_classes = list(get_span_classes(self._label_vocabulary)) + self._num_classes = len(self._span_classes) + self._return_metric = return_metric + self._span_classes_to_index = {c: i for i, c in enumerate(self._span_classes)} + + def default(): + return torch.zeros([self._num_classes], dtype=torch.long) + + for s in ("tp", "fp", "tn", "fn"): + self.add_state(s, default=default(), dist_reduce_fx="sum") + + def update( + self, + preds: torch.Tensor, + targets: torch.Tensor, + masks: Optional[torch.BoolTensor] = None, + prediction_map: Optional[torch.Tensor] = None, + ): + self.calculate_span_based_metric( + preds=preds, targets=targets, masks=masks, prediction_map=prediction_map + ) + + def calculate_span_based_metric( + self, + preds: torch.Tensor, + targets: torch.Tensor, + masks: Optional[torch.BoolTensor] = None, + prediction_map: Optional[torch.Tensor] = None, + ): + if masks is None: + # masks = torch.ones_like(targets).bool() This will result in a tensor with all values True. + # It will result in error since targets contain -100 as value which has no label. + masks = targets != -100 + """ + If you actually passed gradient-tracking Tensors to a Metric, there will be + a huge memory leak, because it will prevent garbage collection for the computation + graph. This method ensures the tensors are detached. + Check if it's actually a tensor in case something else was passed. + """ + predictions, gold_labels, mask, prediction_map = ( + x.detach() if isinstance(x, torch.Tensor) else x + for x in (preds, targets, masks, prediction_map) + ) + + sequence_lengths = masks.sum(-1) + argmax_predictions = preds.argmax(dim=2) + + if prediction_map is not None: + argmax_predictions = torch.gather(prediction_map, 1, argmax_predictions) + gold_labels = torch.gather(prediction_map, 1, gold_labels.long()) + + argmax_predictions = argmax_predictions.float() + + batch_size = gold_labels.size(0) + for i in range(batch_size): + sequence_prediction = argmax_predictions[i, :] + sequence_gold_label = gold_labels[i, :] + length = sequence_lengths[i] + mask = masks[i, :] + if length == 0: + # It is possible to call this metric with sequences which are + # completely padded. These contribute nothing, so we skip these rows. + continue + + predicted_string_labels = [ + self._label_vocabulary[label_id] for label_id in sequence_prediction[mask].tolist() + ] + gold_string_labels = [ + self._label_vocabulary[label_id] for label_id in sequence_gold_label[mask].tolist() + ] + + predicted_spans = tag_sequence_to_token_spans( + tag_sequence=predicted_string_labels, + coding_scheme=self._label_encoding, + classes_to_ignore=self._ignore_classes, + ) + gold_spans = tag_sequence_to_token_spans( + tag_sequence=gold_string_labels, + coding_scheme=self._label_encoding, + classes_to_ignore=self._ignore_classes, + ) + + # Sorting spans so that it is deterministic all the time (handle_continued_spans may not maintain the order) + predicted_spans = sorted(predicted_spans) + gold_spans = sorted(gold_spans) + for span in predicted_spans: + span_original = span + if self._weak: + span = get_weak_match(span, gold_spans, inclusive_end_index=True) + if (not self._weak and span in gold_spans) or (self._weak and span): + self.tp[self._span_classes_to_index[span[0]]] += 1 + gold_spans.remove(span) + else: + if self._weak: + span = span_original + self.fp[self._span_classes.index(span[0])] += 1 + # These spans weren't predicted. + for span in gold_spans: + self.fn[self._span_classes.index(span[0])] += 1 + + def compute(self): + """Scores is a matrix of dimensions num_span_classes + 2 x 3. + + Here 3 signifies precision, recall and f1 and 2 signifies micro and macro averaged metric + scores + """ + + scores = torch.zeros([self._num_classes + 2, 3]) + for i, tag in enumerate(self._span_classes): + scores[i] = compute_metrics(self.tp[i], self.fp[i], self.fn[i]) + scores[self._num_classes + 1] += scores[i] + + # macro averaged metrics + scores[self._num_classes + 1] = scores[self._num_classes + 1] / 3 + + # micro averaged metrics + scores[self._num_classes] = compute_metrics( + true_positives=sum(self.tp), + false_positives=sum(self.fp), + false_negatives=sum(self.fn), + ) + + return self.get_return_metric(scores) + + def get_return_metric(self, scores): + """It returns metric based on return_metric parameter. + + return_metric parameter is defined as TAG/METRIC. + """ + tag_name, _metric_name = self._return_metric.split("/") + metric_to_idx = {"precision": 0, "recall": 1, "f1": 2} + if tag_name == "micro": + index = self._num_classes + elif tag_name == "macro": + index = self._num_classes + 1 + else: + index = self._span_classes.index(tag_name) + return scores[index, metric_to_idx[_metric_name]] + + +def compute_metrics(true_positives: int, false_positives: int, false_negatives: int): + precision = true_positives / (true_positives + false_positives + 1e-13) + recall = true_positives / (true_positives + false_negatives + 1e-13) + f1_measure = 2.0 * (precision * recall) / (precision + recall + 1e-13) + return torch.tensor([precision, recall, f1_measure]) diff --git a/tests/metric/__init__.py b/tests/metric/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/metric/test_weak_span_based_f1.py b/tests/metric/test_weak_span_based_f1.py new file mode 100644 index 0000000..aeca9da --- /dev/null +++ b/tests/metric/test_weak_span_based_f1.py @@ -0,0 +1,328 @@ +import pytest +import torch + +from pie_utils.metric.weak_span_based_f1 import ( + SpanBasedF1WeakMeasure, + compute_metrics, + get_weak_match, + has_weak_overlap, + increase_span_end_index, +) + +# Predicted Tag Sequence : +# [[O,B-own_claim,I-own_claim,I-own_claim,I-own_claim,I-own_claim],[I-own_claim,I-own_claim,I-own_claim,B-data,I-data,I-data]] +LOGITS = torch.tensor( + [ + [ + [ + 10.351420402526855, + -3.2028937339782715, + -2.790834665298462, + -2.648576021194458, + -3.457158327102661, + -2.2753586769104004, + -3.092430591583252, + ], + [ + -0.728147566318512, + -0.3243664503097534, + -3.784968137741089, + -0.566190779209137, + -3.6368443965911865, + 11.714195251464844, + 0.874237596988678, + ], + [ + -2.2986810207366943, + -3.2249557971954346, + -2.213132381439209, + -3.176085948944092, + -1.201798439025879, + -2.5730767250061035, + 10.175483703613281, + ], + [ + -2.399078845977783, + -3.7054524421691895, + -2.732715606689453, + -3.1169371604919434, + -1.3757351636886597, + -2.5311129093170166, + 10.664148330688477, + ], + [ + -1.685034155845642, + -3.0339670181274414, + -2.960148572921753, + -3.0828864574432373, + -1.1170631647109985, + -2.528775215148926, + 10.603050231933594, + ], + [ + -2.307305097579956, + -3.424372673034668, + -2.1860435009002686, + -3.231368064880371, + -1.3228000402450562, + -2.4107933044433594, + 10.90400218963623, + ], + ], + [ + [ + -2.7964279651641846, + -2.7935304641723633, + -2.4361226558685303, + -2.8135015964508057, + -1.101932406425476, + 0.03466780483722687, + 10.032923698425293, + ], + [ + -2.841414451599121, + -2.925645351409912, + -2.197348117828369, + -3.377918004989624, + -1.2031382322311401, + -1.7914681434631348, + 10.230854034423828, + ], + [ + -2.645219087600708, + -3.736269235610962, + -2.1186366081237793, + -3.4355688095092773, + -0.8101659417152405, + -2.3279330730438232, + 10.49844741821289, + ], + [ + -2.3094680309295654, + -3.7092173099517822, + -1.942619800567627, + 10.281001091003418, + -3.1347908973693848, + -1.1283811330795288, + -2.2268154621124268, + ], + [ + -3.747760772705078, + -2.015471935272217, + -3.379420757293701, + -2.101996421813965, + 10.566914558410645, + -3.555516242980957, + -2.556119441986084, + ], + [ + -3.747760772705078, + -2.015471935272217, + -3.379420757293701, + -2.101996421813965, + 10.566914558410645, + -3.555516242980957, + -2.556119441986084, + ], + ], + ] +) + +# Original Tag sequence: +# [[PAD,O,B-own_claim,I-own_claim,I-own_claim,PAD],[PAD,PAD,B-background_claim,I-background_claim,B-data,PAD]] +TARGETS = torch.tensor( + [ + [ + -100, + 0, + 5, + 6, + 6, + -100, + ], + [ + -100, + -100, + 1, + 2, + 3, + -100, + ], + ] +) + +MASKS = TARGETS != -100 + +LABEL_TO_ID = { + "O": 0, + "B-background_claim": 1, + "I-background_claim": 2, + "B-data": 3, + "I-data": 4, + "B-own_claim": 5, + "I-own_claim": 6, +} + + +@pytest.mark.parametrize( + "masks", + [MASKS, None], +) +def test_update(masks): + """Given instance contains four spans, two of which are own_claim and one each for data and + background_claim. Model predicts one span labelled as own_claim correctly. Moreover, it + predicts a span as data falsely. Therefore, resulting metric should contain 1 true positive + count for 'own_claim', 1 false positive count for 'data'. and 1 false negative count for each + label. + + predicted spans : [own_claim (1,5)] , [own_claim (0,2), data (3,5)] + gold spans : [own_claim (2,4)] , [background_claim (2,3), data (4)] + tp : own_claim=1, data=1, background_claim=0 + fp : own_claim=1, data=0, background_claim=0 + fn : own_claim=0, data=0, background_claim=1 + """ + + metric = SpanBasedF1WeakMeasure(label_to_id=LABEL_TO_ID, return_metric="micro/f1") + assert torch.equal(metric.tp, torch.zeros([3], dtype=torch.int64)) + assert torch.equal(metric.fp, torch.zeros([3], dtype=torch.int64)) + assert torch.equal(metric.fn, torch.zeros([3], dtype=torch.int64)) + metric.update(preds=torch.tensor(LOGITS), targets=torch.tensor(TARGETS), masks=masks) + expected_true_positives = torch.zeros([3], dtype=torch.int64) + expected_true_positives[metric._span_classes.index("own_claim")] = 1 + expected_true_positives[metric._span_classes.index("data")] = 1 + assert torch.equal(metric.tp, expected_true_positives) + + expected_false_positives = torch.zeros([3], dtype=torch.int64) + expected_false_positives[metric._span_classes.index("own_claim")] = 1 + assert torch.equal(metric.fp, expected_false_positives) + + expected_false_negatives = torch.zeros([3], dtype=torch.int64) + expected_false_negatives[metric._span_classes.index("background_claim")] = 1 + assert torch.equal(metric.fn, expected_false_negatives) + + +@pytest.mark.parametrize( + "return_metric", + ["own_claim/f1", "own_claim/precision", "own_claim/recall", "macro/f1", "micro/f1"], +) +def test_compute(return_metric): + """Given instance contains four spans, two of which are own_claim and one each for data and + background_claim. Model predicts one span labelled as own_claim correctly. Moreover, it + predicts a span as data falsely. Based on these predictions, scores will be calculated for + different return metrics. + + own_claim/f1 : 2*1*0.5/(0.5+1) = 0.667 + own_claim/precision : 1/(1+1) = 0.5 + own_claim/recall : 1/(1+0) = 1 + macro/f1 : (0.667 + 1 + 0)/3 = 0.556 + micro/f1 : (2*0.667*0.667)/(0.667+0.667) = 0.667 + """ + metric = SpanBasedF1WeakMeasure(label_to_id=LABEL_TO_ID, return_metric=return_metric) + metric.update( + preds=torch.tensor(LOGITS), targets=torch.tensor(TARGETS), masks=torch.tensor(MASKS) + ) + scores = metric.compute() + if return_metric == "own_claim/f1": + assert pytest.approx(scores) == torch.tensor(0.6666666865348816) + if return_metric == "own_claim/precision": + assert torch.eq(scores, torch.tensor(0.5)) + if return_metric == "own_claim/recall": + assert torch.eq(scores, torch.tensor(1)) + if return_metric == "macro/f1": + assert pytest.approx(scores) == torch.tensor(0.5555555820465088) + if return_metric == "micro/f1": + assert pytest.approx(scores) == torch.tensor(0.6666666865348816) + + +def test_has_weak_overlap(): + # checks if overlap in span is at least half of the length of the shorter span. + + weak_overlapping_indices = ((0, 5), (3, 6)) + no_weak_overlapping_indices = ((0, 5), (4, 9)) + touching_indices = ((0, 5), (5, 7)) + non_touching_indices = ((0, 5), (6, 7)) + containing_indices = ((0, 9), (3, 6)) + + assert has_weak_overlap(weak_overlapping_indices[0], weak_overlapping_indices[1]) + assert not has_weak_overlap(no_weak_overlapping_indices[0], no_weak_overlapping_indices[1]) + assert not has_weak_overlap(touching_indices[0], touching_indices[1]) + assert not has_weak_overlap(non_touching_indices[0], non_touching_indices[1]) + assert has_weak_overlap(containing_indices[0], containing_indices[1]) + + assert has_weak_overlap(weak_overlapping_indices[1], weak_overlapping_indices[0]) + assert not has_weak_overlap(no_weak_overlapping_indices[1], no_weak_overlapping_indices[0]) + assert not has_weak_overlap(touching_indices[1], touching_indices[0]) + assert not has_weak_overlap(non_touching_indices[1], non_touching_indices[0]) + assert has_weak_overlap(containing_indices[1], containing_indices[0]) + + +def test_compute_metrics(): + true_positives = 10 + false_positives = 8 + false_negatives = 4 + true_precision = 10 / (10 + 8) + true_recall = 10 / (10 + 4) + true_f1 = (2 * true_recall * true_precision) / (true_recall + true_precision) + precision, recall, f1 = compute_metrics( + true_positives=true_positives, + false_positives=false_positives, + false_negatives=false_negatives, + ) + assert precision == true_precision + assert recall == true_recall + assert f1 == true_f1 + + +def test_increase_span_end_index(): + # Resulting span should have end index adjusted by the value of the offset, i.e. the offset value should be added + # to the end index of the span + span = ("city", (0, 0)) + new_span = increase_span_end_index(span=span, offset=1) + assert new_span == ("city", (0, 1)) + + original_span = increase_span_end_index(span=new_span, offset=-1) + assert original_span == span + + +def test_get_weak_match(): + # Since predicted span is contained in first span of the gold span list, it will be considered as the weak match. + span = ("city", (0, 0)) + spans = [("city", (0, 3)), ("person", (3, 5)), ("person", (6, 9))] + match = get_weak_match(span=span, gold_spans=spans) + assert match == ("city", (0, 3)) + + # Predicted span is not contained in the second gold span, therefore it is not considered as a match. + span = ("person", (3, 5)) + spans = [("city", (0, 3)), ("person", (5, 9))] + match = get_weak_match(span=span, gold_spans=spans) + assert match is None + + # Here predicted span is partly inside the second gold span, there we have a weak match + span = ("person", (3, 5)) + spans = [("city", (0, 3)), ("person", (4, 7))] + match = get_weak_match(span=span, gold_spans=spans) + assert match == ("person", (4, 7)) + + # Here predicted span is partly inside the second gold span but not enough to be considered as a weak match. + span = ("person", (3, 6)) + spans = [("city", (0, 3)), ("person", (5, 9))] + match = get_weak_match(span=span, gold_spans=spans) + assert match is None + + +@pytest.mark.parametrize( + "inclusive_end_index", + [True, False], +) +def test_get_weak_match_with_inclusive_end_index(inclusive_end_index): + # Here in the given span if we consider end index to be part of span then we have weak match with the first span + # in the list of gold spans. However, if we do not consider end index as inclusive to the span, then there is no + # overlap and hence no weak match. + span = ("city", (0, 2)) + spans = [("city", (2, 3)), ("person", (3, 5))] + + match = get_weak_match(span=span, gold_spans=spans, inclusive_end_index=inclusive_end_index) + if inclusive_end_index: + assert match == ("city", (2, 3)) + else: + assert match is None From 4ef76fdd88522f74fcbaa70eb7832bf703a6894f Mon Sep 17 00:00:00 2001 From: Bhuvanesh Verma Date: Tue, 20 Sep 2022 11:15:49 +0200 Subject: [PATCH 2/4] add documentation --- src/pie_utils/metric/weak_span_based_f1.py | 176 +++++++++++++++++---- 1 file changed, 141 insertions(+), 35 deletions(-) diff --git a/src/pie_utils/metric/weak_span_based_f1.py b/src/pie_utils/metric/weak_span_based_f1.py index 429ac23..ec1eedc 100644 --- a/src/pie_utils/metric/weak_span_based_f1.py +++ b/src/pie_utils/metric/weak_span_based_f1.py @@ -1,5 +1,4 @@ -from collections import defaultdict -from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch from torchmetrics import Metric @@ -12,7 +11,20 @@ def has_weak_overlap(indices_1: Tuple[int, int], indices_2: Tuple[int, int]) -> bool: - # checks if overlap in span is at least half of the length of the shorter span + """This method checks if span overlap is at least half the length of the shorter span. + + # Parameters: + + indices_1: `Tuple[int, int]` , required + Span slice or indices of the first span + indices_2: `Tuple[int, int]` , required + Span slice or indices of the second span + + # Returns: + `bool` + if two slices are weakly overlap or not + """ + min_len = min(indices_1[1] - indices_1[0], indices_2[1] - indices_2[0]) overlap_len = get_overlap_len(indices_1, indices_2) return 2 * overlap_len >= min_len @@ -21,7 +33,19 @@ def has_weak_overlap(indices_1: Tuple[int, int], indices_2: Tuple[int, int]) -> def increase_span_end_index( span: Tuple[str, Tuple[int, int]], offset: int ) -> Tuple[str, Tuple[int, int]]: - # format of span is (label,(start,end)) + """Increase end index of a span by offset. + + # Parameters + + span: `Tuple[str, Tuple[int, int]]`, required + span whose end index is to be updated. The format of span is (label,(start,end)) + offset: `int`, required + integer value added to end index of the span + + # Returns + `Tuple[str, Tuple[int, int]]` + Updated span + """ return span[0], (span[1][0], span[1][1] + offset) @@ -30,22 +54,28 @@ def get_weak_match( gold_spans: List[Tuple[str, Tuple[int, int]]], inclusive_end_index: bool = False, ) -> Optional[Tuple[str, Tuple[int, int]]]: - """This method checks if the predicted span is weakly matched with any of the gold spans. If - predicted type and gold type matches then we check if their respective indices are weakly - overlapping or not. Weak overlap between gold and predicted span is defined in Lauscher et al. - (2018) as overlap which should be at least half of the length of shorter span. If they are - weakly overlapping as well then we return the matched span. In addition to this, we use - inclusive_end_index boolean which if set adds an offset to the end index of each span in gold - spans list and also to predicted span. Once a match is found we revert back changes to end - index of matched span. Due to AllenNLP token based predictions, a span containing single token - would have length of 0 since start and end index of span would be same. That is why we add an - offset to end index. - - :param span: Predicted span instance as a tuple with span label and indices(start and end) of span. - :param gold_spans: List of gold span instances as tuple with span label and indices(start and end) of span. - :param inclusive_end_index: if set adds an offset to the end index of each span in gold spans list and also to - predicted span. Once a match is found we revert back changes to end index of matched span. - :return: gold span instance if matched with predicted span instance else None + """This method determines if the predicted span is weakly overlapped with any of the gold + spans. If the predicted type and the gold type matches then we check if their respective + indices are weakly overlapped or not. If they are weakly overlapped then we return the matched + span. In addition to this, we use inclusive_end_index which adds an offset to the end index of + each span in the gold spans list and also to the predicted span. Once a match is found we + revert changes to the end index of the matched span. A span containing a single token might + have length of 0 since the start and the end index of a span would be same. That is why we add + an offset to the end index. + + # Parameters + + span: `Tuple[str, Tuple[int, int]]` , required + Predicted span instance as a tuple with span label and indices(start and end) of span. + gold_spans: `List[Tuple[str, Tuple[int, int]]]` , required + List of gold span instances as tuple with span label and indices(start and end) of span. + inclusive_end_index: `bool` , optional (default = False) + if set adds an offset to the end index of each span in gold spans list and also to + predicted span. Once a match is found we revert changes to end index of matched span. + + # Returns + `Tuple[str, Tuple[int, int]] or None` + gold span instance if matched with predicted span instance else None """ if inclusive_end_index: span = increase_span_end_index(span, offset=1) @@ -62,7 +92,18 @@ def get_weak_match( return match_found -def get_span_classes(label_vocabulary): +def get_span_classes(label_vocabulary: Dict[int, str]): + """This method uses label vocabulary to get the span classes. + + # Parameters + + label_vocabulary: `Dict[int, str]` , required + It is a mapping from integer to span labels. + + # Returns + `Set(str)` + set of span labels + """ return { label.split("-")[1] for label in list(label_vocabulary.values()) @@ -71,6 +112,34 @@ def get_span_classes(label_vocabulary): class SpanBasedF1WeakMeasure(Metric): + """The SpanBasedF1WeakMeasure computes the F1 score based on the span overlap. It creates four + states: true positive (tp), false positive (fp), false negative (fn) and true negative (tn). + These states are updated iteratively for the different span sequences which is ultimately used + to calculate the required F1 score. + + # Parameters + + label_to_id: `Dict[str, int]` , required + It is a dictionary mapping span labels to the id. It is also used to create label_vocabulary + weak: `bool , optional (default = True) + This parameter determines the overlapping criteria for the successfully predicted span. If this parameter + is false then we expect the predicted and gold span to have a complete (exact) overlap, + otherwise the predicted and gold span can have a weak overlap. A weak overlap between the gold and + predicted span is defined in Lauscher et al. (2018) as an overlap which should be at + least half of the length of the shorter span. + return_metric: `str , optional (default = micro/f1) + It is the type of F1 measure that is to be computed and returned. It can be used to + get F1 score for the individual classes or for all classes using macro or micro averaging + criteria. Example: micro/f1, macro/f1, own_claim/f1 + label_encoding: `str , optional (default = IOB2) + It represents the type of encoding scheme for the spans. Encoding can be IOB2, BIOUL + and BOUL. + ignore_classes: `List[str]` , optional (default = None) + List of span labels that is to be ignored for the calculation of the metric. + dist_sync_on_step: `bool`, optional (default = False) + This parameter determines if the metric state should synchronize on forward(). + """ + def __init__( self, label_to_id: Dict[str, int], @@ -78,7 +147,7 @@ def __init__( return_metric: str = "micro/f1", label_encoding: str = "IOB2", ignore_classes: List[str] = None, - dist_sync_on_step=False, + dist_sync_on_step: bool = False, ): super().__init__(dist_sync_on_step=dist_sync_on_step) @@ -86,9 +155,6 @@ def __init__( self._ignore_classes: List[str] = ignore_classes or [] self._weak = weak self._label_vocabulary = dict(zip(label_to_id.values(), label_to_id.keys())) - self._true_positives: Dict[str, int] = defaultdict(int) - self._false_positives: Dict[str, int] = defaultdict(int) - self._false_negatives: Dict[str, int] = defaultdict(int) self._span_classes = list(get_span_classes(self._label_vocabulary)) self._num_classes = len(self._span_classes) self._return_metric = return_metric @@ -107,6 +173,7 @@ def update( masks: Optional[torch.BoolTensor] = None, prediction_map: Optional[torch.Tensor] = None, ): + """Updates the defined states in init using the predictions and targets.""" self.calculate_span_based_metric( preds=preds, targets=targets, masks=masks, prediction_map=prediction_map ) @@ -118,6 +185,25 @@ def calculate_span_based_metric( masks: Optional[torch.BoolTensor] = None, prediction_map: Optional[torch.Tensor] = None, ): + """This method calculates and then updates the values of the different states of the + metric. It converts the predictions and target tensors into a tag sequence which is then + converted to the token spans. Target and predicted spans are then compared to calculate the + true positive, false positive and false negative. + + # Parameters + + preds: `torch.Tensor` , required + This is the output predicted by the classification model in the shape B x T X C + where B represents batch size, T represents number of tokens and C represents + number of classes. + targets: `torch.Tensor` , required + This is the gold values for the given sequence in the shape B x T. + masks: `torch.BoolTensor` , optional (default = None) + This tensor is used to masks the padded tokens in the given sequence. If it is None, + then it is calculated using the targets. + prediction_map: `torch.BoolTensor` , optional (default = None) + ??? + """ if masks is None: # masks = torch.ones_like(targets).bool() This will result in a tensor with all values True. # It will result in error since targets contain -100 as value which has no label. @@ -171,9 +257,6 @@ def calculate_span_based_metric( classes_to_ignore=self._ignore_classes, ) - # Sorting spans so that it is deterministic all the time (handle_continued_spans may not maintain the order) - predicted_spans = sorted(predicted_spans) - gold_spans = sorted(gold_spans) for span in predicted_spans: span_original = span if self._weak: @@ -184,16 +267,17 @@ def calculate_span_based_metric( else: if self._weak: span = span_original - self.fp[self._span_classes.index(span[0])] += 1 + self.fp[self._span_classes_to_index[span[0]]] += 1 # These spans weren't predicted. for span in gold_spans: - self.fn[self._span_classes.index(span[0])] += 1 + self.fn[self._span_classes_to_index[span[0]]] += 1 def compute(self): - """Scores is a matrix of dimensions num_span_classes + 2 x 3. + """Scores is a matrix of dimensions num_span_classes + 2 x 3. Here 3 signifies precision, + recall and f1 and 2 signifies micro and macro averaged metric scores. - Here 3 signifies precision, recall and f1 and 2 signifies micro and macro averaged metric - scores + # Returns `torch.Tensor(float)` value of the metric based on the return_metric + parameter """ scores = torch.zeros([self._num_classes + 2, 3]) @@ -214,9 +298,16 @@ def compute(self): return self.get_return_metric(scores) def get_return_metric(self, scores): - """It returns metric based on return_metric parameter. + """It calculates the metric based on the return_metric parameter using the given scores. + + # Parameter - return_metric parameter is defined as TAG/METRIC. + scores: Tensor, required + Tensor containing the scores for each class against precision, recall and f1 + + # Returns + `torch.Tensor(float)` + value of the metric based on the return_metric parameter """ tag_name, _metric_name = self._return_metric.split("/") metric_to_idx = {"precision": 0, "recall": 1, "f1": 2} @@ -230,6 +321,21 @@ def get_return_metric(self, scores): def compute_metrics(true_positives: int, false_positives: int, false_negatives: int): + """Calculates precision, recall and f1 measure using the given true positive, false positive + and false negative values. + + # Parameters: + + true_positives: int , required + count for true positives + false_positives: int , required + count for false positives + false_negatives: int , required + count for false negatives + + # Returns + Tensor containing the values for precision, recall and f1 measure respectively + """ precision = true_positives / (true_positives + false_positives + 1e-13) recall = true_positives / (true_positives + false_negatives + 1e-13) f1_measure = 2.0 * (precision * recall) / (precision + recall + 1e-13) From 5904536e1f928aa264fca8b87e46615953dc92a8 Mon Sep 17 00:00:00 2001 From: Bhuvanesh Verma Date: Fri, 23 Sep 2022 13:34:14 +0200 Subject: [PATCH 3/4] update code for prediction_map --- src/pie_utils/metric/weak_span_based_f1.py | 3 +- tests/metric/test_weak_span_based_f1.py | 96 ++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/src/pie_utils/metric/weak_span_based_f1.py b/src/pie_utils/metric/weak_span_based_f1.py index ec1eedc..0bd15e5 100644 --- a/src/pie_utils/metric/weak_span_based_f1.py +++ b/src/pie_utils/metric/weak_span_based_f1.py @@ -224,7 +224,8 @@ def calculate_span_based_metric( if prediction_map is not None: argmax_predictions = torch.gather(prediction_map, 1, argmax_predictions) - gold_labels = torch.gather(prediction_map, 1, gold_labels.long()) + # gold labels contain padding token which is -100 and there is no prediction mapping for padding token. + # gold_labels = torch.gather(prediction_map, 1, gold_labels.long()) argmax_predictions = argmax_predictions.float() diff --git a/tests/metric/test_weak_span_based_f1.py b/tests/metric/test_weak_span_based_f1.py index aeca9da..8fbefa2 100644 --- a/tests/metric/test_weak_span_based_f1.py +++ b/tests/metric/test_weak_span_based_f1.py @@ -125,6 +125,62 @@ -2.556119441986084, ], ], + [ + [ + -2.7964279651641846, + -2.7935304641723633, + -2.4361226558685303, + -2.8135015964508057, + -1.101932406425476, + 0.03466780483722687, + 10.032923698425293, + ], + [ + -2.841414451599121, + -2.925645351409912, + -2.197348117828369, + -3.377918004989624, + -1.2031382322311401, + -1.7914681434631348, + 10.230854034423828, + ], + [ + -2.645219087600708, + -3.736269235610962, + -2.1186366081237793, + -3.4355688095092773, + -0.8101659417152405, + -2.3279330730438232, + 10.49844741821289, + ], + [ + -2.3094680309295654, + -3.7092173099517822, + -1.942619800567627, + 10.281001091003418, + -3.1347908973693848, + -1.1283811330795288, + -2.2268154621124268, + ], + [ + -3.747760772705078, + -2.015471935272217, + -3.379420757293701, + -2.101996421813965, + 10.566914558410645, + -3.555516242980957, + -2.556119441986084, + ], + [ + -3.747760772705078, + -2.015471935272217, + -3.379420757293701, + -2.101996421813965, + 10.566914558410645, + -3.555516242980957, + -2.556119441986084, + ], + ], ] ) @@ -148,6 +204,14 @@ 3, -100, ], + [ + -100, + -100, + -100, + -100, + -100, + -100, + ], ] ) @@ -201,6 +265,38 @@ def test_update(masks): assert torch.equal(metric.fn, expected_false_negatives) +def test_update_with_prediction_map(): + """This test is similar to the last test but uses the prediction_map to obtain correct label in + each batch of sequence. + + In our case, the label classes are same for all batch sequences. So prediction_map doesn't really make much + difference. Accurate use case can be found here: + https://github.com/allenai/allennlp/blob/39c40fe38cd2fd36b3465b0b3c031f54ec824160/tests/training/metrics/span_based_f1_measure_test.py#L39 + """ + prediction_map = torch.tensor( + [list(LABEL_TO_ID.values()), list(LABEL_TO_ID.values()), list(LABEL_TO_ID.values())] + ) + metric = SpanBasedF1WeakMeasure(label_to_id=LABEL_TO_ID, return_metric="micro/f1") + assert torch.equal(metric.tp, torch.zeros([3], dtype=torch.int64)) + assert torch.equal(metric.fp, torch.zeros([3], dtype=torch.int64)) + assert torch.equal(metric.fn, torch.zeros([3], dtype=torch.int64)) + metric.update( + preds=torch.tensor(LOGITS), targets=torch.tensor(TARGETS), prediction_map=prediction_map + ) + expected_true_positives = torch.zeros([3], dtype=torch.int64) + expected_true_positives[metric._span_classes.index("own_claim")] = 1 + expected_true_positives[metric._span_classes.index("data")] = 1 + assert torch.equal(metric.tp, expected_true_positives) + + expected_false_positives = torch.zeros([3], dtype=torch.int64) + expected_false_positives[metric._span_classes.index("own_claim")] = 1 + assert torch.equal(metric.fp, expected_false_positives) + + expected_false_negatives = torch.zeros([3], dtype=torch.int64) + expected_false_negatives[metric._span_classes.index("background_claim")] = 1 + assert torch.equal(metric.fn, expected_false_negatives) + + @pytest.mark.parametrize( "return_metric", ["own_claim/f1", "own_claim/precision", "own_claim/recall", "macro/f1", "micro/f1"], From e701050eff570dc380dac265ea51dfb0f2bc2ad8 Mon Sep 17 00:00:00 2001 From: Bhuvanesh Verma Date: Fri, 23 Sep 2022 13:35:06 +0200 Subject: [PATCH 4/4] update documentation --- src/pie_utils/metric/weak_span_based_f1.py | 52 ++++++++++++---------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/src/pie_utils/metric/weak_span_based_f1.py b/src/pie_utils/metric/weak_span_based_f1.py index 0bd15e5..7210edf 100644 --- a/src/pie_utils/metric/weak_span_based_f1.py +++ b/src/pie_utils/metric/weak_span_based_f1.py @@ -56,12 +56,12 @@ def get_weak_match( ) -> Optional[Tuple[str, Tuple[int, int]]]: """This method determines if the predicted span is weakly overlapped with any of the gold spans. If the predicted type and the gold type matches then we check if their respective - indices are weakly overlapped or not. If they are weakly overlapped then we return the matched - span. In addition to this, we use inclusive_end_index which adds an offset to the end index of - each span in the gold spans list and also to the predicted span. Once a match is found we - revert changes to the end index of the matched span. A span containing a single token might - have length of 0 since the start and the end index of a span would be same. That is why we add - an offset to the end index. + indices are weakly overlapped or not. If they weakly overlapped, we return the matched span. In + addition to this, we use inclusive_end_index which adds an offset to the end index of each span + in the gold spans list and also to the predicted span. Once a match is found we revert changes + to the end index of the matched span. A span containing a single token might have length of 0 + since the start and the end index of a span would be same. That is why we add an offset to the + end index. # Parameters @@ -112,10 +112,13 @@ def get_span_classes(label_vocabulary: Dict[int, str]): class SpanBasedF1WeakMeasure(Metric): - """The SpanBasedF1WeakMeasure computes the F1 score based on the span overlap. It creates four - states: true positive (tp), false positive (fp), false negative (fn) and true negative (tn). - These states are updated iteratively for the different span sequences which is ultimately used - to calculate the required F1 score. + """The SpanBasedF1WeakMeasure implements span-based precision, recall and F1 measure for + different tagging schemes. A span can either have exact match (strict) or it can be partially + overlapped (weak) with predicted span. This metric allows both strict and weak version. It + creates four states: true positive (tp), false positive (fp), false negative (fn) and true + negative (tn). These states are updated iteratively for the different sequences which is + ultimately used to calculate the desired metric. It produces precision, recall and F1 measures + per tag, as well as overall statistics. # Parameters @@ -193,27 +196,30 @@ def calculate_span_based_metric( # Parameters preds: `torch.Tensor` , required - This is the output predicted by the classification model in the shape B x T X C - where B represents batch size, T represents number of tokens and C represents - number of classes. + This is the output predicted by the classification model in the shape + (batch_size, sequence_length, num_classes) targets: `torch.Tensor` , required - This is the gold values for the given sequence in the shape B x T. + This is the gold values for the given sequence in the shape (batch_size, sequence_length). masks: `torch.BoolTensor` , optional (default = None) This tensor is used to masks the padded tokens in the given sequence. If it is None, then it is calculated using the targets. prediction_map: `torch.BoolTensor` , optional (default = None) - ??? + A tensor of size (batch_size, num_classes) which provides a mapping from the index of predictions + to the indices of the label vocabulary. If provided, the output label at each timestep will be + `vocabulary.get_index_to_token_vocabulary(prediction_map[batch, argmax(predictions[batch, t]))`, + rather than simply `vocabulary.get_index_to_token_vocabulary(argmax(predictions[batch, t]))`. + This is useful in cases where each Instance in the dataset is associated with a different possible + subset of labels from a large label-space (IE FrameNet, where each frame has a different set of + possible roles associated with it). """ if masks is None: # masks = torch.ones_like(targets).bool() This will result in a tensor with all values True. # It will result in error since targets contain -100 as value which has no label. masks = targets != -100 - """ - If you actually passed gradient-tracking Tensors to a Metric, there will be - a huge memory leak, because it will prevent garbage collection for the computation - graph. This method ensures the tensors are detached. - Check if it's actually a tensor in case something else was passed. - """ + + # If you actually passed gradient-tracking Tensors to a Metric, there will be a huge memory leak, because it + # will prevent garbage collection for the computation graph. This method ensures the tensors are detached. Check + # if it's actually a tensor in case something else was passed. predictions, gold_labels, mask, prediction_map = ( x.detach() if isinstance(x, torch.Tensor) else x for x in (preds, targets, masks, prediction_map) @@ -274,10 +280,10 @@ def calculate_span_based_metric( self.fn[self._span_classes_to_index[span[0]]] += 1 def compute(self): - """Scores is a matrix of dimensions num_span_classes + 2 x 3. Here 3 signifies precision, + """Scores is a matrix of dimensions (num_span_classes + 2) x 3. Here 3 signifies precision, recall and f1 and 2 signifies micro and macro averaged metric scores. - # Returns `torch.Tensor(float)` value of the metric based on the return_metric + # Returns `torch.Tensor(float)` value of the metric based on the return_metric parameter """