diff --git a/applications/neural_search/recall/in_batch_negative/evaluate.py b/applications/neural_search/recall/in_batch_negative/evaluate.py index 89da8e5157fe..ea9dde6c6107 100644 --- a/applications/neural_search/recall/in_batch_negative/evaluate.py +++ b/applications/neural_search/recall/in_batch_negative/evaluate.py @@ -73,17 +73,15 @@ def recall(rs, N=10): with open(args.recall_result_file, "r", encoding="utf-8") as f: relevance_labels = [] for index, line in enumerate(f): - if index % args.recall_num == 0 and index != 0: - rs.append(relevance_labels) - relevance_labels = [] - text, recalled_text, cosine_sim = line.rstrip().split("\t") if text2similar[text] == recalled_text: relevance_labels.append(1) else: relevance_labels.append(0) - rs.append(relevance_labels) + if (index + 1) % args.recall_num == 0: + rs.append(relevance_labels) + relevance_labels = [] recall_N = [] recall_num = [1, 5, 10, 20, 50] diff --git a/applications/neural_search/recall/simcse/evaluate.py b/applications/neural_search/recall/simcse/evaluate.py index 63d3b2fe1634..bc991250b9df 100644 --- a/applications/neural_search/recall/simcse/evaluate.py +++ b/applications/neural_search/recall/simcse/evaluate.py @@ -57,17 +57,16 @@ def recall(rs, N=10): with open(args.recall_result_file, "r", encoding="utf-8") as f: relevance_labels = [] for index, line in enumerate(f): - - if index % args.recall_num == 0 and index != 0: - rs.append(relevance_labels) - relevance_labels = [] - text, recalled_text, cosine_sim = line.rstrip().split("\t") if text2similar[text] == recalled_text: relevance_labels.append(1) else: relevance_labels.append(0) + if (index + 1) % args.recall_num == 0: + rs.append(relevance_labels) + relevance_labels = [] + recall_N = [] recall_num = [1, 5, 10, 20, 50] result = open("result.tsv", "a")