From d4a0b48e4fa321fcfa4c4992f74eb757e375ab73 Mon Sep 17 00:00:00 2001 From: hwelsters Date: Thu, 5 Oct 2023 11:26:05 -0700 Subject: [PATCH] feat: tag clustering using ML --- .../management/commands/curator_clean_tags.py | 16 +- .../commands/curator_cluster_tags.py | 64 +++++ .../commands/curator_edit_clusters.py | 16 ++ .../management/commands/curator_map_tags.py | 77 +++++ .../commands/curator_modify_canon.py | 16 ++ ...nicaltag_canonicaltagmapping_tagcluster.py | 44 +++ django/curator/models.py | 86 +++++- django/curator/tag_deduplication.py | 264 ++++++++++++++++++ .../curator/tests/test_tag_deduplication.py | 90 ++++++ django/requirements.txt | 1 + 10 files changed, 654 insertions(+), 20 deletions(-) create mode 100644 django/curator/management/commands/curator_cluster_tags.py create mode 100644 django/curator/management/commands/curator_edit_clusters.py create mode 100644 django/curator/management/commands/curator_map_tags.py create mode 100644 django/curator/management/commands/curator_modify_canon.py create mode 100644 django/curator/migrations/0003_canonicaltag_canonicaltagmapping_tagcluster.py create mode 100644 django/curator/tag_deduplication.py create mode 100644 django/curator/tests/test_tag_deduplication.py diff --git a/django/curator/management/commands/curator_clean_tags.py b/django/curator/management/commands/curator_clean_tags.py index 2a03bbed7..9dcfe513c 100644 --- a/django/curator/management/commands/curator_clean_tags.py +++ b/django/curator/management/commands/curator_clean_tags.py @@ -31,7 +31,7 @@ def add_arguments(self, parser): def handle_load(self, restore_directory): path = restore_directory.joinpath(PENDING_TAG_CLEANUPS_FILENAME) - print("Loading data from path {}".format(str(path))) + logger.debug("Loading data from path %s", path) tag_cleanups = TagCleanup.load(path) TagCleanup.objects.bulk_create(tag_cleanups) @@ -50,11 +50,11 @@ def handle_run(self): def handle_view(self): qs = TagCleanup.objects.filter(is_active=True) if qs.count() > 0: - print("Tag Cleanups\n--------------------\n") + logger.debug("Tag Cleanups\n--------------------\n") for tag_cleanup in qs.iterator(): - print(tag_cleanup) + logger.debug(tag_cleanup) else: - print("No Pending Tag Cleanups!") + logger.debug("No Pending Tag Cleanups!") def handle(self, *args, **options): run = options["run"] @@ -70,11 +70,11 @@ def handle(self, *args, **options): elif method: self.handle_method(method) elif dump: - print( - "Dumping tag curation data to {}".format( - load_directory.joinpath(PENDING_TAG_CLEANUPS_FILENAME) - ) + logger.debug( + "Dumping tag curation data to %s", + load_directory.joinpath(PENDING_TAG_CLEANUPS_FILENAME), ) + TagCleanup.objects.dump( load_directory.joinpath(PENDING_TAG_CLEANUPS_FILENAME) ) diff --git a/django/curator/management/commands/curator_cluster_tags.py b/django/curator/management/commands/curator_cluster_tags.py new file mode 100644 index 000000000..6f9ffcfce --- /dev/null +++ b/django/curator/management/commands/curator_cluster_tags.py @@ -0,0 +1,64 @@ +import logging + +from django.core.management.base import BaseCommand + +from curator.tag_deduplication import TagClusterer, TagClusterManager + + +class Command(BaseCommand): + help = """ + Cluster Tags using dedupe. This command takes the rows of tags available in the database and clusters the tags together using Dedupe. + It takes in the rows available in the Tag table and attempts to create CanonicalTag objects that are stored in the database. + """ + + def add_arguments(self, parser): + parser.add_argument( + "--label", + "-l", + help="label the training data for the clustering model using the console.", + action="store_true", + default=False, + ) + + parser.add_argument( + "--reset", + "-r", + help="""remove all unlabelled clusters from the database.""", + action="store_true", + default=False, + ) + + parser.add_argument( + "--threshold", + "-t", + help="""float between [0,1]. Blank defaults to 0.5. + Defines how much confidence to require from the model before tags are clustered. + Higher thresholds cluster less tags and require more training data labels.""", + default=0.5, + ) + + def handle(self, *args, **options): + """ + `curator_cluster_tags should be used only if the curator would like for a large amount of unlabelled tags to be clustered. + For individual tags, the TagGazetteer is more preferred. + """ + if TagClusterManager.has_unlabelled_clusters() and not options["reset"]: + logging.warn( + "There are still some unlabelled clusters. Finish labelling those using curator_edit_clusters or run this command with the --reset option to remove all unlabelled clusters." + ) + return + + TagClusterManager.reset() + tag_clusterer = TagClusterer(clustering_threshold=options["threshold"]) + + if options["label"]: + tag_clusterer.console_label() + tag_clusterer.save_to_training_file() + + if not tag_clusterer.training_file_exists(): + logging.warn( + "Your model does not have any labelled data. Run this command with --label and try again." + ) + + clusters = tag_clusterer.cluster_tags() + tag_clusterer.save_clusters(clusters) diff --git a/django/curator/management/commands/curator_edit_clusters.py b/django/curator/management/commands/curator_edit_clusters.py new file mode 100644 index 000000000..99af26eab --- /dev/null +++ b/django/curator/management/commands/curator_edit_clusters.py @@ -0,0 +1,16 @@ +import logging + +from django.core.management.base import BaseCommand + +from curator.tag_deduplication import TagClusterer, TagClusterManager + + +class Command(BaseCommand): + # TODO: Expand upon this + help = """ + Edit clusters. + """ + + def handle(self, *args, **options): + """ """ + TagClusterManager.console_label() diff --git a/django/curator/management/commands/curator_map_tags.py b/django/curator/management/commands/curator_map_tags.py new file mode 100644 index 000000000..c1284b157 --- /dev/null +++ b/django/curator/management/commands/curator_map_tags.py @@ -0,0 +1,77 @@ +import logging + +from django.core.management.base import BaseCommand +from taggit.models import Tag + +from curator.tag_deduplication import TagGazetteer +from curator.models import CanonicalTag, CanonicalTagMapping + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Matches tags to a canonical list of tags using dedupe. This command finds a canonical tag " + + def add_arguments(self, parser): + parser.add_argument( + "--label", + "-l", + help="label the training data for the gazetteering model using the console", + action="store_true", + default=False, + ) + parser.add_argument( + "--threshold", + "-t", + help="""float between [0,1]. Blank defaults to 0.5. + Defines how much confidence to require from the model before tags are selected from the canonical list. + Higher thresholds matches less tags to those in a canonical list and require more training data labels.""", + default=0.5, + ) + + def handle(self, *args, **options): + """ + `curator_gazetteer_tags` searches for CanonicalTags that most closely match a certain tag. + From a canonical list, the canonical tag that most closely matches is selected. + """ + if not CanonicalTag.objects.exists(): + logger.warn( + "Canonical list is empty, populating canonical list using the curator_cluster_tags command" + ) + return + + tag_gazetteer = TagGazetteer(float(options["threshold"])) + + if options["label"]: + tag_gazetteer.console_label() + tag_gazetteer.save_to_training_file() + + if not tag_gazetteer.training_file_exists(): + logging.warn( + "Your model does not have any labelled data. Run this command with --label and try again." + ) + + tags = Tag.objects.filter(canonicaltagmapping=None) + is_unmatched = False + for tag in tags: + matches = tag_gazetteer.text_search(tag.name) + if matches: + match = matches[0] + canonical_tag_mapping = CanonicalTagMapping( + tag=tag, canonical_tag=match[0], confidence_score=match[1] + ) + + is_correct = input( + f"Does the following mapping make sense?:\n{str(canonical_tag_mapping)}\n(y)es/(n)o\n" + ) + + if is_correct == "y": + print("Mapped tag!") + canonical_tag_mapping.save() + else: + is_unmatched = True + + if is_unmatched: + logging.warn( + "There are some Tags that could not be matched to CanonicalTags. Either lower the threshold or increase the training data size." + ) diff --git a/django/curator/management/commands/curator_modify_canon.py b/django/curator/management/commands/curator_modify_canon.py new file mode 100644 index 000000000..525e219a6 --- /dev/null +++ b/django/curator/management/commands/curator_modify_canon.py @@ -0,0 +1,16 @@ +import logging + +from django.core.management.base import BaseCommand +from taggit.models import Tag + +from curator.tag_deduplication import TagGazetteer, TagClusterManager +from curator.models import CanonicalTag, CanonicalTagMapping + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Matches tags to a canonical list of tags using dedupe. This command finds a canonical tag " + + def handle(self, *args, **options): + TagClusterManager.console_canonicalize_edit() diff --git a/django/curator/migrations/0003_canonicaltag_canonicaltagmapping_tagcluster.py b/django/curator/migrations/0003_canonicaltag_canonicaltagmapping_tagcluster.py new file mode 100644 index 000000000..0dd8e4442 --- /dev/null +++ b/django/curator/migrations/0003_canonicaltag_canonicaltagmapping_tagcluster.py @@ -0,0 +1,44 @@ +# Generated by Django 3.2.21 on 2023-10-05 18:30 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('taggit', '0005_auto_20220424_2025'), + ('curator', '0002_rename_tagcleanup_permission'), + ] + + operations = [ + migrations.CreateModel( + name='CanonicalTag', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.TextField(unique=True)), + ], + ), + migrations.CreateModel( + name='TagCluster', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('canonical_tag_name', models.TextField()), + ('confidence_score', models.FloatField()), + ('date_created', models.DateTimeField(auto_now_add=True, null=True)), + ('tags', models.ManyToManyField(to='taggit.Tag')), + ], + ), + migrations.CreateModel( + name='CanonicalTagMapping', + fields=[ + ('tag', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='taggit.tag')), + ('confidence_score', models.FloatField()), + ('date_created', models.DateTimeField(auto_now=True)), + ('canonical_tag', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to='curator.canonicaltag')), + ('curator', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL)), + ], + ), + ] diff --git a/django/curator/models.py b/django/curator/models.py index 90a8e92cc..ac271d116 100644 --- a/django/curator/models.py +++ b/django/curator/models.py @@ -1,13 +1,16 @@ import json import logging -from collections import defaultdict - -import modelcluster.fields import os import re + +from collections import defaultdict +from django.core.exceptions import FieldDoesNotExist +from django.contrib.auth.models import User from django.contrib.postgres.aggregates import ArrayAgg +from django.contrib.postgres.fields import ArrayField from django.db import models, transaction from django.urls import reverse +from modelcluster import fields from nltk.corpus import stopwords from nltk.stem import PorterStemmer from nltk.tokenize import word_tokenize @@ -23,8 +26,8 @@ def has_parental_object_content_field(model): try: field = model._meta.get_field("content_object") - return isinstance(field, modelcluster.fields.ParentalKey) - except models.FieldDoesNotExist: + return isinstance(field, fields.ParentalKey) + except FieldDoesNotExist: return False @@ -119,7 +122,7 @@ def __init__(self, name, regex): self.regex = regex -def pl_regex(name, flags=re.I): +def pl_regex(name, flags=re.IGNORECASE): return re.compile( r"\b{}(?:,|\b|\s+|\Z|v?\d+\.\d+\.\d+(?:-?\w[\w-]*)*|\d+)".format(name), flags=flags, @@ -140,7 +143,7 @@ def pl_regex(name, flags=re.I): Matcher("Jade", pl_regex("jade")), Matcher("Jason", pl_regex("jason")), Matcher("Java", pl_regex("java")), - Matcher("James II", pl_regex("james\s+ii")), + Matcher("James II", pl_regex("james\\s+ii")), Matcher("Logo", pl_regex("logo")), Matcher("NetLogo", pl_regex("netlogo")), Matcher("Mason", pl_regex("mason")), @@ -148,7 +151,7 @@ def pl_regex(name, flags=re.I): Matcher("MatLab", pl_regex("matlab")), Matcher("Objective-C", pl_regex(r"objective(:?[-\s]+)?c")), Matcher("Pandora", pl_regex("pandora")), - Matcher("Powersim Studio", pl_regex("powersim\s+studio")), + Matcher("Powersim Studio", pl_regex("powersim\\s+studio")), Matcher("Python", pl_regex("python")), Matcher("R", pl_regex("r")), Matcher("Repast", pl_regex("repast")), @@ -170,7 +173,7 @@ class TagCleanupTransaction(models.Model): date_created = models.DateTimeField(auto_now_add=True) def __str__(self): - return "{}".format(self.date_created.strftime("%c")) + return self.date_created.strftime("%c") class TagCleanup(models.Model): @@ -183,9 +186,7 @@ class TagCleanup(models.Model): objects = TagCleanupQuerySet.as_manager() def __str__(self): - return "id={} new_name={}, old_name={}".format( - self.id, repr(self.new_name), repr(self.old_name) - ) + return f"id={self.id} new_name={self.new_name}, old_name={self.old_name}" @classmethod def find_groups_by_porter_stemmer(cls): @@ -300,3 +301,64 @@ def migrate(self, new_names, old_name): for model in through_models: self.copy_through_model_refs(model, new_tags=new_tags, old_tag=old_tag) old_tag.delete() + + +class TagCluster(models.Model): + canonical_tag_name = models.TextField() + tags = models.ManyToManyField(Tag) + confidence_score = models.FloatField() + date_created = models.DateTimeField(auto_now_add=True, null=True) + + def add_tag_by_name(self, name: str): + tag = Tag.objects.filter(name=name) + if tag: + self.tags.add(tag.first()) + return tag + + def save_mapping(self): + canonical_tag, _ = CanonicalTag.objects.get_or_create( + name=self.canonical_tag_name + ) + + previous_tag_mappings = CanonicalTagMapping.objects.filter( + canonical_tag=canonical_tag + ) + for previous_tag_mapping in previous_tag_mappings: + previous_tag_mapping.canonical_tag = None + previous_tag_mapping.save() + + tag_mappings = [ + CanonicalTagMapping( + tag=tag, + canonical_tag=canonical_tag, + confidence_score=self.confidence_score, + ) + for tag in self.tags.all() + ] + for tag_mapping in tag_mappings: + tag_mapping.save() + + return canonical_tag, tag_mappings + + def __str__(self): + return f"canonical_tag_name={self.canonical_tag_name} tags={self.tags.all()}" + + +class CanonicalTag(models.Model): + name = models.TextField(unique=True) + + def __str__(self): + return f"name={self.name}" + + +class CanonicalTagMapping(models.Model): + tag = models.OneToOneField(Tag, on_delete=models.deletion.CASCADE, primary_key=True) + canonical_tag = models.ForeignKey( + CanonicalTag, null=True, on_delete=models.SET_NULL + ) + curator = models.ForeignKey(User, null=True, on_delete=models.SET_NULL) + confidence_score = models.FloatField() + date_created = models.DateTimeField(auto_now=True) + + def __str__(self): + return f"tag={self.tag} canonical_tag={self.canonical_tag.name} confidence={self.confidence_score}" diff --git a/django/curator/tag_deduplication.py b/django/curator/tag_deduplication.py new file mode 100644 index 000000000..a7eeb6bda --- /dev/null +++ b/django/curator/tag_deduplication.py @@ -0,0 +1,264 @@ +import abc +import logging +import os +import re +import enum +from typing import List + +import dedupe +from taggit.models import Tag + +from curator.models import CanonicalTagMapping, CanonicalTag, TagCluster + + +class AbstractTagDeduper(abc.ABC): + TRAINING_FILE = "curator/clustering_training.json" + FIELDS = [{"field": "name", "type": "String"}] + + def uncertain_pairs(self): + return self.deduper.uncertain_pairs() + + def mark_pairs(self, pairs, is_distinct: bool): + labelled_examples = {"match": [], "distinct": []} + + example_key = "distinct" if is_distinct else "match" + labelled_examples[example_key] = pairs + + self.deduper.mark_pairs(labelled_examples) + + def prepare_training_data(self): + tags = Tag.objects.all().values() + data = {row["id"]: row for i, row in enumerate(tags)} + return data + + def console_label(self): + dedupe.console_label(self.deduper) + + def training_file_exists(self) -> bool: + return os.path.exists(self.TRAINING_FILE) + + def remove_training_file(self): + if os.path.isfile(self.TRAINING_FILE): + os.remove(self.TRAINING_FILE) + + def save_to_training_file(self): + with open(self.TRAINING_FILE, "w") as file: + self.deduper.write_training(file) + + +class TagClusterManager: + def reset(): + TagCluster.objects.all().delete() + + def has_unlabelled_clusters(): + return TagCluster.objects.exists() + + def modify_cluster(tag_cluster: TagCluster): + action = "" + while action != "q": + TagClusterManager.__display_cluster(tag_cluster) + action = input( + "What would you like to do?\n(c)hange canonical tag name\n(a)dd tags\n(r)emove tags\n(s)ave\n(f)inish\n" + ) + + if action == "c": + new_canonical_tag_name = input("What would you like to change it to?\n") + tag_cluster.canonical_tag_name = new_canonical_tag_name + elif action == "a": + tag_name = input("Input the name of the tag you would like to add\n") + if not tag_cluster.add_tag_by_name(tag_name): + print("The tag does not exist in Tags") + elif action == "r": + tag_index = input( + "Input the number of the tag you would like to remove.\n" + ) + + tags = list(tag_cluster.tags.all()) + if not tag_index.isnumeric() or int(tag_index) >= len(tags): + print("Index is out of bounds!") + continue + + tags.pop(int(tag_index)) + tag_cluster.tags.set(tags) + elif action == "s": + print("Published mapping") + canonical_tag, tag_mappings = tag_cluster.save_mapping() + + print( + f"The following was saved to the database: \n\n{tag_mappings}" + ) + + elif action == "f": + tag_cluster.delete() + break + else: + print("Invalid option!") + + def console_label(): + if not TagCluster.objects.exists(): + logging.info("There aren't any clusters to label.") + return + + tag_clusters = TagCluster.objects.all() + for index, tag_cluster in enumerate(tag_clusters): + TagClusterManager.modify_cluster(tag_cluster) + + def console_canonicalize_edit(): + quit = False + while not quit: + action = input( + "What would you like to do?\n(a)dd canonical tag\n(r)emove canonical tag\n(v)iew canonical list\n(m)odify canonical tag\n(q)uit\n" + ) + + if action == "a": + TagClusterManager.__console_add_new_canonical_tag() + elif action == "r": + TagClusterManager.__console_remove_canonical_tag() + elif action == "v": + TagClusterManager.__console_view_canonical_tag() + elif action == "m": + TagClusterManager.__console_modify_canonical_tag() + elif action == "q": + quit = True + + def __console_add_new_canonical_tag(): + new_canonical_tag_name = input("What is your new tag name?\n") + canonical_tag = CanonicalTag.objects.get_or_create(name=new_canonical_tag_name) + print("Created:\n", canonical_tag) + + def __console_remove_canonical_tag(): + canonical_tag_name = input("What is the name of the canonical tag to delete?\n") + canonical_tag = CanonicalTag.objects.filter(name=canonical_tag_name) + if canonical_tag.exists(): + canonical_tag.delete() + print("Successfully deleted!") + else: + print("Tag not found!") + + def __console_view_canonical_tag(): + canonical_tags = CanonicalTag.objects.all() + print("Canonical Tag List:") + for canonical_tag in canonical_tags: + print(canonical_tag.name) + print("") + + def __console_modify_canonical_tag(): + name = input("Which one would you like to modify?\n") + canonical_tag = CanonicalTag.objects.filter(name=name) + tags = Tag.objects.filter(canonicaltagmapping__canonical_tag=canonical_tag[0]) + + if canonical_tag.exists(): + try: + cluster = TagCluster( + canonical_tag_name=canonical_tag[0].name, confidence_score=1 + ) + canonical_tag[0].delete() + + cluster.save() + cluster.tags.set(tags) + TagClusterManager.modify_cluster(cluster) + except KeyboardInterrupt: + cluster.delete() + canonical_tag[0].save() + else: + print("Canonical tag not found!") + + def __display_cluster(tag_cluster: TagCluster): + tag_names = [tag.name for tag in tag_cluster.tags.all()] + print("Canonical tag name:", tag_cluster.canonical_tag_name, end="\n\n") + print("Tags:") + for index, tag_name in enumerate(tag_names): + print(f"{index}. {tag_name}") + print("") + + +class TagClusterer(AbstractTagDeduper): + def __init__(self, clustering_threshold): + self.clustering_threshold = clustering_threshold + + self.deduper = dedupe.Dedupe(AbstractTagDeduper.FIELDS) + self.prepare_training() + + # The training data is stored in a file. + # If it exists, load from the file + # Otherwise, start from scratch + def prepare_training(self): + data = self.prepare_training_data() + if os.path.exists(TagClusterer.TRAINING_FILE): + with open(TagClusterer.TRAINING_FILE, "r") as training_file: + self.deduper.prepare_training(data, training_file) + else: + self.deduper.prepare_training(data) + + # The model is trained and then the data is clustered + def cluster_tags(self): + self.deduper.train() + return self.deduper.partition( + self.prepare_training_data(), self.clustering_threshold + ) + + # Saves the clusters to the database + def save_clusters(self, clusters): + for id_list, confidence_list in clusters: + tags = Tag.objects.filter(id__in=id_list) + confidence_score = confidence_list[0] + + tag_cluster = TagCluster( + canonical_tag_name=tags[0].name, confidence_score=confidence_score + ) + tag_cluster.save() + + tag_cluster.tags.set(tags) + tag_cluster.save() + + def save_to_training_file(self): + with open(TagClusterer.TRAINING_FILE, "w") as file: + self.deduper.write_training(file) + + def training_file_exists(self) -> bool: + return os.path.exists(TagClusterer.TRAINING_FILE) + + def remove_training_file(self): + if os.path.isfile(TagClusterer.TRAINING_FILE): + os.remove(TagClusterer.TRAINING_FILE) + + +class TagGazetteer(AbstractTagDeduper): + def __init__(self, search_threshold): + self.search_threshold = search_threshold + + self.deduper = dedupe.Gazetteer(AbstractTagDeduper.FIELDS) + self.prepare_training() + self.deduper.train() + self.deduper.index(self.prepare_canonical_data()) + + # The training data is stored in a file. + # If it exists, load from the file + # Otherwise, start from scratch + def prepare_training(self): + data = self.prepare_training_data() + canonical_data = self.prepare_canonical_data() + + if self.training_file_exists(): + with open(TagGazetteer.TRAINING_FILE, "r") as training_file: + self.deduper.prepare_training(data, canonical_data, training_file) + else: + self.deduper.prepare_training(data, canonical_data) + + def prepare_canonical_data(self): + tags = CanonicalTag.objects.all().values() + data = {row["id"]: {**row} for i, row in enumerate(tags)} + return data + + def search(self, data): + return self.deduper.search(data, threshold=self.search_threshold) + + def text_search(self, name: str): + results = self.search({1: {"id": 1, "name": name}}) + matches = results[0][1] + + matches = [ + (CanonicalTag.objects.filter(pk=match[0])[0], match[1]) for match in matches + ] + + return matches diff --git a/django/curator/tests/test_tag_deduplication.py b/django/curator/tests/test_tag_deduplication.py new file mode 100644 index 000000000..5652d4e40 --- /dev/null +++ b/django/curator/tests/test_tag_deduplication.py @@ -0,0 +1,90 @@ +import random +import shortuuid +import string + +from django.test import TestCase +from django.utils.text import slugify +from taggit.models import Tag + +from curator.tag_deduplication import TagClusterer, TagGazetteer +from curator.models import CanonicalTag, CanonicalTagMapping + +random.seed(0) + + +def name_generator(unique_id, size=12): + return shortuuid.uuid()[:size] + str(unique_id) + + +class TestTagClustering(TestCase): + def setUp(self): + tags = [] + for index in range(300): + name = name_generator(index) + tags.append(Tag(name=name, slug=name)) + Tag.objects.bulk_create(tags) + + def test_uncertain_pairs(self): + tag_clustering = TagClusterer(clustering_threshold=0.5) + uncertain_pairs = tag_clustering.uncertain_pairs() + self.assertEqual(list, type(uncertain_pairs)) + self.assertLessEqual(1, len(uncertain_pairs)) + self.assertLessEqual(2, len(uncertain_pairs[0])) + self.assertEqual(int, type(uncertain_pairs[0][0]["id"])) + self.assertEqual(str, type(uncertain_pairs[0][0]["name"])) + self.assertEqual(str, type(uncertain_pairs[0][0]["slug"])) + + def test_cluster_tags(self): + clusters = self._cluster_tags() + self.assertEqual(list, type(clusters)) + + for cluster in clusters: + self.assertLess(0, len(cluster)) + self.assertLess(0, len(cluster[0])) + self.assertLess(0, len(cluster[1])) + + def _cluster_tags(self): + tag_clustering = TagClusterer(clustering_threshold=0.5) + for i in range(600): + uncertain_pair = tag_clustering.uncertain_pairs() + tag_clustering.mark_pairs(uncertain_pair, i % 3 != 0) + return tag_clustering.cluster_tags() + + +class TestTagGazetteering(TestCase): + def setUp(self): + canonical_tags = [] + tags = [] + for index in range(300): + name = name_generator(index) + canonical_tags.append(CanonicalTag(name=name)) + tags.append(Tag(name=name + "a", slug=name + "a")) + tags.append(Tag(name=name + "b", slug=name + "b")) + tags.append(Tag(name=name + "c", slug=name + "c")) + CanonicalTag.objects.bulk_create(canonical_tags) + Tag.objects.bulk_create(tags) + + def test_uncertain_pairs(self): + tag_clustering = TagGazetteer(search_threshold=0.5) + uncertain_pairs = tag_clustering.uncertain_pairs() + self.assertEqual(list, type(uncertain_pairs)) + self.assertEqual(1, len(uncertain_pairs)) + self.assertEqual(2, len(uncertain_pairs[0])) + self.assertEqual(int, type(uncertain_pairs[0][0]["id"])) + self.assertEqual(str, type(uncertain_pairs[0][0]["name"])) + self.assertEqual(str, type(uncertain_pairs[0][0]["slug"])) + + def test_search(self): + clusters = self._search() + self.assertEqual(list, type(clusters)) + + for cluster in clusters: + self.assertEqual(int, type(cluster[0])) + self.assertEqual(tuple, type(cluster[1])) + + def _search(self): + tag_clustering = TagGazetteer(search_threshold=0.5) + for i in range(600): + uncertain_pair = tag_clustering.uncertain_pairs() + tag_clustering.mark_pairs(uncertain_pair, i % 3 != 0) + return tag_clustering.search({1: {"id": 1, "name": "abms"}}) diff --git a/django/requirements.txt b/django/requirements.txt index 121243c39..8befd6d21 100644 --- a/django/requirements.txt +++ b/django/requirements.txt @@ -1,5 +1,6 @@ bagit==1.8.1 bleach==6.0.0 +dedupe==2.0.23 django-allauth==0.55.2 django-anymail[mailgun]==10.1 django-cookie-law==2.2.0