Skip to content

Commit

Permalink
Merge pull request #22 from biaslyze-dev/counterfactual-text-augmentor
Browse files Browse the repository at this point in the history
Implement a CounterfactualTextAugmentor class to abstract the keyword replacement process
  • Loading branch information
tsterbak committed Jul 14, 2023
2 parents 305f8c8 + b8ee325 commit c1ac4ff
Show file tree
Hide file tree
Showing 22 changed files with 2,151 additions and 739 deletions.
12 changes: 12 additions & 0 deletions biaslyze/augmentors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""This module contains classes for augmenting text data with bias concepts."""

from typing import List, Optional, Union


class CounterfactualTextAugmentor:
"""Class for augmenting text data with counterfactuals.
"""

def __init__(self) -> None:
pass
36 changes: 24 additions & 12 deletions biaslyze/bias_detectors/counterfactual_biasdetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CounterfactualDetectionResult,
CounterfactualSample,
)
from biaslyze.augmentors import CounterfactualTextAugmentor
from biaslyze.text_representation import TextRepresentation, process_texts_with_spacy


Expand Down Expand Up @@ -56,15 +57,18 @@ class CounterfactualBiasDetector:
Attributes:
use_tokenizer: If keywords should only be searched in tokenized text. Can be useful for short keywords like 'she'.
concept_detector: an instance of KeywordConceptDetector
text_augmentor: an instance of CounterfactualTextAugmentor
"""

def __init__(
self,
use_tokenizer: bool = False,
concept_detector: KeywordConceptDetector = KeywordConceptDetector(),
text_augmentor: CounterfactualTextAugmentor = CounterfactualTextAugmentor(),
):
self.use_tokenizer = use_tokenizer
self.concept_detector = concept_detector
self.text_augmentor = text_augmentor

# overwrite use_tokenizer
self.concept_detector.use_tokenizer = self.use_tokenizer
Expand Down Expand Up @@ -125,7 +129,7 @@ def process(
if max_counterfactual_samples:
max_counterfactual_samples_per_text = max_counterfactual_samples // len(
detected_texts
)
) + 1

results = []
for concept in self.concepts:
Expand All @@ -141,17 +145,23 @@ def process(
n_texts=max_counterfactual_samples_per_text,
)
if not counterfactual_samples:
logger.warning(f"No samples containing {concept} found. Skipping.")
logger.warning(f"No samples containing {concept.name} found. Skipping.")
continue

# calculate counterfactual scores for each keyword
for keyword in tqdm(concept.keywords):
# get the counterfactual scores
counterfactual_scores = _calculate_counterfactual_scores(
bias_keyword=keyword.text,
predict_func=predict_func,
samples=counterfactual_samples,
)
try:
counterfactual_scores = _calculate_counterfactual_scores(
bias_keyword=keyword.text,
predict_func=predict_func,
samples=counterfactual_samples,
)
except ValueError:
logger.warning(
f"Could not calculate counterfactual scores for keyword {keyword.text}. Skipping."
)
continue
# add to score dict
score_dict[keyword.text] = counterfactual_scores
# add scores to samples
Expand Down Expand Up @@ -195,6 +205,7 @@ def _extract_counterfactual_concept_samples(
texts: List[str],
labels: Optional[List[str]] = None,
n_texts: Optional[int] = None,
respect_function: bool = True,
) -> List[CounterfactualSample]:
"""Extract counterfactual samples for a given concept from a list of texts.
Expand All @@ -207,22 +218,23 @@ def _extract_counterfactual_concept_samples(
tokenizer: The tokenizer to use for tokenization.
labels: Optional. Used to add labels to the counterfactual results.
n_texts: Optional. The number of counterfactual texts to return. Defaults to None, which returns all possible counterfactual texts.
respect_function: If True, only replace keywords with the same function.
Returns:
A list of CounterfactualSample objects.
"""
counterfactual_samples = []
original_texts = []
text_representations: List[TextRepresentation] = process_texts_with_spacy(texts)
for idx, (text, text_representation) in tqdm(
enumerate(zip(texts, text_representations)), total=len(texts)
for idx, text_representation in tqdm(
enumerate(text_representations), total=len(text_representations)
):
present_keywords = concept.get_present_keywords(text_representation)
if present_keywords:
original_texts.append(text)
original_texts.append(text_representation.text)
for orig_keyword in present_keywords:
counterfactual_texts = concept.get_counterfactual_texts(
orig_keyword, text_representation, n_texts=n_texts
orig_keyword, text_representation, n_texts=n_texts, respect_function=respect_function
)
for counterfactual_text, counterfactual_keyword in counterfactual_texts:
counterfactual_samples.append(
Expand All @@ -233,7 +245,7 @@ def _extract_counterfactual_concept_samples(
concept=concept.name,
tokenized=text_representation,
label=labels[idx] if labels else None,
source_text=text,
source_text=text_representation.text,
)
)
logger.info(
Expand Down
50 changes: 39 additions & 11 deletions biaslyze/concept_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ class Keyword:
category (str): The category of the keyword.
"""

def __init__(self, text: str, function: List[str], category: str):
def __init__(self, text: str, functions: List[str], category: str):
"""The constructor for the Keyword class."""
self.text = text
self.function = function
self.functions = functions
self.category = category

def __str__(self) -> str:
Expand All @@ -32,8 +32,16 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"Keyword({self.text}, {self.function}, {self.category})"

def can_replace_token(self, token: Token) -> bool:
"""Returns True if the keyword can replace the given token."""
def can_replace_token(self, token: Token, respect_function: bool = False) -> bool:
"""Returns True if the keyword can replace the given token.
Args:
token (Token): The token to replace.
respect_function (bool): Whether to respect the function of the keyword. Defaults to False.
"""
if respect_function:
return True
# return token.function in self.functions
return True

def equal_to_token(self, token: Token) -> bool:
Expand All @@ -43,8 +51,24 @@ def equal_to_token(self, token: Token) -> bool:
return False

def get_keyword_in_style_of_token(self, token: Token) -> str:
"""Returns the keyword text in the style of the given token."""
return self.text
"""Returns the keyword text in the style of the given token.
Uses the shape of the token to determine the style.
Args:
token (Token): The token to get the style from.
Returns:
str: The keyword text in the style of the given token.
"""
if "X" not in token.shape:
return self.text.lower()
elif "x" not in token.shape:
return self.text.upper()
elif token.shape[0] == "X":
return self.text.capitalize()
else:
return self.text


class Concept:
Expand All @@ -68,9 +92,9 @@ def from_dict_keyword_list(cls, name: str, keywords: List[dict]):
for keyword in keywords:
keyword_list.append(
Keyword(
keyword["keyword"],
keyword["function"],
keyword.get("category", None),
text=keyword["keyword"],
functions=keyword["function"],
category=keyword.get("category", None),
)
)
return cls(name, keyword_list)
Expand All @@ -90,13 +114,15 @@ def get_counterfactual_texts(
keyword: Keyword,
text_representation: TextRepresentation,
n_texts: Optional[int] = None,
respect_function: bool = True,
) -> List[Tuple[str, Keyword]]:
"""Returns a counterfactual texts based on a specific keyword for the given text representation.
Args:
keyword (Keyword): The keyword in the text to replace.
text_representation (TextRepresentation): The text representation to replace the keyword in.
n_texts (Optional[int]): The number of counterfactual texts to return. Defaults to None, which returns all possible counterfactual texts.
respect_function (bool): Whether to respect the function of the keyword. Defaults to True.
Returns:
List[Tuple[str, Keyword]]: A list of tuples containing the counterfactual text and the keyword that was replaced.
Expand All @@ -110,7 +136,9 @@ def get_counterfactual_texts(
# create a counterfactual text for each keyword until n_texts is reached
for counterfactual_keyword in self.keywords:
# check if the keyword can be replaced by another keyword
if counterfactual_keyword.can_replace_token(token):
if counterfactual_keyword.can_replace_token(
token, respect_function
):
# create the counterfactual text
counterfactual_text = (
text_representation.text[: token.start]
Expand All @@ -123,7 +151,7 @@ def get_counterfactual_texts(
(counterfactual_text, counterfactual_keyword)
)
# check if n_texts is reached and return the counterfactual texts
if len(counterfactual_texts) == n_texts:
if n_texts and (len(counterfactual_texts) >= n_texts):
return counterfactual_texts
return counterfactual_texts

Expand Down
Loading

0 comments on commit c1ac4ff

Please sign in to comment.