diff --git a/rdmo/core/tests/utils.py b/rdmo/core/tests/utils.py index 3bf4db8c4a..1fce8f674b 100644 --- a/rdmo/core/tests/utils.py +++ b/rdmo/core/tests/utils.py @@ -1,3 +1,5 @@ +import hashlib + from rdmo.core.models import Model from rdmo.core.tests.constants import multisite_status_map, status_map_object_permissions @@ -30,3 +32,7 @@ def get_obj_perms_status_code(instance, username, method): except KeyError: # not all users are defined in the method_instance_perms_map return multisite_status_map[method][username] + + +def compute_checksum(string): + return hashlib.sha1(string).hexdigest() diff --git a/rdmo/projects/admin.py b/rdmo/projects/admin.py index 1c82a6a5e8..926885e1c5 100644 --- a/rdmo/projects/admin.py +++ b/rdmo/projects/admin.py @@ -1,3 +1,4 @@ +from django import forms from django.contrib import admin from django.db.models import Prefetch from django.urls import reverse @@ -15,12 +16,35 @@ Snapshot, Value, ) +from .validators import ProjectParentValidator + + +class ProjectAdminForm(forms.ModelForm): + + class Meta: + model = Project + fields = [ + 'parent', + 'site', + 'title', + 'description', + 'catalog', + 'views' + ] + + + def clean(self): + super().clean() + ProjectParentValidator(self.instance)(self.cleaned_data) @admin.register(Project) class ProjectAdmin(admin.ModelAdmin): + form = ProjectAdminForm + search_fields = ('title', 'user__username') list_display = ('title', 'owners', 'updated', 'created') + readonly_fields = ('progress_count', 'progress_total') def get_queryset(self, request): return Project.objects.prefetch_related( diff --git a/rdmo/projects/assets/js/projects/components/main/Projects.js b/rdmo/projects/assets/js/projects/components/main/Projects.js index 93e38e65db..6f15f8e7b8 100644 --- a/rdmo/projects/assets/js/projects/components/main/Projects.js +++ b/rdmo/projects/assets/js/projects/components/main/Projects.js @@ -140,21 +140,27 @@ const Projects = ({ config, configActions, currentUserObject, projectsActions, p return (
- {(isProjectManager || isProjectOwner || isManager) && window.location.href = `${rowUrl}/update/${params}`} + href={`${rowUrl}/copy/`} + className="fa fa-copy" + title={gettext('Copy project')} + onClick={() => window.location.href = `${rowUrl}/copy/${params}`} /> + {(isProjectManager || isProjectOwner || isManager) && + window.location.href = `${rowUrl}/update/${params}`} + /> } {(isProjectOwner || isManager) && - window.location.href = `${rowUrl}/delete/${params}`} - /> + window.location.href = `${rowUrl}/delete/${params}`} + /> }
) diff --git a/rdmo/projects/assets/scss/projects.scss b/rdmo/projects/assets/scss/projects.scss index dbe1811707..f24a9fd12e 100644 --- a/rdmo/projects/assets/scss/projects.scss +++ b/rdmo/projects/assets/scss/projects.scss @@ -207,6 +207,7 @@ a.disabled { display: flex; gap: 5px; margin-bottom: 10px; + justify-content: flex-end; } .projects { diff --git a/rdmo/projects/forms.py b/rdmo/projects/forms.py index 4594fca7fa..11adc0f1b9 100644 --- a/rdmo/projects/forms.py +++ b/rdmo/projects/forms.py @@ -13,6 +13,7 @@ from .constants import ROLE_CHOICES from .models import Integration, IntegrationOption, Invite, Membership, Project, Snapshot +from .validators import ProjectParentValidator class CatalogChoiceField(forms.ModelChoiceField): @@ -53,6 +54,8 @@ class ProjectForm(forms.ModelForm): use_required_attribute = False def __init__(self, *args, **kwargs): + self.copy = kwargs.pop('copy', False) + catalogs = kwargs.pop('catalogs') projects = kwargs.pop('projects') super().__init__(*args, **kwargs) @@ -66,6 +69,11 @@ def __init__(self, *args, **kwargs): if settings.NESTED_PROJECTS: self.fields['parent'].queryset = projects + def clean(self): + if not self.copy: + ProjectParentValidator(self.instance)(self.cleaned_data) + super().clean() + class Meta: model = Project @@ -160,6 +168,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fields['parent'].queryset = projects + def clean(self): + ProjectParentValidator(self.instance)(self.cleaned_data) + super().clean() + class Meta: model = Project fields = ('parent', ) diff --git a/rdmo/projects/models/project.py b/rdmo/projects/models/project.py index 6425f37f48..b992c0dc9d 100644 --- a/rdmo/projects/models/project.py +++ b/rdmo/projects/models/project.py @@ -1,6 +1,5 @@ from django.conf import settings from django.contrib.sites.models import Site -from django.core.exceptions import ValidationError from django.db import models from django.db.models.signals import pre_delete from django.dispatch import receiver @@ -88,11 +87,12 @@ def __str__(self): def get_absolute_url(self): return reverse('project', kwargs={'pk': self.pk}) - def clean(self): + def save(self, *args, **kwargs): + # ensure that the project hierarchy is not disturbed if self.id and self.parent in self.get_descendants(include_self=True): - raise ValidationError({ - 'parent': [_('A project may not be moved to be a child of itself or one of its descendants.')] - }) + raise RuntimeError('A project may not be moved to be a child of itself or one of its descendants.') + + super().save(*args, **kwargs) @property def catalog_uri(self): diff --git a/rdmo/projects/serializers/v1/__init__.py b/rdmo/projects/serializers/v1/__init__.py index 8816dc9ca7..fe72127295 100644 --- a/rdmo/projects/serializers/v1/__init__.py +++ b/rdmo/projects/serializers/v1/__init__.py @@ -9,7 +9,7 @@ from rdmo.services.validators import ProviderValidator from ...models import Integration, IntegrationOption, Invite, Issue, IssueResource, Membership, Project, Snapshot, Value -from ...validators import ValueConflictValidator, ValueQuotaValidator, ValueTypeValidator +from ...validators import ProjectParentValidator, ValueConflictValidator, ValueQuotaValidator, ValueTypeValidator class UserSerializer(serializers.ModelSerializer): @@ -78,6 +78,17 @@ class Meta: read_only_fields = ( 'snapshots', ) + validators = [ + ProjectParentValidator() + ] + + +class ProjectCopySerializer(ProjectSerializer): + + class Meta: + model = Project + fields = ProjectSerializer.Meta.fields + read_only_fields = ProjectSerializer.Meta.read_only_fields class ProjectMembershipSerializer(serializers.ModelSerializer): diff --git a/rdmo/projects/templates/projects/project_detail_sidebar.html b/rdmo/projects/templates/projects/project_detail_sidebar.html index 4c4be91ff7..c1886388a9 100644 --- a/rdmo/projects/templates/projects/project_detail_sidebar.html +++ b/rdmo/projects/templates/projects/project_detail_sidebar.html @@ -38,7 +38,6 @@

{% trans 'Options' %}

{% endif %} -{% if can_change_project or can_delete_project %} -{% endif %} {% has_perm 'projects.add_membership_object' request.user project as can_add_membership %} {% if can_add_membership %} diff --git a/rdmo/projects/tests/test_utils.py b/rdmo/projects/tests/test_utils.py index edbc5a8d3b..db2bfe2c51 100644 --- a/rdmo/projects/tests/test_utils.py +++ b/rdmo/projects/tests/test_utils.py @@ -1,9 +1,14 @@ import pytest +from django.contrib.auth.models import User +from django.contrib.sites.models import Site from django.http import QueryDict +from rdmo.core.tests.utils import compute_checksum + from ..filters import ProjectFilter -from ..utils import set_context_querystring_with_filter_and_page +from ..models import Project +from ..utils import copy_project, set_context_querystring_with_filter_and_page GET_queries = [ 'page=2&title=project', @@ -32,3 +37,94 @@ def test_set_context_querystring_with_filter_and_page(GET_query): assert context.get('querystring', 'not-in-context') == '' else: assert context.get('querystring', 'not-in-context') == 'not-in-context' + + +def test_copy_project(db, files): + project = Project.objects.get(id=1) + site = Site.objects.get(id=2) + user = User.objects.get(id=1) + project_copy = copy_project(project, site, [user]) + + # re fetch the original project + project = Project.objects.get(id=1) + + # check that site, owners, tasks, and views are correct + assert project_copy.site == site + assert list(project_copy.owners) == [user] + assert list(project_copy.user.values('id')) == [{'id': user.id}] + assert list(project_copy.tasks.values('id')) == list(project.tasks.values('id')) + assert list(project_copy.views.values('id')) == list(project.views.values('id')) + + # check that no ids are the same + assert project_copy.id != project.id + assert not set(project_copy.snapshots.values_list('id')).intersection(set(project.snapshots.values_list('id'))) + assert not set(project_copy.values.values_list('id')).intersection(set(project.values.values_list('id'))) + + # check the snapshots + snapshot_fields = ( + 'title', + 'description' + ) + for snapshot_copy, snapshot in zip( + project_copy.snapshots.values(*snapshot_fields), + project.snapshots.values(*snapshot_fields) + ): + assert snapshot_copy == snapshot + + # check the values + value_fields = ( + 'attribute', + 'set_prefix', + 'set_collection', + 'set_index', + 'collection_index', + 'text', + 'option', + 'value_type', + 'unit', + 'external_id' + ) + ordering = ( + 'attribute', + 'set_prefix', + 'set_index', + 'collection_index' + ) + for value_copy, value in zip( + project_copy.values.filter(snapshot=None).order_by(*ordering), + project.values.filter(snapshot=None).order_by(*ordering) + ): + for field in value_fields: + assert getattr(value_copy, field) == getattr(value, field), field + + if value_copy.file: + assert value_copy.file.path != value.file.path + assert value_copy.file.path == value_copy.file.path.replace( + f'/projects/{project.id}/values/{value.id}/', + f'/projects/{project_copy.id}/values/{value_copy.id}/' + ) + assert value_copy.file.size == value.file.size + assert compute_checksum(value_copy.file.open('rb').read()) == \ + compute_checksum(value.file.open('rb').read()) + else: + assert not value.file + + for snapshot_copy, snapshot in zip(project_copy.snapshots.all(), project.snapshots.all()): + for value_copy, value in zip( + project_copy.values.filter(snapshot=snapshot_copy).order_by(*ordering), + project.values.filter(snapshot=snapshot).order_by(*ordering) + ): + for field in value_fields: + assert getattr(value_copy, field) == getattr(value, field) + + if value_copy.file: + assert value_copy.file.path != value.file.path + assert value_copy.file.path == value_copy.file.path.replace( + f'/projects/{project.id}/snapshot/{snapshot.id}/values/{value.id}/', + f'/projects/{project_copy.id}/snapshot/{snapshot.id}/values/{value_copy.id}/' + ) + assert value_copy.file.size == value.file.size + assert compute_checksum(value_copy.file.open('rb').read()) == \ + compute_checksum(value.file.open('rb').read()) + else: + assert not value.file diff --git a/rdmo/projects/tests/test_view_project.py b/rdmo/projects/tests/test_view_project.py index b43660cf14..8a1465c765 100644 --- a/rdmo/projects/tests/test_view_project.py +++ b/rdmo/projects/tests/test_view_project.py @@ -58,7 +58,9 @@ export_formats = ('rtf', 'odt', 'docx', 'html', 'markdown', 'tex', 'pdf') site_id = 1 -parent_project_id = 1 +project_id = 1 +parent_id = 3 +parent_ancestors = [2, 3] catalog_id = 1 @@ -266,7 +268,7 @@ def test_project_create_parent_post(db, client, username, password): 'title': 'A new project', 'description': 'Some description', 'catalog': catalog_id, - 'parent': parent_project_id + 'parent': project_id } response = client.post(url, data) @@ -335,17 +337,21 @@ def test_project_update_post_parent(db, client, username, password, project_id): 'title': project.title, 'description': project.description, 'catalog': project.catalog.pk, - 'parent': parent_project_id + 'parent': parent_id } response = client.post(url, data) if project_id in change_project_permission_map.get(username, []): - if project_id == parent_project_id: + if parent_id in view_project_permission_map.get(username, []): + if project_id in parent_ancestors: + assert response.status_code == 200 + assert Project.objects.get(pk=project_id).parent == project.parent + else: + assert response.status_code == 302 + assert Project.objects.get(pk=project_id).parent_id == parent_id + else: assert response.status_code == 200 assert Project.objects.get(pk=project_id).parent == project.parent - else: - assert response.status_code == 302 - assert Project.objects.get(pk=project_id).parent_id == parent_project_id else: if password: assert response.status_code == 403 @@ -545,17 +551,21 @@ def test_project_update_parent_post(db, client, username, password, project_id): url = reverse('project_update_parent', args=[project_id]) data = { - 'parent': parent_project_id + 'parent': parent_id } response = client.post(url, data) if project_id in change_project_permission_map.get(username, []): - if project_id == parent_project_id: + if parent_id in view_project_permission_map.get(username, []): + if project_id in parent_ancestors: + assert response.status_code == 200 + assert Project.objects.get(pk=project_id).parent == project.parent + else: + assert response.status_code == 302 + assert Project.objects.get(pk=project_id).parent_id == parent_id + else: assert response.status_code == 200 assert Project.objects.get(pk=project_id).parent == project.parent - else: - assert response.status_code == 302 - assert Project.objects.get(pk=project_id).parent_id == parent_project_id else: if password: assert response.status_code == 403 diff --git a/rdmo/projects/tests/test_view_project_copy.py b/rdmo/projects/tests/test_view_project_copy.py new file mode 100644 index 0000000000..8ecec135da --- /dev/null +++ b/rdmo/projects/tests/test_view_project_copy.py @@ -0,0 +1,207 @@ +import pytest + +from django.contrib.auth.models import Group, User +from django.urls import reverse + +from ..models import Project, Snapshot, Value + +users = ( + ('owner', 'owner'), + ('manager', 'manager'), + ('author', 'author'), + ('guest', 'guest'), + ('user', 'user'), + ('site', 'site'), + ('anonymous', None), + ('editor', 'editor'), + ('reviewer', 'reviewer'), + ('api', 'api'), +) + +view_project_permission_map = { + 'owner': [1, 2, 3, 4, 5, 10], + 'manager': [1, 3, 5, 7], + 'author': [1, 3, 5, 8], + 'guest': [1, 3, 5, 9], + 'api': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + 'site': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] +} + +change_project_permission_map = { + 'owner': [1, 2, 3, 4, 5, 10], + 'manager': [1, 3, 5, 7], + 'api': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + 'site': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] +} + +delete_project_permission_map = { + 'owner': [1, 2, 3, 4, 5, 10], + 'api': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + 'site': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], +} + +export_project_permission_map = { + 'owner': [1, 2, 3, 4, 5, 10], + 'manager': [1, 3, 5, 7], + 'api': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + 'site': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], +} + +projects = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + +project_id = 1 +site_id = 1 +parent_id = 1 +catalog_id = 1 + + +@pytest.mark.parametrize('username,password', users) +@pytest.mark.parametrize('project_id', projects) +def test_project_copy_get(db, client, username, password, project_id): + client.login(username=username, password=password) + + url = reverse('project_copy', args=[project_id]) + response = client.get(url) + + if project_id in view_project_permission_map.get(username, []): + assert response.status_code == 200 + else: + if password: + assert response.status_code == 403 + else: + assert response.status_code == 302 + + +def test_project_copy_restricted_get(db, client, settings): + settings.PROJECT_CREATE_RESTRICTED = True + settings.PROJECT_CREATE_GROUPS = ['projects'] + + group = Group.objects.create(name='projects') + guest = User.objects.get(username='guest') + guest.groups.add(group) + + client.login(username='guest', password='guest') + + url = reverse('project_copy', args=[project_id]) + response = client.get(url) + + assert response.status_code == 200 + + +def test_project_copy_forbidden_get(db, client, settings): + settings.PROJECT_CREATE_RESTRICTED = True + + client.login(username='guest', password='guest') + + url = reverse('project_copy', args=[project_id]) + response = client.get(url) + + assert response.status_code == 403 + + +@pytest.mark.parametrize('username,password', users) +@pytest.mark.parametrize('project_id', projects) +def test_project_copy_post(db, files, client, username, password, project_id): + client.login(username=username, password=password) + + project_count = Project.objects.count() + snapshot_count = Snapshot.objects.count() + value_count = Value.objects.count() + + project = Project.objects.get(id=project_id) + project_snapshots_count = project.snapshots.count() + project_values_count = project.values.count() + + url = reverse('project_copy', args=[project_id]) + data = { + 'title': 'A new project', + 'description': 'Some description', + 'catalog': catalog_id + } + response = client.post(url, data) + + if project_id in view_project_permission_map.get(username, []): + assert response.status_code == 302 + assert Project.objects.count() == project_count + 1 + assert Snapshot.objects.count() == snapshot_count + project_snapshots_count + assert Value.objects.count() == value_count + project_values_count + else: + assert response.status_code == 403 if password else 302 + assert Project.objects.count() == project_count + assert Value.objects.count() == value_count + + +def test_project_copy_post_restricted(db, files, client, settings): + settings.PROJECT_CREATE_RESTRICTED = True + settings.PROJECT_CREATE_GROUPS = ['projects'] + + group = Group.objects.create(name='projects') + guest = User.objects.get(username='guest') + guest.groups.add(group) + + client.login(username='guest', password='guest') + + url = reverse('project_copy', args=[project_id]) + data = { + 'title': 'A new project', + 'description': 'Some description', + 'catalog': catalog_id + } + response = client.post(url, data) + + assert response.status_code == 302 + + +def test_project_copy_post_forbidden(db, files, client, settings): + settings.PROJECT_CREATE_RESTRICTED = True + + client.login(username='guest', password='guest') + + url = reverse('project_copy', args=[project_id]) + data = { + 'title': 'A new project', + 'description': 'Some description', + 'catalog': catalog_id + } + response = client.post(url, data) + + assert response.status_code == 403 + + +@pytest.mark.parametrize('username,password', users) +@pytest.mark.parametrize('project_id', projects) +def test_project_copy_parent_post(db, files, client, username, password, project_id): + client.login(username=username, password=password) + project_count = Project.objects.count() + + project_count = Project.objects.count() + snapshot_count = Snapshot.objects.count() + value_count = Value.objects.count() + + project = Project.objects.get(id=project_id) + project_snapshots_count = project.snapshots.count() + project_values_count = project.values.count() + + url = reverse('project_copy', args=[project_id]) + data = { + 'title': 'A new project', + 'description': 'Some description', + 'catalog': catalog_id, + 'parent': parent_id + } + response = client.post(url, data) + + if project_id in view_project_permission_map.get(username, []): + if parent_id in view_project_permission_map.get(username, []): + assert response.status_code == 302 + assert Project.objects.count() == project_count + 1 + assert Snapshot.objects.count() == snapshot_count + project_snapshots_count + assert Value.objects.count() == value_count + project_values_count + else: + assert response.status_code == 200 + assert Project.objects.count() == project_count + assert Value.objects.count() == value_count + else: + assert response.status_code == 403 if password else 302 + assert Project.objects.count() == project_count + assert Value.objects.count() == value_count diff --git a/rdmo/projects/tests/test_viewset_project.py b/rdmo/projects/tests/test_viewset_project.py index 0069d25777..7427781f8d 100644 --- a/rdmo/projects/tests/test_viewset_project.py +++ b/rdmo/projects/tests/test_viewset_project.py @@ -3,7 +3,7 @@ from django.contrib.auth.models import Group, User from django.urls import reverse -from ..models import Project +from ..models import Membership, Project, Snapshot, Value users = ( ('owner', 'owner'), @@ -41,6 +41,7 @@ urlnames = { 'list': 'v1-projects:project-list', 'detail': 'v1-projects:project-detail', + 'copy': 'v1-projects:project-copy', 'overview': 'v1-projects:project-overview', 'navigation': 'v1-projects:project-navigation', 'options': 'v1-projects:project-options', @@ -60,6 +61,8 @@ optionset_id = 4 project_id = 1 +parent_id = 3 +parent_ancestors = [2, 3] page_size = 5 @@ -216,6 +219,133 @@ def test_create_parent(db, client, username, password, project_id): assert response.status_code == 401 +@pytest.mark.parametrize('username,password', users) +@pytest.mark.parametrize('project_id', projects) +def test_copy(db, files, client, username, password, project_id): + client.login(username=username, password=password) + + project_count = Project.objects.count() + snapshot_count = Snapshot.objects.count() + value_count = Value.objects.count() + + project = Project.objects.get(id=project_id) + project_snapshots_count = project.snapshots.count() + project_values_count = project.values.count() + + url = reverse(urlnames['copy'], args=[project_id]) + data = { + 'title': 'New title', + 'description': project.description, + 'catalog': project.catalog.id + } + response = client.post(url, data, content_type='application/json') + + if project_id in view_project_permission_map.get(username, []): + assert response.status_code == 201 + + for key, value in response.json().items(): + if key in data: + assert value == data[key] + + assert Project.objects.count() == project_count + 1 + assert Snapshot.objects.count() == snapshot_count + project_snapshots_count + assert Value.objects.count() == value_count + project_values_count + else: + if password: + assert response.status_code == 404 + else: + assert response.status_code == 401 + + assert Project.objects.count() == project_count + assert Value.objects.count() == value_count + + +def test_copy_restricted(db, files, client, settings): + settings.PROJECT_CREATE_RESTRICTED = True + settings.PROJECT_CREATE_GROUPS = ['projects'] + + group = Group.objects.create(name='projects') + user = User.objects.get(username='user') + user.groups.add(group) + + Membership.objects.create(user=user, project_id=project_id, role='guest') + + client.login(username='user', password='user') + + url = reverse(urlnames['copy'], args=[project_id]) + data = { + 'title': 'Lorem ipsum dolor sit amet', + 'description': 'At vero eos et accusam et justo duo dolores et ea rebum.', + 'catalog': catalog_id + } + response = client.post(url, data, content_type='application/json') + + assert response.status_code == 201 + + +def test_copy_forbidden(db, client, settings): + settings.PROJECT_CREATE_RESTRICTED = True + + user = User.objects.get(username='user') + + Membership.objects.create(user=user, project_id=project_id, role='guest') + + client.login(username='user', password='user') + + url = reverse(urlnames['copy'], args=[project_id]) + data = { + 'title': 'Lorem ipsum dolor sit amet', + 'description': 'At vero eos et accusam et justo duo dolores et ea rebum.', + 'catalog': catalog_id + } + response = client.post(url, data) + + assert response.status_code == 403 + + +def test_copy_catalog_missing(db, client): + client.login(username='guest', password='guest') + + url = reverse(urlnames['copy'], args=[project_id]) + data = { + 'title': 'Lorem ipsum dolor sit amet', + 'description': 'At vero eos et accusam et justo duo dolores et ea rebum.' + } + response = client.post(url, data) + + assert response.status_code == 400 + + +def test_copy_catalog_not_available(db, client): + client.login(username='guest', password='guest') + + url = reverse(urlnames['copy'], args=[project_id]) + data = { + 'title': 'Lorem ipsum dolor sit amet', + 'description': 'At vero eos et accusam et justo duo dolores et ea rebum.', + 'catalog': catalog_id_not_available + } + response = client.post(url, data) + + assert response.status_code == 400 + +@pytest.mark.parametrize('project_id', projects) +def test_copy_parent(db, files, client, project_id): + client.login(username='owner', password='owner') + project = Project.objects.get(pk=project_id) + + url = reverse(urlnames['copy'], args=[project_id]) + data = { + 'title': 'New title', + 'description': project.description, + 'catalog': project.catalog.id, + 'parent': parent_id + } + response = client.post(url, data, content_type='application/json') + + assert response.status_code == 201 + + @pytest.mark.parametrize('username,password', users) @pytest.mark.parametrize('project_id', projects) def test_update(db, client, username, password, project_id): @@ -246,6 +376,43 @@ def test_update(db, client, username, password, project_id): assert Project.objects.get(id=project_id).description == project.description +@pytest.mark.parametrize('username,password', users) +@pytest.mark.parametrize('project_id', projects) +def test_update_parent(db, client, username, password, project_id): + client.login(username=username, password=password) + project = Project.objects.get(pk=project_id) + + url = reverse(urlnames['detail'], args=[project_id]) + data = { + 'title': 'New title', + 'description': project.description, + 'catalog': project.catalog.id, + 'parent': parent_id + } + response = client.put(url, data, content_type='application/json') + + if project_id in change_project_permission_map.get(username, []): + if parent_id in view_project_permission_map.get(username, []): + if project_id in parent_ancestors: + assert response.status_code == 400 + assert Project.objects.get(pk=project_id).parent == project.parent + else: + assert response.status_code == 200 + assert Project.objects.get(pk=project_id).parent_id == parent_id + else: + assert response.status_code == 404 + assert Project.objects.get(pk=project_id).parent == project.parent + else: + if project_id in view_project_permission_map.get(username, []): + assert response.status_code == 403 + elif password: + assert response.status_code == 404 + else: + assert response.status_code == 401 + + assert Project.objects.get(pk=project_id).parent == project.parent + + @pytest.mark.parametrize('username,password', users) @pytest.mark.parametrize('project_id', projects) def test_delete(db, client, username, password, project_id): diff --git a/rdmo/projects/urls/__init__.py b/rdmo/projects/urls/__init__.py index 0be1bdbd97..b7cc6d81f9 100644 --- a/rdmo/projects/urls/__init__.py +++ b/rdmo/projects/urls/__init__.py @@ -15,6 +15,7 @@ ProjectAnswersExportView, ProjectAnswersView, ProjectCancelView, + ProjectCopyView, ProjectCreateImportView, ProjectCreateView, ProjectDeleteView, @@ -56,6 +57,8 @@ re_path(r'^(?P[0-9]+)/$', ProjectDetailView.as_view(), name='project'), + re_path(r'^(?P[0-9]+)/copy/$', + ProjectCopyView.as_view(), name='project_copy'), re_path(r'^(?P[0-9]+)/update/$', ProjectUpdateView.as_view(), name='project_update'), re_path(r'^(?P[0-9]+)/update/information/$', diff --git a/rdmo/projects/utils.py b/rdmo/projects/utils.py index c23b0137fa..ab619b54b4 100644 --- a/rdmo/projects/utils.py +++ b/rdmo/projects/utils.py @@ -5,6 +5,7 @@ from django.contrib.sites.models import Site from django.template.loader import render_to_string from django.urls import reverse +from django.utils.timezone import now from rdmo.core.mail import send_mail from rdmo.core.plugins import get_plugins @@ -38,6 +39,84 @@ def check_conditions(conditions, values, set_prefix=None, set_index=None): return True +def copy_project(project, site, owners): + from .models import Membership, Value # to prevent circular inclusion + + timestamp = now() + + tasks = project.tasks.all() + views = project.views.all() + + values = project.values.filter(snapshot=None) + snapshots = { + snapshot: project.values.filter(snapshot=snapshot) + for snapshot in project.snapshots.all() + } + + # unset the id, set current site and update timestamps + project.id = None + project.site = site + project.created = timestamp + + # save the new project + project.save() + + # save project tasks + for task in tasks: + project.tasks.add(task) + + # save project views + for view in views: + project.views.add(view) + + # save current project values + project_values = [] + for value in values: + value.id = None + value.project = project + value.created = timestamp + + if value.file: + # file values cannot be bulk created since we need their id and only postgres provides that (reliably) + # https://docs.djangoproject.com/en/4.2/ref/models/querysets/#bulk-create + value.save() + value.copy_file(value.file_name, value.file) + else: + project_values.append(value) + + # insert the new values using bulk_create + Value.objects.bulk_create(project_values) + + # save project snapshots + for snapshot, snapshot_values in snapshots.items(): + snapshot.id = None + snapshot.project = project + snapshot.created = timestamp + snapshot.save(copy_values=False) + + project_snapshot_values = [] + for value in snapshot_values: + value.id = None + value.project = project + value.snapshot = snapshot + value.created = timestamp + + if value.file: + value.save() + value.copy_file(value.file_name, value.file) + else: + project_snapshot_values.append(value) + + # insert the new snapshot values using bulk_create + Value.objects.bulk_create(project_snapshot_values) + + for owner in owners: + membership = Membership(project=project, user=owner, role='owner') + membership.save() + + return project + + def save_import_values(project, values, checked): for value in values: if value.attribute: diff --git a/rdmo/projects/validators.py b/rdmo/projects/validators.py index da0ace2423..e3bd382923 100644 --- a/rdmo/projects/validators.py +++ b/rdmo/projects/validators.py @@ -20,6 +20,19 @@ VALUE_TYPE_URL, ) from rdmo.core.utils import human2bytes +from rdmo.core.validators import InstanceValidator + + +class ProjectParentValidator(InstanceValidator): + + def __call__(self, data, serializer=None): + super().__call__(data, serializer) + + if self.instance and self.instance.id \ + and data.get('parent') in self.instance.get_descendants(include_self=True): + raise self.raise_validation_error({ + 'parent': [_('A project may not be moved to be a child of itself or one of its descendants.')] + }) class ValueConflictValidator: diff --git a/rdmo/projects/views/__init__.py b/rdmo/projects/views/__init__.py index 543481bc22..a6899c8671 100644 --- a/rdmo/projects/views/__init__.py +++ b/rdmo/projects/views/__init__.py @@ -14,6 +14,7 @@ ProjectsView, ) from .project_answers import ProjectAnswersExportView, ProjectAnswersView +from .project_copy import ProjectCopyView from .project_create import ProjectCreateImportView, ProjectCreateView from .project_update import ( ProjectUpdateCatalogView, diff --git a/rdmo/projects/views/project_copy.py b/rdmo/projects/views/project_copy.py new file mode 100644 index 0000000000..6d9df32c3e --- /dev/null +++ b/rdmo/projects/views/project_copy.py @@ -0,0 +1,42 @@ +import logging + +from django.contrib.sites.shortcuts import get_current_site +from django.http import HttpResponseRedirect +from django.views.generic import UpdateView + +from rdmo.core.views import ObjectPermissionMixin, RedirectViewMixin +from rdmo.questions.models import Catalog + +from ..forms import ProjectForm +from ..models import Project +from ..utils import copy_project + +logger = logging.getLogger(__name__) + + +class ProjectCopyView(ObjectPermissionMixin, RedirectViewMixin, UpdateView): + + model = Project + form_class = ProjectForm + permission_required = ('projects.add_project', 'projects.view_project_object') + + def get_form_kwargs(self): + catalogs = Catalog.objects.filter_current_site() \ + .filter_group(self.request.user) \ + .filter_availability(self.request.user) \ + .order_by('-available', 'order') + projects = Project.objects.filter_user(self.request.user) + + form_kwargs = super().get_form_kwargs() + form_kwargs.update({ + 'copy': True, + 'catalogs': catalogs, + 'projects': projects + }) + return form_kwargs + + def form_valid(self, form): + site = get_current_site(self.request) + owners = [self.request.user] + project = copy_project(form.instance, site, owners) + return HttpResponseRedirect(project.get_absolute_url()) diff --git a/rdmo/projects/viewsets.py b/rdmo/projects/viewsets.py index 8c5d367a5b..0a0fde16d6 100644 --- a/rdmo/projects/viewsets.py +++ b/rdmo/projects/viewsets.py @@ -6,7 +6,7 @@ from django.http import Http404, HttpResponseRedirect from django.utils.translation import gettext_lazy as _ -from rest_framework import serializers +from rest_framework import serializers, status from rest_framework.decorators import action from rest_framework.exceptions import NotFound from rest_framework.mixins import CreateModelMixin, ListModelMixin, RetrieveModelMixin, UpdateModelMixin @@ -48,6 +48,7 @@ InviteSerializer, IssueSerializer, MembershipSerializer, + ProjectCopySerializer, ProjectIntegrationSerializer, ProjectInviteSerializer, ProjectInviteUpdateSerializer, @@ -63,7 +64,7 @@ ) from .serializers.v1.overview import CatalogSerializer, ProjectOverviewSerializer from .serializers.v1.page import PageSerializer -from .utils import check_conditions, get_upload_accept, send_invite_email +from .utils import check_conditions, copy_project, get_upload_accept, send_invite_email class ProjectPagination(PageNumberPagination): @@ -116,6 +117,25 @@ def get_queryset(self): return queryset + @action(detail=True, methods=['POST'], + permission_classes=(HasModelPermission | HasProjectPermission, )) + def copy(self, request, pk=None): + instance = self.get_object() + serializer = ProjectCopySerializer(instance, data=request.data, context=self.get_serializer_context()) + serializer.is_valid(raise_exception=True) + + # update instance + for key, value in serializer.validated_data.items(): + setattr(instance, key, value) + + site = get_current_site(self.request) + owners = [self.request.user] + project_copy = copy_project(instance, site, owners) + + serializer = self.get_serializer(project_copy) + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + @action(detail=True, permission_classes=(HasModelPermission | HasProjectPermission, )) def overview(self, request, pk=None): project = self.get_object()