diff --git a/src/pie_utils/metric/__init__.py b/src/pie_utils/metric/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pie_utils/metric/weak_span_based_f1.py b/src/pie_utils/metric/weak_span_based_f1.py new file mode 100644 index 0000000..7210edf --- /dev/null +++ b/src/pie_utils/metric/weak_span_based_f1.py @@ -0,0 +1,349 @@ +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from torchmetrics import Metric + +from pie_utils.sequence_tagging import tag_sequence_to_token_spans +from pie_utils.span.slice import get_overlap_len + +TypedStringSpan = Tuple[str, Tuple[int, int]] +TAGS_TO_SPANS_FUNCTION_TYPE = Callable[[List[str], Optional[List[str]]], List[TypedStringSpan]] + + +def has_weak_overlap(indices_1: Tuple[int, int], indices_2: Tuple[int, int]) -> bool: + """This method checks if span overlap is at least half the length of the shorter span. + + # Parameters: + + indices_1: `Tuple[int, int]` , required + Span slice or indices of the first span + indices_2: `Tuple[int, int]` , required + Span slice or indices of the second span + + # Returns: + `bool` + if two slices are weakly overlap or not + """ + + min_len = min(indices_1[1] - indices_1[0], indices_2[1] - indices_2[0]) + overlap_len = get_overlap_len(indices_1, indices_2) + return 2 * overlap_len >= min_len + + +def increase_span_end_index( + span: Tuple[str, Tuple[int, int]], offset: int +) -> Tuple[str, Tuple[int, int]]: + """Increase end index of a span by offset. + + # Parameters + + span: `Tuple[str, Tuple[int, int]]`, required + span whose end index is to be updated. The format of span is (label,(start,end)) + offset: `int`, required + integer value added to end index of the span + + # Returns + `Tuple[str, Tuple[int, int]]` + Updated span + """ + return span[0], (span[1][0], span[1][1] + offset) + + +def get_weak_match( + span: Tuple[str, Tuple[int, int]], + gold_spans: List[Tuple[str, Tuple[int, int]]], + inclusive_end_index: bool = False, +) -> Optional[Tuple[str, Tuple[int, int]]]: + """This method determines if the predicted span is weakly overlapped with any of the gold + spans. If the predicted type and the gold type matches then we check if their respective + indices are weakly overlapped or not. If they weakly overlapped, we return the matched span. In + addition to this, we use inclusive_end_index which adds an offset to the end index of each span + in the gold spans list and also to the predicted span. Once a match is found we revert changes + to the end index of the matched span. A span containing a single token might have length of 0 + since the start and the end index of a span would be same. That is why we add an offset to the + end index. + + # Parameters + + span: `Tuple[str, Tuple[int, int]]` , required + Predicted span instance as a tuple with span label and indices(start and end) of span. + gold_spans: `List[Tuple[str, Tuple[int, int]]]` , required + List of gold span instances as tuple with span label and indices(start and end) of span. + inclusive_end_index: `bool` , optional (default = False) + if set adds an offset to the end index of each span in gold spans list and also to + predicted span. Once a match is found we revert changes to end index of matched span. + + # Returns + `Tuple[str, Tuple[int, int]] or None` + gold span instance if matched with predicted span instance else None + """ + if inclusive_end_index: + span = increase_span_end_index(span, offset=1) + gold_spans = [increase_span_end_index(gold_span, offset=1) for gold_span in gold_spans] + + match_found = None + predicted_type, predicted_indices = span + for gold_type, gold_indices in gold_spans: + if predicted_type == gold_type and has_weak_overlap(predicted_indices, gold_indices): + match_found = gold_type, gold_indices + if inclusive_end_index: + match_found = increase_span_end_index(match_found, offset=-1) + break + return match_found + + +def get_span_classes(label_vocabulary: Dict[int, str]): + """This method uses label vocabulary to get the span classes. + + # Parameters + + label_vocabulary: `Dict[int, str]` , required + It is a mapping from integer to span labels. + + # Returns + `Set(str)` + set of span labels + """ + return { + label.split("-")[1] + for label in list(label_vocabulary.values()) + if len(label.split("-")) == 2 + } + + +class SpanBasedF1WeakMeasure(Metric): + """The SpanBasedF1WeakMeasure implements span-based precision, recall and F1 measure for + different tagging schemes. A span can either have exact match (strict) or it can be partially + overlapped (weak) with predicted span. This metric allows both strict and weak version. It + creates four states: true positive (tp), false positive (fp), false negative (fn) and true + negative (tn). These states are updated iteratively for the different sequences which is + ultimately used to calculate the desired metric. It produces precision, recall and F1 measures + per tag, as well as overall statistics. + + # Parameters + + label_to_id: `Dict[str, int]` , required + It is a dictionary mapping span labels to the id. It is also used to create label_vocabulary + weak: `bool , optional (default = True) + This parameter determines the overlapping criteria for the successfully predicted span. If this parameter + is false then we expect the predicted and gold span to have a complete (exact) overlap, + otherwise the predicted and gold span can have a weak overlap. A weak overlap between the gold and + predicted span is defined in Lauscher et al. (2018) as an overlap which should be at + least half of the length of the shorter span. + return_metric: `str , optional (default = micro/f1) + It is the type of F1 measure that is to be computed and returned. It can be used to + get F1 score for the individual classes or for all classes using macro or micro averaging + criteria. Example: micro/f1, macro/f1, own_claim/f1 + label_encoding: `str , optional (default = IOB2) + It represents the type of encoding scheme for the spans. Encoding can be IOB2, BIOUL + and BOUL. + ignore_classes: `List[str]` , optional (default = None) + List of span labels that is to be ignored for the calculation of the metric. + dist_sync_on_step: `bool`, optional (default = False) + This parameter determines if the metric state should synchronize on forward(). + """ + + def __init__( + self, + label_to_id: Dict[str, int], + weak: bool = True, + return_metric: str = "micro/f1", + label_encoding: str = "IOB2", + ignore_classes: List[str] = None, + dist_sync_on_step: bool = False, + ): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self._label_encoding = label_encoding + self._ignore_classes: List[str] = ignore_classes or [] + self._weak = weak + self._label_vocabulary = dict(zip(label_to_id.values(), label_to_id.keys())) + self._span_classes = list(get_span_classes(self._label_vocabulary)) + self._num_classes = len(self._span_classes) + self._return_metric = return_metric + self._span_classes_to_index = {c: i for i, c in enumerate(self._span_classes)} + + def default(): + return torch.zeros([self._num_classes], dtype=torch.long) + + for s in ("tp", "fp", "tn", "fn"): + self.add_state(s, default=default(), dist_reduce_fx="sum") + + def update( + self, + preds: torch.Tensor, + targets: torch.Tensor, + masks: Optional[torch.BoolTensor] = None, + prediction_map: Optional[torch.Tensor] = None, + ): + """Updates the defined states in init using the predictions and targets.""" + self.calculate_span_based_metric( + preds=preds, targets=targets, masks=masks, prediction_map=prediction_map + ) + + def calculate_span_based_metric( + self, + preds: torch.Tensor, + targets: torch.Tensor, + masks: Optional[torch.BoolTensor] = None, + prediction_map: Optional[torch.Tensor] = None, + ): + """This method calculates and then updates the values of the different states of the + metric. It converts the predictions and target tensors into a tag sequence which is then + converted to the token spans. Target and predicted spans are then compared to calculate the + true positive, false positive and false negative. + + # Parameters + + preds: `torch.Tensor` , required + This is the output predicted by the classification model in the shape + (batch_size, sequence_length, num_classes) + targets: `torch.Tensor` , required + This is the gold values for the given sequence in the shape (batch_size, sequence_length). + masks: `torch.BoolTensor` , optional (default = None) + This tensor is used to masks the padded tokens in the given sequence. If it is None, + then it is calculated using the targets. + prediction_map: `torch.BoolTensor` , optional (default = None) + A tensor of size (batch_size, num_classes) which provides a mapping from the index of predictions + to the indices of the label vocabulary. If provided, the output label at each timestep will be + `vocabulary.get_index_to_token_vocabulary(prediction_map[batch, argmax(predictions[batch, t]))`, + rather than simply `vocabulary.get_index_to_token_vocabulary(argmax(predictions[batch, t]))`. + This is useful in cases where each Instance in the dataset is associated with a different possible + subset of labels from a large label-space (IE FrameNet, where each frame has a different set of + possible roles associated with it). + """ + if masks is None: + # masks = torch.ones_like(targets).bool() This will result in a tensor with all values True. + # It will result in error since targets contain -100 as value which has no label. + masks = targets != -100 + + # If you actually passed gradient-tracking Tensors to a Metric, there will be a huge memory leak, because it + # will prevent garbage collection for the computation graph. This method ensures the tensors are detached. Check + # if it's actually a tensor in case something else was passed. + predictions, gold_labels, mask, prediction_map = ( + x.detach() if isinstance(x, torch.Tensor) else x + for x in (preds, targets, masks, prediction_map) + ) + + sequence_lengths = masks.sum(-1) + argmax_predictions = preds.argmax(dim=2) + + if prediction_map is not None: + argmax_predictions = torch.gather(prediction_map, 1, argmax_predictions) + # gold labels contain padding token which is -100 and there is no prediction mapping for padding token. + # gold_labels = torch.gather(prediction_map, 1, gold_labels.long()) + + argmax_predictions = argmax_predictions.float() + + batch_size = gold_labels.size(0) + for i in range(batch_size): + sequence_prediction = argmax_predictions[i, :] + sequence_gold_label = gold_labels[i, :] + length = sequence_lengths[i] + mask = masks[i, :] + if length == 0: + # It is possible to call this metric with sequences which are + # completely padded. These contribute nothing, so we skip these rows. + continue + + predicted_string_labels = [ + self._label_vocabulary[label_id] for label_id in sequence_prediction[mask].tolist() + ] + gold_string_labels = [ + self._label_vocabulary[label_id] for label_id in sequence_gold_label[mask].tolist() + ] + + predicted_spans = tag_sequence_to_token_spans( + tag_sequence=predicted_string_labels, + coding_scheme=self._label_encoding, + classes_to_ignore=self._ignore_classes, + ) + gold_spans = tag_sequence_to_token_spans( + tag_sequence=gold_string_labels, + coding_scheme=self._label_encoding, + classes_to_ignore=self._ignore_classes, + ) + + for span in predicted_spans: + span_original = span + if self._weak: + span = get_weak_match(span, gold_spans, inclusive_end_index=True) + if (not self._weak and span in gold_spans) or (self._weak and span): + self.tp[self._span_classes_to_index[span[0]]] += 1 + gold_spans.remove(span) + else: + if self._weak: + span = span_original + self.fp[self._span_classes_to_index[span[0]]] += 1 + # These spans weren't predicted. + for span in gold_spans: + self.fn[self._span_classes_to_index[span[0]]] += 1 + + def compute(self): + """Scores is a matrix of dimensions (num_span_classes + 2) x 3. Here 3 signifies precision, + recall and f1 and 2 signifies micro and macro averaged metric scores. + + # Returns `torch.Tensor(float)` value of the metric based on the return_metric + parameter + """ + + scores = torch.zeros([self._num_classes + 2, 3]) + for i, tag in enumerate(self._span_classes): + scores[i] = compute_metrics(self.tp[i], self.fp[i], self.fn[i]) + scores[self._num_classes + 1] += scores[i] + + # macro averaged metrics + scores[self._num_classes + 1] = scores[self._num_classes + 1] / 3 + + # micro averaged metrics + scores[self._num_classes] = compute_metrics( + true_positives=sum(self.tp), + false_positives=sum(self.fp), + false_negatives=sum(self.fn), + ) + + return self.get_return_metric(scores) + + def get_return_metric(self, scores): + """It calculates the metric based on the return_metric parameter using the given scores. + + # Parameter + + scores: Tensor, required + Tensor containing the scores for each class against precision, recall and f1 + + # Returns + `torch.Tensor(float)` + value of the metric based on the return_metric parameter + """ + tag_name, _metric_name = self._return_metric.split("/") + metric_to_idx = {"precision": 0, "recall": 1, "f1": 2} + if tag_name == "micro": + index = self._num_classes + elif tag_name == "macro": + index = self._num_classes + 1 + else: + index = self._span_classes.index(tag_name) + return scores[index, metric_to_idx[_metric_name]] + + +def compute_metrics(true_positives: int, false_positives: int, false_negatives: int): + """Calculates precision, recall and f1 measure using the given true positive, false positive + and false negative values. + + # Parameters: + + true_positives: int , required + count for true positives + false_positives: int , required + count for false positives + false_negatives: int , required + count for false negatives + + # Returns + Tensor containing the values for precision, recall and f1 measure respectively + """ + precision = true_positives / (true_positives + false_positives + 1e-13) + recall = true_positives / (true_positives + false_negatives + 1e-13) + f1_measure = 2.0 * (precision * recall) / (precision + recall + 1e-13) + return torch.tensor([precision, recall, f1_measure]) diff --git a/tests/metric/__init__.py b/tests/metric/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/metric/test_weak_span_based_f1.py b/tests/metric/test_weak_span_based_f1.py new file mode 100644 index 0000000..8fbefa2 --- /dev/null +++ b/tests/metric/test_weak_span_based_f1.py @@ -0,0 +1,424 @@ +import pytest +import torch + +from pie_utils.metric.weak_span_based_f1 import ( + SpanBasedF1WeakMeasure, + compute_metrics, + get_weak_match, + has_weak_overlap, + increase_span_end_index, +) + +# Predicted Tag Sequence : +# [[O,B-own_claim,I-own_claim,I-own_claim,I-own_claim,I-own_claim],[I-own_claim,I-own_claim,I-own_claim,B-data,I-data,I-data]] +LOGITS = torch.tensor( + [ + [ + [ + 10.351420402526855, + -3.2028937339782715, + -2.790834665298462, + -2.648576021194458, + -3.457158327102661, + -2.2753586769104004, + -3.092430591583252, + ], + [ + -0.728147566318512, + -0.3243664503097534, + -3.784968137741089, + -0.566190779209137, + -3.6368443965911865, + 11.714195251464844, + 0.874237596988678, + ], + [ + -2.2986810207366943, + -3.2249557971954346, + -2.213132381439209, + -3.176085948944092, + -1.201798439025879, + -2.5730767250061035, + 10.175483703613281, + ], + [ + -2.399078845977783, + -3.7054524421691895, + -2.732715606689453, + -3.1169371604919434, + -1.3757351636886597, + -2.5311129093170166, + 10.664148330688477, + ], + [ + -1.685034155845642, + -3.0339670181274414, + -2.960148572921753, + -3.0828864574432373, + -1.1170631647109985, + -2.528775215148926, + 10.603050231933594, + ], + [ + -2.307305097579956, + -3.424372673034668, + -2.1860435009002686, + -3.231368064880371, + -1.3228000402450562, + -2.4107933044433594, + 10.90400218963623, + ], + ], + [ + [ + -2.7964279651641846, + -2.7935304641723633, + -2.4361226558685303, + -2.8135015964508057, + -1.101932406425476, + 0.03466780483722687, + 10.032923698425293, + ], + [ + -2.841414451599121, + -2.925645351409912, + -2.197348117828369, + -3.377918004989624, + -1.2031382322311401, + -1.7914681434631348, + 10.230854034423828, + ], + [ + -2.645219087600708, + -3.736269235610962, + -2.1186366081237793, + -3.4355688095092773, + -0.8101659417152405, + -2.3279330730438232, + 10.49844741821289, + ], + [ + -2.3094680309295654, + -3.7092173099517822, + -1.942619800567627, + 10.281001091003418, + -3.1347908973693848, + -1.1283811330795288, + -2.2268154621124268, + ], + [ + -3.747760772705078, + -2.015471935272217, + -3.379420757293701, + -2.101996421813965, + 10.566914558410645, + -3.555516242980957, + -2.556119441986084, + ], + [ + -3.747760772705078, + -2.015471935272217, + -3.379420757293701, + -2.101996421813965, + 10.566914558410645, + -3.555516242980957, + -2.556119441986084, + ], + ], + [ + [ + -2.7964279651641846, + -2.7935304641723633, + -2.4361226558685303, + -2.8135015964508057, + -1.101932406425476, + 0.03466780483722687, + 10.032923698425293, + ], + [ + -2.841414451599121, + -2.925645351409912, + -2.197348117828369, + -3.377918004989624, + -1.2031382322311401, + -1.7914681434631348, + 10.230854034423828, + ], + [ + -2.645219087600708, + -3.736269235610962, + -2.1186366081237793, + -3.4355688095092773, + -0.8101659417152405, + -2.3279330730438232, + 10.49844741821289, + ], + [ + -2.3094680309295654, + -3.7092173099517822, + -1.942619800567627, + 10.281001091003418, + -3.1347908973693848, + -1.1283811330795288, + -2.2268154621124268, + ], + [ + -3.747760772705078, + -2.015471935272217, + -3.379420757293701, + -2.101996421813965, + 10.566914558410645, + -3.555516242980957, + -2.556119441986084, + ], + [ + -3.747760772705078, + -2.015471935272217, + -3.379420757293701, + -2.101996421813965, + 10.566914558410645, + -3.555516242980957, + -2.556119441986084, + ], + ], + ] +) + +# Original Tag sequence: +# [[PAD,O,B-own_claim,I-own_claim,I-own_claim,PAD],[PAD,PAD,B-background_claim,I-background_claim,B-data,PAD]] +TARGETS = torch.tensor( + [ + [ + -100, + 0, + 5, + 6, + 6, + -100, + ], + [ + -100, + -100, + 1, + 2, + 3, + -100, + ], + [ + -100, + -100, + -100, + -100, + -100, + -100, + ], + ] +) + +MASKS = TARGETS != -100 + +LABEL_TO_ID = { + "O": 0, + "B-background_claim": 1, + "I-background_claim": 2, + "B-data": 3, + "I-data": 4, + "B-own_claim": 5, + "I-own_claim": 6, +} + + +@pytest.mark.parametrize( + "masks", + [MASKS, None], +) +def test_update(masks): + """Given instance contains four spans, two of which are own_claim and one each for data and + background_claim. Model predicts one span labelled as own_claim correctly. Moreover, it + predicts a span as data falsely. Therefore, resulting metric should contain 1 true positive + count for 'own_claim', 1 false positive count for 'data'. and 1 false negative count for each + label. + + predicted spans : [own_claim (1,5)] , [own_claim (0,2), data (3,5)] + gold spans : [own_claim (2,4)] , [background_claim (2,3), data (4)] + tp : own_claim=1, data=1, background_claim=0 + fp : own_claim=1, data=0, background_claim=0 + fn : own_claim=0, data=0, background_claim=1 + """ + + metric = SpanBasedF1WeakMeasure(label_to_id=LABEL_TO_ID, return_metric="micro/f1") + assert torch.equal(metric.tp, torch.zeros([3], dtype=torch.int64)) + assert torch.equal(metric.fp, torch.zeros([3], dtype=torch.int64)) + assert torch.equal(metric.fn, torch.zeros([3], dtype=torch.int64)) + metric.update(preds=torch.tensor(LOGITS), targets=torch.tensor(TARGETS), masks=masks) + expected_true_positives = torch.zeros([3], dtype=torch.int64) + expected_true_positives[metric._span_classes.index("own_claim")] = 1 + expected_true_positives[metric._span_classes.index("data")] = 1 + assert torch.equal(metric.tp, expected_true_positives) + + expected_false_positives = torch.zeros([3], dtype=torch.int64) + expected_false_positives[metric._span_classes.index("own_claim")] = 1 + assert torch.equal(metric.fp, expected_false_positives) + + expected_false_negatives = torch.zeros([3], dtype=torch.int64) + expected_false_negatives[metric._span_classes.index("background_claim")] = 1 + assert torch.equal(metric.fn, expected_false_negatives) + + +def test_update_with_prediction_map(): + """This test is similar to the last test but uses the prediction_map to obtain correct label in + each batch of sequence. + + In our case, the label classes are same for all batch sequences. So prediction_map doesn't really make much + difference. Accurate use case can be found here: + https://github.com/allenai/allennlp/blob/39c40fe38cd2fd36b3465b0b3c031f54ec824160/tests/training/metrics/span_based_f1_measure_test.py#L39 + """ + prediction_map = torch.tensor( + [list(LABEL_TO_ID.values()), list(LABEL_TO_ID.values()), list(LABEL_TO_ID.values())] + ) + metric = SpanBasedF1WeakMeasure(label_to_id=LABEL_TO_ID, return_metric="micro/f1") + assert torch.equal(metric.tp, torch.zeros([3], dtype=torch.int64)) + assert torch.equal(metric.fp, torch.zeros([3], dtype=torch.int64)) + assert torch.equal(metric.fn, torch.zeros([3], dtype=torch.int64)) + metric.update( + preds=torch.tensor(LOGITS), targets=torch.tensor(TARGETS), prediction_map=prediction_map + ) + expected_true_positives = torch.zeros([3], dtype=torch.int64) + expected_true_positives[metric._span_classes.index("own_claim")] = 1 + expected_true_positives[metric._span_classes.index("data")] = 1 + assert torch.equal(metric.tp, expected_true_positives) + + expected_false_positives = torch.zeros([3], dtype=torch.int64) + expected_false_positives[metric._span_classes.index("own_claim")] = 1 + assert torch.equal(metric.fp, expected_false_positives) + + expected_false_negatives = torch.zeros([3], dtype=torch.int64) + expected_false_negatives[metric._span_classes.index("background_claim")] = 1 + assert torch.equal(metric.fn, expected_false_negatives) + + +@pytest.mark.parametrize( + "return_metric", + ["own_claim/f1", "own_claim/precision", "own_claim/recall", "macro/f1", "micro/f1"], +) +def test_compute(return_metric): + """Given instance contains four spans, two of which are own_claim and one each for data and + background_claim. Model predicts one span labelled as own_claim correctly. Moreover, it + predicts a span as data falsely. Based on these predictions, scores will be calculated for + different return metrics. + + own_claim/f1 : 2*1*0.5/(0.5+1) = 0.667 + own_claim/precision : 1/(1+1) = 0.5 + own_claim/recall : 1/(1+0) = 1 + macro/f1 : (0.667 + 1 + 0)/3 = 0.556 + micro/f1 : (2*0.667*0.667)/(0.667+0.667) = 0.667 + """ + metric = SpanBasedF1WeakMeasure(label_to_id=LABEL_TO_ID, return_metric=return_metric) + metric.update( + preds=torch.tensor(LOGITS), targets=torch.tensor(TARGETS), masks=torch.tensor(MASKS) + ) + scores = metric.compute() + if return_metric == "own_claim/f1": + assert pytest.approx(scores) == torch.tensor(0.6666666865348816) + if return_metric == "own_claim/precision": + assert torch.eq(scores, torch.tensor(0.5)) + if return_metric == "own_claim/recall": + assert torch.eq(scores, torch.tensor(1)) + if return_metric == "macro/f1": + assert pytest.approx(scores) == torch.tensor(0.5555555820465088) + if return_metric == "micro/f1": + assert pytest.approx(scores) == torch.tensor(0.6666666865348816) + + +def test_has_weak_overlap(): + # checks if overlap in span is at least half of the length of the shorter span. + + weak_overlapping_indices = ((0, 5), (3, 6)) + no_weak_overlapping_indices = ((0, 5), (4, 9)) + touching_indices = ((0, 5), (5, 7)) + non_touching_indices = ((0, 5), (6, 7)) + containing_indices = ((0, 9), (3, 6)) + + assert has_weak_overlap(weak_overlapping_indices[0], weak_overlapping_indices[1]) + assert not has_weak_overlap(no_weak_overlapping_indices[0], no_weak_overlapping_indices[1]) + assert not has_weak_overlap(touching_indices[0], touching_indices[1]) + assert not has_weak_overlap(non_touching_indices[0], non_touching_indices[1]) + assert has_weak_overlap(containing_indices[0], containing_indices[1]) + + assert has_weak_overlap(weak_overlapping_indices[1], weak_overlapping_indices[0]) + assert not has_weak_overlap(no_weak_overlapping_indices[1], no_weak_overlapping_indices[0]) + assert not has_weak_overlap(touching_indices[1], touching_indices[0]) + assert not has_weak_overlap(non_touching_indices[1], non_touching_indices[0]) + assert has_weak_overlap(containing_indices[1], containing_indices[0]) + + +def test_compute_metrics(): + true_positives = 10 + false_positives = 8 + false_negatives = 4 + true_precision = 10 / (10 + 8) + true_recall = 10 / (10 + 4) + true_f1 = (2 * true_recall * true_precision) / (true_recall + true_precision) + precision, recall, f1 = compute_metrics( + true_positives=true_positives, + false_positives=false_positives, + false_negatives=false_negatives, + ) + assert precision == true_precision + assert recall == true_recall + assert f1 == true_f1 + + +def test_increase_span_end_index(): + # Resulting span should have end index adjusted by the value of the offset, i.e. the offset value should be added + # to the end index of the span + span = ("city", (0, 0)) + new_span = increase_span_end_index(span=span, offset=1) + assert new_span == ("city", (0, 1)) + + original_span = increase_span_end_index(span=new_span, offset=-1) + assert original_span == span + + +def test_get_weak_match(): + # Since predicted span is contained in first span of the gold span list, it will be considered as the weak match. + span = ("city", (0, 0)) + spans = [("city", (0, 3)), ("person", (3, 5)), ("person", (6, 9))] + match = get_weak_match(span=span, gold_spans=spans) + assert match == ("city", (0, 3)) + + # Predicted span is not contained in the second gold span, therefore it is not considered as a match. + span = ("person", (3, 5)) + spans = [("city", (0, 3)), ("person", (5, 9))] + match = get_weak_match(span=span, gold_spans=spans) + assert match is None + + # Here predicted span is partly inside the second gold span, there we have a weak match + span = ("person", (3, 5)) + spans = [("city", (0, 3)), ("person", (4, 7))] + match = get_weak_match(span=span, gold_spans=spans) + assert match == ("person", (4, 7)) + + # Here predicted span is partly inside the second gold span but not enough to be considered as a weak match. + span = ("person", (3, 6)) + spans = [("city", (0, 3)), ("person", (5, 9))] + match = get_weak_match(span=span, gold_spans=spans) + assert match is None + + +@pytest.mark.parametrize( + "inclusive_end_index", + [True, False], +) +def test_get_weak_match_with_inclusive_end_index(inclusive_end_index): + # Here in the given span if we consider end index to be part of span then we have weak match with the first span + # in the list of gold spans. However, if we do not consider end index as inclusive to the span, then there is no + # overlap and hence no weak match. + span = ("city", (0, 2)) + spans = [("city", (2, 3)), ("person", (3, 5))] + + match = get_weak_match(span=span, gold_spans=spans, inclusive_end_index=inclusive_end_index) + if inclusive_end_index: + assert match == ("city", (2, 3)) + else: + assert match is None