"""
Benchmark code. It's the main entry point for comparing strategies using
evaluation metrics such as precision, recall, and F1-score.
"""
import logging
import time
from enum import Enum
from itertools import islice
from typing import Dict, Tuple, List, Set
from acres.evaluation import metrics
from acres.model import expansion_standard, detection_standard, topic_list
from acres.resolution import resolver
from acres.stats import senses
from acres.util import text
from acres.util.acronym import Acronym
logger = logging.getLogger(__name__)
[docs]class Level(Enum):
"""
Enum that holds acronym-solving levels.
"""
TOKEN = 1
TYPE = 2
[docs]def analyze(contextualized_acronym: Acronym, true_expansions: Set[str],
strategy: resolver.Strategy, max_tries: int) -> Dict[str, bool]:
"""
Analyze a given row of the gold standard.
:param contextualized_acronym:
:param true_expansions:
:param strategy:
:param max_tries:
:return: A dictionary with keys {'found', 'correct', and 'ignored'} pointing to boolean.
"""
ret = {'found': False, 'correct': False, 'ignored': False}
# Normalize context so that we can match it to the training data
left_context = text.clean(contextualized_acronym.left_context)
acronym = contextualized_acronym.acronym
right_context = text.clean(contextualized_acronym.right_context)
# Normalize true expansions so that we can match them to the training data
true_expansions = set(map(text.clean, true_expansions))
# Remove context to improve cache hit
# XXX We currently support context only for n-grams
if strategy != resolver.Strategy.FASTNGRAM:
left_context = ""
right_context = ""
logger.debug("%s [%s] %s => %s", left_context, acronym, right_context, true_expansions)
fltered_expansions = resolver.filtered_resolve(acronym, left_context, right_context, strategy)
possible_expansions = list(islice(fltered_expansions, max_tries))
if possible_expansions:
logger.debug("FOUND: %s => %s", acronym, possible_expansions)
ret['found'] = True
ret['correct'] = test_input(true_expansions, possible_expansions, max_tries)
if not ret['found']:
logger.debug("NOT FOUND: %s", acronym)
elif not ret['correct']:
logger.debug("INCORRECT: %s", acronym)
return ret
def _valid(topics: List[Acronym], valid_standard: Set[str], level: Level) -> List[Acronym]:
"""
:param topics:
:param valid_standard:
:param level:
:return:
"""
valid = [] # type: List[Acronym]
types = set() # type: Set[str]
for contextualized_acronym in topics:
acronym = contextualized_acronym.acronym
# Ignore repeated types
if level == Level.TYPE and acronym in types:
# logger.debug("REPEATED {%s}: repeated type.", acronym)
continue
types.add(acronym)
if acronym not in valid_standard:
logger.debug("IGNORED {%s}: invalid acronym.", acronym)
continue
valid.append(contextualized_acronym)
return valid
[docs]def evaluate(topics: List[Acronym], valid_standard: Set[str], standard: Dict[str, Dict[str, int]],
strategy: resolver.Strategy, level: Level, max_tries: int,
lenient: bool) -> Tuple[List[Acronym], List[Acronym], List[Acronym]]:
"""
Analyze a gold standard with text excerpts centered on an acronym, followed by `n` valid
expansions.
:param topics:
:param valid_standard:
:param standard:
:param strategy:
:param level:
:param max_tries:
:param lenient: Whether to consider partial matches (1) as a valid sense.
:return: A tuple with lists containing correct, found, and valid contextualized acronyms
"""
found = [] # type: List[Acronym]
correct = [] # type: List[Acronym]
acronym_senses = senses.map_senses_acronym(standard, lenient)
valid = _valid(topics, valid_standard, level)
for contextualized_acronym in valid:
acronym = contextualized_acronym.acronym
true_expansions = acronym_senses[acronym]
row_analysis = analyze(contextualized_acronym, true_expansions, strategy, max_tries)
if row_analysis['found']:
found.append(contextualized_acronym)
if row_analysis['correct']:
correct.append(contextualized_acronym)
invalid_absolute = len(topics) - len(valid)
logger.info("Total: %s", len(topics))
logger.info("Ignored: %d", invalid_absolute)
logger.info("Valid: %d", len(valid))
logger.info("Found: %d", len(found))
logger.info("Correct: %d", len(correct))
return correct, found, valid
[docs]def do_analysis(topics_file: str, detection_file: str, expansion_file: str,
strategy: resolver.Strategy, level: Level,
max_tries: int, lenient: bool) -> Tuple[List[Acronym], List[Acronym], List[Acronym]]:
"""
Analyze a given expansion standard.
:param topics_file:
:param detection_file:
:param expansion_file:
:param strategy:
:param level:
:param max_tries:
:param lenient:
:return: A tuple with lists containing correct, found, and valid contextualized acronyms
"""
topics = topic_list.parse(topics_file)
valid_standard = detection_standard.parse_valid(detection_file)
standard = expansion_standard.parse(expansion_file)
start_time = time.time()
(correct, found, valid) = evaluate(topics, valid_standard, standard, strategy, level, max_tries,
lenient)
end_time = time.time()
print("Time: (s)", end_time - start_time)
final_precision = metrics.calculate_precision(len(correct), len(found))
final_recall = metrics.calculate_recall(len(correct), len(valid))
final_f1 = metrics.calculate_f1(final_precision, final_recall)
print("Strategy: ", strategy)
print("Rank: ", max_tries)
print("Lenient: ", lenient)
print("Precision: ", final_precision)
print("Recall: ", final_recall)
print("F1: ", final_f1)
return correct, found, valid
[docs]def plot_data(topics_file: str, detection_file: str, expansion_file: str):
"""
Run all strategies using different ranks and lenient approaches and generate a TSV file to be \
used as input for the `plots.R` script.
:param topics_file:
:param detection_file:
:param expansion_file:
:return:
"""
output = open("stats.tsv", "w+", encoding="utf-8")
output.write("k\tMajority\tSI\tfastType\tword2vec\n")
for lenient in [True, False]:
for rank in range(1, 21):
output.write(str(rank) + "\t")
for strategy in [resolver.Strategy.BASELINE, resolver.Strategy.DICTIONARY,
resolver.Strategy.FASTTYPE, resolver.Strategy.WORD2VEC]:
correct, found, valid = do_analysis(topics_file, detection_file, expansion_file,
strategy, Level.TYPE, rank, lenient)
precision = metrics.calculate_precision(len(correct), len(found))
recall = metrics.calculate_recall(len(correct), len(valid))
fone = metrics.calculate_f1(precision, recall)
output.write(str(fone) + "\t")
output.write("\n")
output.flush()
output.close()
[docs]def summary(topics_file: str, detection_file: str, expansion_file: str, level: Level,
max_tries: int, lenient: bool):
"""
Save a summary table in TSV format that can be used to run statistical tests (e.g. McNemar Test)
:param topics_file:
:param detection_file:
:param expansion_file:
:param level:
:param max_tries:
:param lenient:
:return:
"""
strategies = [resolver.Strategy.BASELINE, resolver.Strategy.DICTIONARY,
resolver.Strategy.FASTTYPE, resolver.Strategy.WORD2VEC]
topics = topic_list.parse(topics_file)
valid_standard = detection_standard.parse_valid(detection_file)
valid = _valid(topics, valid_standard, level)
# Most methods are optimized to run in alphabetic order by storing a large chunk in memory.
# Therefore, it's more efficient to run first each method so that results can be transposed.
results = {} # type: Dict[resolver.Strategy, Set[Acronym]]
for strategy in strategies:
logger.info("Strategy: %s", strategy)
correct, _, _ = do_analysis(topics_file, detection_file, expansion_file,strategy,
level, max_tries, lenient)
results[strategy] = set(correct)
output = open("summary.tsv", "w+", encoding="utf-8")
output.write("instance\tMajority\tSI\tfastType\tword2vec\n")
for contextualized_acronym in valid:
output.write(contextualized_acronym.acronym + "\t")
assessments = []
for strategy in strategies:
assessment = "T" if contextualized_acronym in results[strategy] else "F"
assessments.append(assessment)
output.write("\t".join(assessments))
output.write("\n")
output.close()
if __name__ == "__main__":
do_analysis("resources/topic_list.tsv",
"resources/detection_standard.tsv",
"resources/expansion_standard.tsv",
resolver.Strategy.WORD2VEC, Level.TYPE, 1, True)