Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed: logic in calculating evaluation metrics #476

Merged
merged 25 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
30a77d7
Changed: logic in calculating evaluation metrics
iftwigs Mar 25, 2024
47d7ab1
Changed: confusion matrix calculator
iftwigs Apr 18, 2024
df33bcf
Added: EvaluationConfusionMatrix class
iftwigs Apr 24, 2024
e1f4ee3
Merge branch 'master' into 12397-evaluation-fix
iftwigs Apr 25, 2024
838a03a
Changed: class structure, method moved outside compare() function
iftwigs Apr 29, 2024
b54f66e
Changed: integrating the GT/TN into compare() and existing Extraction…
iftwigs May 2, 2024
7c13a39
Added: Confusion matrix into the Extraction evaluation
iftwigs May 3, 2024
4331405
Added: docstrings for confusion matrix class
iftwigs May 7, 2024
39d9cc8
Fixed: affected tests
iftwigs May 9, 2024
e056367
Fixed: affected tests
iftwigs May 9, 2024
a755b87
Merge branch 'master' into 12397-evaluation-fix
iftwigs May 21, 2024
a49ea75
Fix: Evaluation all zero for label of type multiple
nengelmann Jun 5, 2024
b19cb3f
Merge branch 'master' into 12397-evaluation-fix
iftwigs Jun 6, 2024
1fe3752
Merge branch '12397-evaluation-fix' into 12529-fix-evavluation-for-la…
iftwigs Jun 6, 2024
01ae48e
Fix: Assignment, Changed: Ordering Spans
nengelmann Jun 6, 2024
4b2b8e1
Merge branch 'master' into 12397-evaluation-fix
iftwigs Jun 13, 2024
471da75
Merge branch '12397-evaluation-fix' of https://github.com/konfuzio-ai…
iftwigs Jun 13, 2024
fdc660e
Merge branch '12397-evaluation-fix' into 12529-fix-evavluation-for-la…
iftwigs Jun 13, 2024
007082e
Merge pull request #501 from konfuzio-ai/12529-fix-evavluation-for-la…
iftwigs Jun 13, 2024
93eb1ec
Fixed: conditions in compare()
iftwigs Jun 13, 2024
44a7a06
Merge branch '12397-evaluation-fix' of https://github.com/konfuzio-ai…
iftwigs Jun 13, 2024
ce4a7ab
Fix: Reference before assignment, remove outdated span prioritize row…
nengelmann Jun 14, 2024
9513388
Added: parametrized tests to demonstrate different cases of ground tr…
iftwigs Jun 20, 2024
e576933
Fixed: numpy version for compatibility with different python versions
iftwigs Jun 20, 2024
c0a3f5e
Merge branch 'master' into 12397-evaluation-fix
MohamedAmineDHIAB Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 112 additions & 35 deletions konfuzio_sdk/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Calculate the accuracy on any level in a Document."""

import logging
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -49,6 +50,8 @@
'is_correct_id_',
'duplicated',
'duplicated_predicted',
'tmp_id_', # a temporary ID used for enumerating the predicted annotations solely
'disambiguated_id', # an ID for multi-span annotations
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,13 +84,39 @@ def grouped(group, target: str):
return group


def prioritize_rows(group):
"""
Apply a filter when a Label should only appear once per AnnotationSet but has been predicted multiple times.

After we have calculated the TPs, FPs, FNs for the Document, we filter out the case where a Label should
only appear once per AnnotationSet but has been predicted multiple times. In this case, if any of the
predictions is a TP then we keep one and discard FPs/FNs. If no TPs, if any of the predictions is a FP
then we keep one and discard the FNs. If no FPs, then we keep a FN. The prediction we keep is always the
first in terms of start_offset.
"""
group = group[~(group['label_has_multiple_top_candidates_predicted'].astype(bool))]
if group.empty:
return group

first_true_positive = group[group['true_positive']].head(1)
first_false_positive = group[group['false_positive']].head(1)
first_false_negative = group[group['false_negative']].head(1)
if not first_true_positive.empty:
return first_true_positive
elif not first_false_positive.empty:
return first_false_positive
else:
return first_false_negative


def compare(
doc_a,
doc_b,
only_use_correct=False,
use_view_annotations=False,
ignore_below_threshold=False,
strict=True,
id_counter: int = 1,
custom_threshold=None,
) -> pd.DataFrame:
"""Compare the Annotations of two potentially empty Documents wrt. to **all** Annotations.
Expand All @@ -107,13 +136,20 @@ def compare(
:return: Evaluation DataFrame
"""
df_a = pd.DataFrame(doc_a.eval_dict(use_correct=only_use_correct))
df_a_ids = df_a[['id_']]
duplicated_ids = df_a_ids['id_'].duplicated(keep=False)
df_a_ids['disambiguated_id'] = df_a_ids['id_'].astype(str)
df_a_ids.loc[duplicated_ids, 'disambiguated_id'] += '_' + (df_a_ids.groupby('id_').cumcount() + 1).astype(str)
df_a['disambiguated_id'] = df_a_ids['disambiguated_id']
df_b = pd.DataFrame(
doc_b.eval_dict(
use_view_annotations=strict and use_view_annotations, # view_annotations only available for strict=True
use_correct=False,
ignore_below_threshold=ignore_below_threshold,
)
),
)
df_b['tmp_id_'] = list(range(id_counter, id_counter + len(df_b)))

if doc_a.category != doc_b.category:
raise ValueError(f'Categories of {doc_a} with {doc_a.category} and {doc_b} with {doc_a.category} do not match.')
if strict: # many to many inner join to keep all Spans of both Documents
Expand Down Expand Up @@ -231,12 +267,12 @@ def compare(
& ((~spans['is_matched']) | (~spans['above_predicted_threshold']) | (spans['label_id_predicted'].isna()))
)

spans['false_positive'] = ( # commented out on purpose (spans["is_correct"]) &
spans['false_positive'] = (
(spans['above_predicted_threshold'])
& (~spans['false_negative'])
& (~spans['true_positive'])
& (~spans['duplicated_predicted'])
& ( # Something is wrong
& (
(~spans['is_correct_label'])
| (~spans['is_correct_label_set'])
| (~spans['is_correct_annotation_set_id'])
Expand All @@ -246,40 +282,64 @@ def compare(
)

if not strict:

def prioritize_rows(group):
"""
Apply a filter when a Label should only appear once per AnnotationSet but has been predicted multiple times.

After we have calculated the TPs, FPs, FNs for the Document, we filter out the case where a Label should
only appear once per AnnotationSet but has been predicted multiple times. In this case, if any of the
predictions is a TP then we keep one and discard FPs/FNs. If no TPs, if any of the predictions is a FP
then we keep one and discard the FNs. If no FPs, then we keep a FN. The prediction we keep is always the
first in terms of start_offset.
"""
group = group[~(group['label_has_multiple_top_candidates_predicted'].astype(bool))]
if group.empty:
return group

first_true_positive = group[group['true_positive']].head(1)
first_false_positive = group[group['false_positive']].head(1)
first_false_negative = group[group['false_negative']].head(1)
if not first_true_positive.empty:
return first_true_positive
elif not first_false_positive.empty:
return first_false_positive
else:
return first_false_negative

spans = spans.groupby(['annotation_set_id_predicted', 'label_id_predicted']).apply(prioritize_rows)
# Apply the function prioritize_rows just to entries where the label is not set to "multiple"
labels = doc_a.project.labels
label_ids_multiple = [label.id_ for label in labels if label.has_multiple_top_candidates]
label_ids_not_multiple = [label.id_ for label in labels if not label.has_multiple_top_candidates]
spans_not_multiple = spans[spans['label_id'].isin(label_ids_not_multiple)]
spans_not_multiple = spans_not_multiple.groupby(['annotation_set_id_predicted', 'label_id_predicted']).apply(
prioritize_rows
)
spans_multiple = spans[spans['label_id'].isin(label_ids_multiple)]
spans = pd.concat([spans_not_multiple, spans_multiple])
spans = spans.sort_values(by='is_matched', ascending=False)

spans = spans.replace({np.nan: None})
# one Span must not be defined as TP or FP or FN more than once
quality = (spans[['true_positive', 'false_positive', 'false_negative']].sum(axis=1) <= 1).all()
assert quality
# how many times annotations with this label occur in the ground truth data
spans['frequency'] = spans.groupby('label_id')['label_id'].transform('size')
spans['frequency'].fillna(0, inplace=True)
spans['frequency'] = spans['frequency'].apply(lambda x: int(x))

if not strict:
# one Span must not be defined as TP or FP or FN more than once
quality = (spans[['true_positive', 'false_positive', 'false_negative']].sum(axis=1) <= 1).all()
assert quality
return spans


class ExtractionConfusionMatrix:
"""Check how all predictions are mapped to the ground-truth Annotations."""

def __init__(self, data: pd.DataFrame):
"""
Initialize the class.

:param data: Raw evaluation data.
"""
self.matrix = self.calculate(data=data)

def calculate(self, data: pd.DataFrame):
"""
Calculate the matrix.

:param data: Raw evaluation data.
"""
data = data.reset_index(drop=True)
data['id_'] = data['id_'].fillna('no_match', inplace=True)
data['tmp_id_'] = data['tmp_id_'].fillna('no_match')

data['relation'] = data.apply(
lambda x: 'TP'
if x['true_positive']
else ('FP' if x['false_positive'] else ('FN' if x['false_negative'] else 'TN')),
axis=1,
)

matrix = pd.pivot(data, index='disambiguated_id', columns='tmp_id_', values=['relation'])
matrix.fillna('TN', inplace=True)
return matrix


class EvaluationCalculator:
"""Calculate precision, recall, f1, based on TP, FP, FN."""

Expand Down Expand Up @@ -420,6 +480,7 @@ def __init__(
def calculate(self):
"""Calculate and update the data stored within this Evaluation."""
evaluations = [] # start anew, the configuration of the Evaluation might have changed.
id_counter = 1
for ground_truth, predicted in self.documents:
evaluation = compare(
doc_a=ground_truth,
Expand All @@ -428,8 +489,11 @@ def calculate(self):
strict=self.strict,
use_view_annotations=self.use_view_annotations,
ignore_below_threshold=self.ignore_below_threshold,
id_counter=id_counter,
)
evaluations.append(evaluation)
id_counter += len(evaluation)

self.data = pd.concat(evaluations)

def calculate_thresholds(self):
Expand Down Expand Up @@ -527,9 +591,11 @@ def fn(self, search=None) -> int:

def tn(self, search=None) -> int:
"""Return the True Negatives of all Spans."""
return (
len(self._query(search=search)) - self.tp(search=search) - self.fn(search=search) - self.fp(search=search)
)
return len(self._query(search=None)) - self.tp(search=search) - self.fn(search=search) - self.fp(search=search)

def gt(self, search=None) -> int:
"""Return the number of ground-truth Annotations for a given Label."""
return len(self._query(search=search).dropna(subset=['label_id']))

def tokenizer_tp(self, search=None) -> int:
"""Return the tokenizer True Positives of all Spans."""
Expand Down Expand Up @@ -650,6 +716,9 @@ def get_wrong_vertical_merge(self):
self.data.groupby('id_local_predicted').apply(lambda group: self._apply(group, 'wrong_merge'))
return self.data[self.data['wrong_merge']]

def confusion_matrix(self):
return ExtractionConfusionMatrix(data=self.data)


class CategorizationEvaluation:
"""Calculated evaluation measures for the classification task of Document categorization."""
Expand Down Expand Up @@ -791,6 +860,10 @@ def tn(self, category: Optional[Category] = None) -> int:
"""Return the True Negatives of all Documents."""
return self._base_metric('tn', category)

def gt(self, category: Optional[Category] = None) -> int:
"""Placeholder for compatibility with Server."""
return 0

def get_evaluation_data(self, search: Category = None, allow_zero: bool = True) -> EvaluationCalculator:
"""
Get precision, recall, f1, based on TP, TN, FP, FN.
Expand Down Expand Up @@ -1042,6 +1115,10 @@ def tn(self, search: Category = None) -> int:
"""
return self._query('true_negatives', search)

def gt(self, search: Category = None) -> int:
"""Placeholder for compatibility with Server."""
return 0

def precision(self, search: Category = None) -> float:
"""
Return precision.
Expand Down
Loading
Loading