Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "doesn't match" evaluation to KeyedVectors #2765

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 281 additions & 1 deletion gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@

from __future__ import division # py3 "true division"

from itertools import chain
from itertools import chain, combinations
import logging
from numbers import Integral

Expand All @@ -174,6 +174,9 @@
ndarray, sum as np_sum, prod, argmax
import numpy as np

import re
from gensim.models import Word2Vec

from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.corpora.dictionary import Dictionary
from six import string_types, integer_types
Expand Down Expand Up @@ -569,6 +572,118 @@ def most_similar(self, positive=None, negative=None, topn=10, restrict_vocab=Non
result = [(self.index2word[sim], float(dists[sim])) for sim in best if sim not in all_words]
return result[:topn]


def evaluate_top_k_similar(self, file, model, k, topk_in_cat=True, debug=False):
"""Compute the performance of a model on a topk similarity test
for a given txt file.

Parameters
----------
file : str
Path to the .txt file of categories
model : str
Location of word2vec model to be loaded
k : int
Number of words from the same category for a given comparsion.
e.g (3_in -> total group size of 4)
topk_in_cat : bool, optional
Determines which accuracy to compute. topk_in_cat finds the
ratio of the number of matches to the number
of words in the category which generated the topk list.
cat_in_topk finds the ratio of the number of matches to the
number of words in the topk list.
debug : bool, optional
A flag used to print the comparisons to the console (default is
False)

Returns
-------
group score : dictionary of {str-str: float}
A dictionary of cross-category scores where each key identifies the categories being compared
and the value is the accuracy of the odd-one-out task for that pairing.

"""

# read in test set file as dictionary where each category is a key
# and the corresponding value is the list of words belonging to it.
cats = {}
with open(file, 'r') as f:
for line in f:
if re.search('^:',line):
key = line.rstrip()
elif not len(line.strip()) == 0:
cats.update({key: line.split()})

# load the word2vec model to be tested
test_model = Word2Vec.load(model)

# initialize running totals
total_score = 0
total_count = 0
# denominator for topk_in_cat accuracy
total_denom = 0
# for storing accuracies
cat_acc = {}

# find the topk most similar for each word in all categories
for key, value in cats.items():
# counters for the categories
cat_score = 0
cat_count = 0
cat_denom = 0
for entry in value:
# update counters
total_denom += len(value)-1
cat_denom += len(value)-1
total_count += 1
cat_count += 1
# zero out score for new word
word_score = 0
# find top_k similar words for a given entry
top_k = test_model.wv.most_similar(positive=entry, topn=k)
if debug:
print('entry=',entry)
print('topk=',top_k)
# items in each category
for i in value:
# items in topk list
for j in range(len(top_k)):
# ignore word that generated top_k list
if i != entry:
# find the number of matches between category list and top_k list
word_score += i in top_k[j]
cat_score += i in top_k[j]
total_score += i in top_k[j]
if debug:
print('word:', i)
print('check=',i,'against', 'top_k['+str(j)+']=', top_k[j])
print('in_top_k=', i in top_k[j])
# decide accuracy metric
if topk_in_cat:
# update score dict
cat_acc.update({str(key)+'-'+str(k):cat_score/cat_denom})
print('entry_score=', word_score/(len(value)-1))
print('cat_score=', cat_score/cat_denom)
print('total_score=', total_score/total_denom)

# cat_in_topk accuracy
else:
# update score dict
cat_acc.update({str(key+'-'+str(k)):cat_score/(k*cat_count)})
print('entry_score=', word_score/len(top_k))
print('cat_score=', cat_score/(k*cat_count))
print('total_score=', total_score/(k*total_count))

# summary statistics
print('number of comparisons=',total_count)
if topk_in_cat:
print('total_score=', total_score/total_denom)
else:
print('total accuracy =', total_score/(k*total_count))
print(cat_acc)
return cat_acc


def similar_by_word(self, word, topn=10, restrict_vocab=None):
"""Find the top-N most similar words.

Expand Down Expand Up @@ -878,6 +993,171 @@ def doesnt_match(self, words):
mean = matutils.unitvec(vectors.mean(axis=0)).astype(REAL)
dists = dot(vectors, mean)
return sorted(zip(dists, used_words))[0][1]

def evaluate_doesnt_match(self, cat_file, model, k_in=3, eval_dupes=False, debug=False):
"""Compute the performance of a model on the doesnt match task
for a given test set.

Parameters
----------
cat_file : str
Path to the .txt file of categories
model : str
Location of word2vec model to be loaded
k_in : int
Number of words from the same category for a given comparsion.
e.g (3_in -> total group size of 4)
eval_dupes : bool, optional
Determines whether to evaluate groups with a word that belongs to
two or more categories. Default is to ignore these comparisons.
debug : bool, optional
A flag used to print the comparisons to the console (default is
False)

Returns
-------
group score : dictionary of {str-str: float}
A dictionary of cross-category scores where each key identifies the categories being compared
and the value is the accuracy of the odd-one-out task for that pairing.

Raises
-------
KeyError
if file contains word(s) that are not in the specified model's vocabulary

ValueError
if not all categories contain atleast k_in words.

"""

# for identifying k-combos
regex = str(k_in)+'$'

# read in test set file as dictionary where each category is a key
# and the corresponding value is the list of words belonging to it.
cats = {}
with open(cat_file, 'r') as f:
for line in f:
if re.search('^:',line):
key = line.rstrip()
# check category contains enough values
elif not len(line.strip()) == 0 and len(line.strip().split()) < k_in:
raise ValueError('Atleast one category does not contain enough values for specified size. Each category must contain a minimum of k_in values')
elif not len(line.strip()) == 0:
cats.update({key: line.split()})

# load the word2vec model to be tested
test_model = Word2Vec.load(model)

# verify that all words are in the vocabulary
# else raise an error
for key, value in cats.items():
for word in value:
if word in test_model.wv.vocab:
pass
else:
raise KeyError('word '+word+' is not in vocabulary')

# odd-one-out testing
combos = {} # place to store all possible category combos
one_out = [] # place to store the predicted odd-one-out
labels = [] # place to store the actual odd-one-out

# for storing accuracies
cross_cat_acc = {}
# store list of comparisons with dupes
dupes = []

# find all k-combos and 1-combos for each category and store in dictionary
for key in cats:
k_combos = list(combinations(cats[key],k_in))
one_combos = list(combinations(cats[key],1))
combos.update({str(key)+str(k_in): [k_combos], str(key)+'1': one_combos})
# pair each k-combo with all possible other 1-combos
test_set = list(combinations(combos.keys(),2))
if(debug):
print('len(test_set)=',len(test_set))

for pair in range(len(test_set)):
# ignore comparisons between two groups of k
if(re.search(regex,test_set[pair][0]) and re.search(regex,test_set[pair][1])):
pass
# ignore comparisons between two groups of 1
elif(re.search('1$',test_set[pair][0]) and re.search('1$',test_set[pair][1])):
pass
# ignore pairings with a 1-comb and 3-comb from the same category
elif(test_set[pair][0][0:-1] == test_set[pair][1][0:-1]):
pass
# store the value from each pair for accessibility
else:
if(debug):
print('test pair =', test_set[pair][0][0:-1], test_set[pair][1][0:-1])
ls1 = combos[test_set[pair][0]] # access values from combo dict for 1st key in pair
ls2 = combos[test_set[pair][1]] # access values from combo dict for 2nd key in pair
# create key for categories being compared
cross_cat_acc.update({test_set[pair][0]+'-'+test_set[pair][1]: 0})
# zero out score for this comparison pair
score = 0
count = 0
# loop through all value comparisons for this pair
for i in range(len(ls1)):
for j in range(len(ls1[i])):
for k in range(len(ls2)):
for l in range(len(ls2[k])):
# find which of the value is a 1-comb and therefore the actual odd-one-out
if isinstance(ls1[i][j], tuple):
actual = ls2[k][l]
result = list(ls1[i][j])
result.append(actual)

else:
actual = ls1[i][j]
result = list(ls2[k][l])
result.append(actual)

if debug:
print(pair,i,j,k,l,'this is the comparison=',result)

# check for duplicate words in result list
if not eval_dupes:
unique = set(result)
is_dupe = len(unique) != len(result)
if debug:
print('unique=',unique)
print('result=',result)
print('is_dupe=',is_dupe)
if is_dupe:
dupes.append(result)

# only evaluate non-dupes unless overridden by user
if eval_dupes or not is_dupe:
# predict odd-one-out using gensims doesnt_match function
pred = test_model.wv.doesnt_match(result)
if(debug):
print('predicted doesnt match=',pred)
# append predicted and actual to master lists
one_out.append(pred)
labels.append(actual)
# score
score += int(pred == actual)
count += 1
# update accuracy dict
cross_cat_acc.update({test_set[pair][0]+'-'+test_set[pair][1]: score/count})


# compare the prediction list with the labels list to calculate total accuracy
correct = 0.0
total = len(one_out)
for prediction in range(total):
correct += int(one_out[prediction] == labels[prediction])
accuracy = correct/total
print('group scores =',cross_cat_acc)
print('total accuracy =',accuracy)
print('number of comparisons =',total)
if not eval_dupes:
print('comparisons with duplicates=',dupes)
return cross_cat_acc


@staticmethod
def cosine_similarities(vector_1, vectors_all):
Expand Down