diff --git a/rdmo/core/tests/utils.py b/rdmo/core/tests/utils.py
index 3bf4db8c4a..1fce8f674b 100644
--- a/rdmo/core/tests/utils.py
+++ b/rdmo/core/tests/utils.py
@@ -1,3 +1,5 @@
+import hashlib
+
from rdmo.core.models import Model
from rdmo.core.tests.constants import multisite_status_map, status_map_object_permissions
@@ -30,3 +32,7 @@ def get_obj_perms_status_code(instance, username, method):
except KeyError:
# not all users are defined in the method_instance_perms_map
return multisite_status_map[method][username]
+
+
+def compute_checksum(string):
+ return hashlib.sha1(string).hexdigest()
diff --git a/rdmo/projects/admin.py b/rdmo/projects/admin.py
index 1c82a6a5e8..926885e1c5 100644
--- a/rdmo/projects/admin.py
+++ b/rdmo/projects/admin.py
@@ -1,3 +1,4 @@
+from django import forms
from django.contrib import admin
from django.db.models import Prefetch
from django.urls import reverse
@@ -15,12 +16,35 @@
Snapshot,
Value,
)
+from .validators import ProjectParentValidator
+
+
+class ProjectAdminForm(forms.ModelForm):
+
+ class Meta:
+ model = Project
+ fields = [
+ 'parent',
+ 'site',
+ 'title',
+ 'description',
+ 'catalog',
+ 'views'
+ ]
+
+
+ def clean(self):
+ super().clean()
+ ProjectParentValidator(self.instance)(self.cleaned_data)
@admin.register(Project)
class ProjectAdmin(admin.ModelAdmin):
+ form = ProjectAdminForm
+
search_fields = ('title', 'user__username')
list_display = ('title', 'owners', 'updated', 'created')
+ readonly_fields = ('progress_count', 'progress_total')
def get_queryset(self, request):
return Project.objects.prefetch_related(
diff --git a/rdmo/projects/assets/js/projects/components/main/Projects.js b/rdmo/projects/assets/js/projects/components/main/Projects.js
index 93e38e65db..6f15f8e7b8 100644
--- a/rdmo/projects/assets/js/projects/components/main/Projects.js
+++ b/rdmo/projects/assets/js/projects/components/main/Projects.js
@@ -140,21 +140,27 @@ const Projects = ({ config, configActions, currentUserObject, projectsActions, p
return (
- {(isProjectManager || isProjectOwner || isManager) &&
window.location.href = `${rowUrl}/update/${params}`}
+ href={`${rowUrl}/copy/`}
+ className="fa fa-copy"
+ title={gettext('Copy project')}
+ onClick={() => window.location.href = `${rowUrl}/copy/${params}`}
/>
+ {(isProjectManager || isProjectOwner || isManager) &&
+ window.location.href = `${rowUrl}/update/${params}`}
+ />
}
{(isProjectOwner || isManager) &&
- window.location.href = `${rowUrl}/delete/${params}`}
- />
+ window.location.href = `${rowUrl}/delete/${params}`}
+ />
}
)
diff --git a/rdmo/projects/assets/scss/projects.scss b/rdmo/projects/assets/scss/projects.scss
index dbe1811707..f24a9fd12e 100644
--- a/rdmo/projects/assets/scss/projects.scss
+++ b/rdmo/projects/assets/scss/projects.scss
@@ -207,6 +207,7 @@ a.disabled {
display: flex;
gap: 5px;
margin-bottom: 10px;
+ justify-content: flex-end;
}
.projects {
diff --git a/rdmo/projects/forms.py b/rdmo/projects/forms.py
index 4594fca7fa..11adc0f1b9 100644
--- a/rdmo/projects/forms.py
+++ b/rdmo/projects/forms.py
@@ -13,6 +13,7 @@
from .constants import ROLE_CHOICES
from .models import Integration, IntegrationOption, Invite, Membership, Project, Snapshot
+from .validators import ProjectParentValidator
class CatalogChoiceField(forms.ModelChoiceField):
@@ -53,6 +54,8 @@ class ProjectForm(forms.ModelForm):
use_required_attribute = False
def __init__(self, *args, **kwargs):
+ self.copy = kwargs.pop('copy', False)
+
catalogs = kwargs.pop('catalogs')
projects = kwargs.pop('projects')
super().__init__(*args, **kwargs)
@@ -66,6 +69,11 @@ def __init__(self, *args, **kwargs):
if settings.NESTED_PROJECTS:
self.fields['parent'].queryset = projects
+ def clean(self):
+ if not self.copy:
+ ProjectParentValidator(self.instance)(self.cleaned_data)
+ super().clean()
+
class Meta:
model = Project
@@ -160,6 +168,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['parent'].queryset = projects
+ def clean(self):
+ ProjectParentValidator(self.instance)(self.cleaned_data)
+ super().clean()
+
class Meta:
model = Project
fields = ('parent', )
diff --git a/rdmo/projects/models/project.py b/rdmo/projects/models/project.py
index 6425f37f48..b992c0dc9d 100644
--- a/rdmo/projects/models/project.py
+++ b/rdmo/projects/models/project.py
@@ -1,6 +1,5 @@
from django.conf import settings
from django.contrib.sites.models import Site
-from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.signals import pre_delete
from django.dispatch import receiver
@@ -88,11 +87,12 @@ def __str__(self):
def get_absolute_url(self):
return reverse('project', kwargs={'pk': self.pk})
- def clean(self):
+ def save(self, *args, **kwargs):
+ # ensure that the project hierarchy is not disturbed
if self.id and self.parent in self.get_descendants(include_self=True):
- raise ValidationError({
- 'parent': [_('A project may not be moved to be a child of itself or one of its descendants.')]
- })
+ raise RuntimeError('A project may not be moved to be a child of itself or one of its descendants.')
+
+ super().save(*args, **kwargs)
@property
def catalog_uri(self):
diff --git a/rdmo/projects/serializers/v1/__init__.py b/rdmo/projects/serializers/v1/__init__.py
index 8816dc9ca7..fe72127295 100644
--- a/rdmo/projects/serializers/v1/__init__.py
+++ b/rdmo/projects/serializers/v1/__init__.py
@@ -9,7 +9,7 @@
from rdmo.services.validators import ProviderValidator
from ...models import Integration, IntegrationOption, Invite, Issue, IssueResource, Membership, Project, Snapshot, Value
-from ...validators import ValueConflictValidator, ValueQuotaValidator, ValueTypeValidator
+from ...validators import ProjectParentValidator, ValueConflictValidator, ValueQuotaValidator, ValueTypeValidator
class UserSerializer(serializers.ModelSerializer):
@@ -78,6 +78,17 @@ class Meta:
read_only_fields = (
'snapshots',
)
+ validators = [
+ ProjectParentValidator()
+ ]
+
+
+class ProjectCopySerializer(ProjectSerializer):
+
+ class Meta:
+ model = Project
+ fields = ProjectSerializer.Meta.fields
+ read_only_fields = ProjectSerializer.Meta.read_only_fields
class ProjectMembershipSerializer(serializers.ModelSerializer):
diff --git a/rdmo/projects/templates/projects/project_detail_sidebar.html b/rdmo/projects/templates/projects/project_detail_sidebar.html
index 4c4be91ff7..c1886388a9 100644
--- a/rdmo/projects/templates/projects/project_detail_sidebar.html
+++ b/rdmo/projects/templates/projects/project_detail_sidebar.html
@@ -38,7 +38,6 @@ {% trans 'Options' %}
{% endif %}
-{% if can_change_project or can_delete_project %}
-{% endif %}
{% has_perm 'projects.add_membership_object' request.user project as can_add_membership %}
{% if can_add_membership %}
diff --git a/rdmo/projects/tests/test_utils.py b/rdmo/projects/tests/test_utils.py
index edbc5a8d3b..db2bfe2c51 100644
--- a/rdmo/projects/tests/test_utils.py
+++ b/rdmo/projects/tests/test_utils.py
@@ -1,9 +1,14 @@
import pytest
+from django.contrib.auth.models import User
+from django.contrib.sites.models import Site
from django.http import QueryDict
+from rdmo.core.tests.utils import compute_checksum
+
from ..filters import ProjectFilter
-from ..utils import set_context_querystring_with_filter_and_page
+from ..models import Project
+from ..utils import copy_project, set_context_querystring_with_filter_and_page
GET_queries = [
'page=2&title=project',
@@ -32,3 +37,94 @@ def test_set_context_querystring_with_filter_and_page(GET_query):
assert context.get('querystring', 'not-in-context') == ''
else:
assert context.get('querystring', 'not-in-context') == 'not-in-context'
+
+
+def test_copy_project(db, files):
+ project = Project.objects.get(id=1)
+ site = Site.objects.get(id=2)
+ user = User.objects.get(id=1)
+ project_copy = copy_project(project, site, [user])
+
+ # re fetch the original project
+ project = Project.objects.get(id=1)
+
+ # check that site, owners, tasks, and views are correct
+ assert project_copy.site == site
+ assert list(project_copy.owners) == [user]
+ assert list(project_copy.user.values('id')) == [{'id': user.id}]
+ assert list(project_copy.tasks.values('id')) == list(project.tasks.values('id'))
+ assert list(project_copy.views.values('id')) == list(project.views.values('id'))
+
+ # check that no ids are the same
+ assert project_copy.id != project.id
+ assert not set(project_copy.snapshots.values_list('id')).intersection(set(project.snapshots.values_list('id')))
+ assert not set(project_copy.values.values_list('id')).intersection(set(project.values.values_list('id')))
+
+ # check the snapshots
+ snapshot_fields = (
+ 'title',
+ 'description'
+ )
+ for snapshot_copy, snapshot in zip(
+ project_copy.snapshots.values(*snapshot_fields),
+ project.snapshots.values(*snapshot_fields)
+ ):
+ assert snapshot_copy == snapshot
+
+ # check the values
+ value_fields = (
+ 'attribute',
+ 'set_prefix',
+ 'set_collection',
+ 'set_index',
+ 'collection_index',
+ 'text',
+ 'option',
+ 'value_type',
+ 'unit',
+ 'external_id'
+ )
+ ordering = (
+ 'attribute',
+ 'set_prefix',
+ 'set_index',
+ 'collection_index'
+ )
+ for value_copy, value in zip(
+ project_copy.values.filter(snapshot=None).order_by(*ordering),
+ project.values.filter(snapshot=None).order_by(*ordering)
+ ):
+ for field in value_fields:
+ assert getattr(value_copy, field) == getattr(value, field), field
+
+ if value_copy.file:
+ assert value_copy.file.path != value.file.path
+ assert value_copy.file.path == value_copy.file.path.replace(
+ f'/projects/{project.id}/values/{value.id}/',
+ f'/projects/{project_copy.id}/values/{value_copy.id}/'
+ )
+ assert value_copy.file.size == value.file.size
+ assert compute_checksum(value_copy.file.open('rb').read()) == \
+ compute_checksum(value.file.open('rb').read())
+ else:
+ assert not value.file
+
+ for snapshot_copy, snapshot in zip(project_copy.snapshots.all(), project.snapshots.all()):
+ for value_copy, value in zip(
+ project_copy.values.filter(snapshot=snapshot_copy).order_by(*ordering),
+ project.values.filter(snapshot=snapshot).order_by(*ordering)
+ ):
+ for field in value_fields:
+ assert getattr(value_copy, field) == getattr(value, field)
+
+ if value_copy.file:
+ assert value_copy.file.path != value.file.path
+ assert value_copy.file.path == value_copy.file.path.replace(
+ f'/projects/{project.id}/snapshot/{snapshot.id}/values/{value.id}/',
+ f'/projects/{project_copy.id}/snapshot/{snapshot.id}/values/{value_copy.id}/'
+ )
+ assert value_copy.file.size == value.file.size
+ assert compute_checksum(value_copy.file.open('rb').read()) == \
+ compute_checksum(value.file.open('rb').read())
+ else:
+ assert not value.file
diff --git a/rdmo/projects/tests/test_view_project.py b/rdmo/projects/tests/test_view_project.py
index b43660cf14..8a1465c765 100644
--- a/rdmo/projects/tests/test_view_project.py
+++ b/rdmo/projects/tests/test_view_project.py
@@ -58,7 +58,9 @@
export_formats = ('rtf', 'odt', 'docx', 'html', 'markdown', 'tex', 'pdf')
site_id = 1
-parent_project_id = 1
+project_id = 1
+parent_id = 3
+parent_ancestors = [2, 3]
catalog_id = 1
@@ -266,7 +268,7 @@ def test_project_create_parent_post(db, client, username, password):
'title': 'A new project',
'description': 'Some description',
'catalog': catalog_id,
- 'parent': parent_project_id
+ 'parent': project_id
}
response = client.post(url, data)
@@ -335,17 +337,21 @@ def test_project_update_post_parent(db, client, username, password, project_id):
'title': project.title,
'description': project.description,
'catalog': project.catalog.pk,
- 'parent': parent_project_id
+ 'parent': parent_id
}
response = client.post(url, data)
if project_id in change_project_permission_map.get(username, []):
- if project_id == parent_project_id:
+ if parent_id in view_project_permission_map.get(username, []):
+ if project_id in parent_ancestors:
+ assert response.status_code == 200
+ assert Project.objects.get(pk=project_id).parent == project.parent
+ else:
+ assert response.status_code == 302
+ assert Project.objects.get(pk=project_id).parent_id == parent_id
+ else:
assert response.status_code == 200
assert Project.objects.get(pk=project_id).parent == project.parent
- else:
- assert response.status_code == 302
- assert Project.objects.get(pk=project_id).parent_id == parent_project_id
else:
if password:
assert response.status_code == 403
@@ -545,17 +551,21 @@ def test_project_update_parent_post(db, client, username, password, project_id):
url = reverse('project_update_parent', args=[project_id])
data = {
- 'parent': parent_project_id
+ 'parent': parent_id
}
response = client.post(url, data)
if project_id in change_project_permission_map.get(username, []):
- if project_id == parent_project_id:
+ if parent_id in view_project_permission_map.get(username, []):
+ if project_id in parent_ancestors:
+ assert response.status_code == 200
+ assert Project.objects.get(pk=project_id).parent == project.parent
+ else:
+ assert response.status_code == 302
+ assert Project.objects.get(pk=project_id).parent_id == parent_id
+ else:
assert response.status_code == 200
assert Project.objects.get(pk=project_id).parent == project.parent
- else:
- assert response.status_code == 302
- assert Project.objects.get(pk=project_id).parent_id == parent_project_id
else:
if password:
assert response.status_code == 403
diff --git a/rdmo/projects/tests/test_view_project_copy.py b/rdmo/projects/tests/test_view_project_copy.py
new file mode 100644
index 0000000000..8ecec135da
--- /dev/null
+++ b/rdmo/projects/tests/test_view_project_copy.py
@@ -0,0 +1,207 @@
+import pytest
+
+from django.contrib.auth.models import Group, User
+from django.urls import reverse
+
+from ..models import Project, Snapshot, Value
+
+users = (
+ ('owner', 'owner'),
+ ('manager', 'manager'),
+ ('author', 'author'),
+ ('guest', 'guest'),
+ ('user', 'user'),
+ ('site', 'site'),
+ ('anonymous', None),
+ ('editor', 'editor'),
+ ('reviewer', 'reviewer'),
+ ('api', 'api'),
+)
+
+view_project_permission_map = {
+ 'owner': [1, 2, 3, 4, 5, 10],
+ 'manager': [1, 3, 5, 7],
+ 'author': [1, 3, 5, 8],
+ 'guest': [1, 3, 5, 9],
+ 'api': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+ 'site': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+}
+
+change_project_permission_map = {
+ 'owner': [1, 2, 3, 4, 5, 10],
+ 'manager': [1, 3, 5, 7],
+ 'api': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+ 'site': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+}
+
+delete_project_permission_map = {
+ 'owner': [1, 2, 3, 4, 5, 10],
+ 'api': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+ 'site': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+}
+
+export_project_permission_map = {
+ 'owner': [1, 2, 3, 4, 5, 10],
+ 'manager': [1, 3, 5, 7],
+ 'api': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+ 'site': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+}
+
+projects = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+
+project_id = 1
+site_id = 1
+parent_id = 1
+catalog_id = 1
+
+
+@pytest.mark.parametrize('username,password', users)
+@pytest.mark.parametrize('project_id', projects)
+def test_project_copy_get(db, client, username, password, project_id):
+ client.login(username=username, password=password)
+
+ url = reverse('project_copy', args=[project_id])
+ response = client.get(url)
+
+ if project_id in view_project_permission_map.get(username, []):
+ assert response.status_code == 200
+ else:
+ if password:
+ assert response.status_code == 403
+ else:
+ assert response.status_code == 302
+
+
+def test_project_copy_restricted_get(db, client, settings):
+ settings.PROJECT_CREATE_RESTRICTED = True
+ settings.PROJECT_CREATE_GROUPS = ['projects']
+
+ group = Group.objects.create(name='projects')
+ guest = User.objects.get(username='guest')
+ guest.groups.add(group)
+
+ client.login(username='guest', password='guest')
+
+ url = reverse('project_copy', args=[project_id])
+ response = client.get(url)
+
+ assert response.status_code == 200
+
+
+def test_project_copy_forbidden_get(db, client, settings):
+ settings.PROJECT_CREATE_RESTRICTED = True
+
+ client.login(username='guest', password='guest')
+
+ url = reverse('project_copy', args=[project_id])
+ response = client.get(url)
+
+ assert response.status_code == 403
+
+
+@pytest.mark.parametrize('username,password', users)
+@pytest.mark.parametrize('project_id', projects)
+def test_project_copy_post(db, files, client, username, password, project_id):
+ client.login(username=username, password=password)
+
+ project_count = Project.objects.count()
+ snapshot_count = Snapshot.objects.count()
+ value_count = Value.objects.count()
+
+ project = Project.objects.get(id=project_id)
+ project_snapshots_count = project.snapshots.count()
+ project_values_count = project.values.count()
+
+ url = reverse('project_copy', args=[project_id])
+ data = {
+ 'title': 'A new project',
+ 'description': 'Some description',
+ 'catalog': catalog_id
+ }
+ response = client.post(url, data)
+
+ if project_id in view_project_permission_map.get(username, []):
+ assert response.status_code == 302
+ assert Project.objects.count() == project_count + 1
+ assert Snapshot.objects.count() == snapshot_count + project_snapshots_count
+ assert Value.objects.count() == value_count + project_values_count
+ else:
+ assert response.status_code == 403 if password else 302
+ assert Project.objects.count() == project_count
+ assert Value.objects.count() == value_count
+
+
+def test_project_copy_post_restricted(db, files, client, settings):
+ settings.PROJECT_CREATE_RESTRICTED = True
+ settings.PROJECT_CREATE_GROUPS = ['projects']
+
+ group = Group.objects.create(name='projects')
+ guest = User.objects.get(username='guest')
+ guest.groups.add(group)
+
+ client.login(username='guest', password='guest')
+
+ url = reverse('project_copy', args=[project_id])
+ data = {
+ 'title': 'A new project',
+ 'description': 'Some description',
+ 'catalog': catalog_id
+ }
+ response = client.post(url, data)
+
+ assert response.status_code == 302
+
+
+def test_project_copy_post_forbidden(db, files, client, settings):
+ settings.PROJECT_CREATE_RESTRICTED = True
+
+ client.login(username='guest', password='guest')
+
+ url = reverse('project_copy', args=[project_id])
+ data = {
+ 'title': 'A new project',
+ 'description': 'Some description',
+ 'catalog': catalog_id
+ }
+ response = client.post(url, data)
+
+ assert response.status_code == 403
+
+
+@pytest.mark.parametrize('username,password', users)
+@pytest.mark.parametrize('project_id', projects)
+def test_project_copy_parent_post(db, files, client, username, password, project_id):
+ client.login(username=username, password=password)
+ project_count = Project.objects.count()
+
+ project_count = Project.objects.count()
+ snapshot_count = Snapshot.objects.count()
+ value_count = Value.objects.count()
+
+ project = Project.objects.get(id=project_id)
+ project_snapshots_count = project.snapshots.count()
+ project_values_count = project.values.count()
+
+ url = reverse('project_copy', args=[project_id])
+ data = {
+ 'title': 'A new project',
+ 'description': 'Some description',
+ 'catalog': catalog_id,
+ 'parent': parent_id
+ }
+ response = client.post(url, data)
+
+ if project_id in view_project_permission_map.get(username, []):
+ if parent_id in view_project_permission_map.get(username, []):
+ assert response.status_code == 302
+ assert Project.objects.count() == project_count + 1
+ assert Snapshot.objects.count() == snapshot_count + project_snapshots_count
+ assert Value.objects.count() == value_count + project_values_count
+ else:
+ assert response.status_code == 200
+ assert Project.objects.count() == project_count
+ assert Value.objects.count() == value_count
+ else:
+ assert response.status_code == 403 if password else 302
+ assert Project.objects.count() == project_count
+ assert Value.objects.count() == value_count
diff --git a/rdmo/projects/tests/test_viewset_project.py b/rdmo/projects/tests/test_viewset_project.py
index 0069d25777..7427781f8d 100644
--- a/rdmo/projects/tests/test_viewset_project.py
+++ b/rdmo/projects/tests/test_viewset_project.py
@@ -3,7 +3,7 @@
from django.contrib.auth.models import Group, User
from django.urls import reverse
-from ..models import Project
+from ..models import Membership, Project, Snapshot, Value
users = (
('owner', 'owner'),
@@ -41,6 +41,7 @@
urlnames = {
'list': 'v1-projects:project-list',
'detail': 'v1-projects:project-detail',
+ 'copy': 'v1-projects:project-copy',
'overview': 'v1-projects:project-overview',
'navigation': 'v1-projects:project-navigation',
'options': 'v1-projects:project-options',
@@ -60,6 +61,8 @@
optionset_id = 4
project_id = 1
+parent_id = 3
+parent_ancestors = [2, 3]
page_size = 5
@@ -216,6 +219,133 @@ def test_create_parent(db, client, username, password, project_id):
assert response.status_code == 401
+@pytest.mark.parametrize('username,password', users)
+@pytest.mark.parametrize('project_id', projects)
+def test_copy(db, files, client, username, password, project_id):
+ client.login(username=username, password=password)
+
+ project_count = Project.objects.count()
+ snapshot_count = Snapshot.objects.count()
+ value_count = Value.objects.count()
+
+ project = Project.objects.get(id=project_id)
+ project_snapshots_count = project.snapshots.count()
+ project_values_count = project.values.count()
+
+ url = reverse(urlnames['copy'], args=[project_id])
+ data = {
+ 'title': 'New title',
+ 'description': project.description,
+ 'catalog': project.catalog.id
+ }
+ response = client.post(url, data, content_type='application/json')
+
+ if project_id in view_project_permission_map.get(username, []):
+ assert response.status_code == 201
+
+ for key, value in response.json().items():
+ if key in data:
+ assert value == data[key]
+
+ assert Project.objects.count() == project_count + 1
+ assert Snapshot.objects.count() == snapshot_count + project_snapshots_count
+ assert Value.objects.count() == value_count + project_values_count
+ else:
+ if password:
+ assert response.status_code == 404
+ else:
+ assert response.status_code == 401
+
+ assert Project.objects.count() == project_count
+ assert Value.objects.count() == value_count
+
+
+def test_copy_restricted(db, files, client, settings):
+ settings.PROJECT_CREATE_RESTRICTED = True
+ settings.PROJECT_CREATE_GROUPS = ['projects']
+
+ group = Group.objects.create(name='projects')
+ user = User.objects.get(username='user')
+ user.groups.add(group)
+
+ Membership.objects.create(user=user, project_id=project_id, role='guest')
+
+ client.login(username='user', password='user')
+
+ url = reverse(urlnames['copy'], args=[project_id])
+ data = {
+ 'title': 'Lorem ipsum dolor sit amet',
+ 'description': 'At vero eos et accusam et justo duo dolores et ea rebum.',
+ 'catalog': catalog_id
+ }
+ response = client.post(url, data, content_type='application/json')
+
+ assert response.status_code == 201
+
+
+def test_copy_forbidden(db, client, settings):
+ settings.PROJECT_CREATE_RESTRICTED = True
+
+ user = User.objects.get(username='user')
+
+ Membership.objects.create(user=user, project_id=project_id, role='guest')
+
+ client.login(username='user', password='user')
+
+ url = reverse(urlnames['copy'], args=[project_id])
+ data = {
+ 'title': 'Lorem ipsum dolor sit amet',
+ 'description': 'At vero eos et accusam et justo duo dolores et ea rebum.',
+ 'catalog': catalog_id
+ }
+ response = client.post(url, data)
+
+ assert response.status_code == 403
+
+
+def test_copy_catalog_missing(db, client):
+ client.login(username='guest', password='guest')
+
+ url = reverse(urlnames['copy'], args=[project_id])
+ data = {
+ 'title': 'Lorem ipsum dolor sit amet',
+ 'description': 'At vero eos et accusam et justo duo dolores et ea rebum.'
+ }
+ response = client.post(url, data)
+
+ assert response.status_code == 400
+
+
+def test_copy_catalog_not_available(db, client):
+ client.login(username='guest', password='guest')
+
+ url = reverse(urlnames['copy'], args=[project_id])
+ data = {
+ 'title': 'Lorem ipsum dolor sit amet',
+ 'description': 'At vero eos et accusam et justo duo dolores et ea rebum.',
+ 'catalog': catalog_id_not_available
+ }
+ response = client.post(url, data)
+
+ assert response.status_code == 400
+
+@pytest.mark.parametrize('project_id', projects)
+def test_copy_parent(db, files, client, project_id):
+ client.login(username='owner', password='owner')
+ project = Project.objects.get(pk=project_id)
+
+ url = reverse(urlnames['copy'], args=[project_id])
+ data = {
+ 'title': 'New title',
+ 'description': project.description,
+ 'catalog': project.catalog.id,
+ 'parent': parent_id
+ }
+ response = client.post(url, data, content_type='application/json')
+
+ assert response.status_code == 201
+
+
@pytest.mark.parametrize('username,password', users)
@pytest.mark.parametrize('project_id', projects)
def test_update(db, client, username, password, project_id):
@@ -246,6 +376,43 @@ def test_update(db, client, username, password, project_id):
assert Project.objects.get(id=project_id).description == project.description
+@pytest.mark.parametrize('username,password', users)
+@pytest.mark.parametrize('project_id', projects)
+def test_update_parent(db, client, username, password, project_id):
+ client.login(username=username, password=password)
+ project = Project.objects.get(pk=project_id)
+
+ url = reverse(urlnames['detail'], args=[project_id])
+ data = {
+ 'title': 'New title',
+ 'description': project.description,
+ 'catalog': project.catalog.id,
+ 'parent': parent_id
+ }
+ response = client.put(url, data, content_type='application/json')
+
+ if project_id in change_project_permission_map.get(username, []):
+ if parent_id in view_project_permission_map.get(username, []):
+ if project_id in parent_ancestors:
+ assert response.status_code == 400
+ assert Project.objects.get(pk=project_id).parent == project.parent
+ else:
+ assert response.status_code == 200
+ assert Project.objects.get(pk=project_id).parent_id == parent_id
+ else:
+ assert response.status_code == 404
+ assert Project.objects.get(pk=project_id).parent == project.parent
+ else:
+ if project_id in view_project_permission_map.get(username, []):
+ assert response.status_code == 403
+ elif password:
+ assert response.status_code == 404
+ else:
+ assert response.status_code == 401
+
+ assert Project.objects.get(pk=project_id).parent == project.parent
+
+
@pytest.mark.parametrize('username,password', users)
@pytest.mark.parametrize('project_id', projects)
def test_delete(db, client, username, password, project_id):
diff --git a/rdmo/projects/urls/__init__.py b/rdmo/projects/urls/__init__.py
index 0be1bdbd97..b7cc6d81f9 100644
--- a/rdmo/projects/urls/__init__.py
+++ b/rdmo/projects/urls/__init__.py
@@ -15,6 +15,7 @@
ProjectAnswersExportView,
ProjectAnswersView,
ProjectCancelView,
+ ProjectCopyView,
ProjectCreateImportView,
ProjectCreateView,
ProjectDeleteView,
@@ -56,6 +57,8 @@
re_path(r'^(?P[0-9]+)/$',
ProjectDetailView.as_view(), name='project'),
+ re_path(r'^(?P[0-9]+)/copy/$',
+ ProjectCopyView.as_view(), name='project_copy'),
re_path(r'^(?P[0-9]+)/update/$',
ProjectUpdateView.as_view(), name='project_update'),
re_path(r'^(?P[0-9]+)/update/information/$',
diff --git a/rdmo/projects/utils.py b/rdmo/projects/utils.py
index c23b0137fa..ab619b54b4 100644
--- a/rdmo/projects/utils.py
+++ b/rdmo/projects/utils.py
@@ -5,6 +5,7 @@
from django.contrib.sites.models import Site
from django.template.loader import render_to_string
from django.urls import reverse
+from django.utils.timezone import now
from rdmo.core.mail import send_mail
from rdmo.core.plugins import get_plugins
@@ -38,6 +39,84 @@ def check_conditions(conditions, values, set_prefix=None, set_index=None):
return True
+def copy_project(project, site, owners):
+ from .models import Membership, Value # to prevent circular inclusion
+
+ timestamp = now()
+
+ tasks = project.tasks.all()
+ views = project.views.all()
+
+ values = project.values.filter(snapshot=None)
+ snapshots = {
+ snapshot: project.values.filter(snapshot=snapshot)
+ for snapshot in project.snapshots.all()
+ }
+
+ # unset the id, set current site and update timestamps
+ project.id = None
+ project.site = site
+ project.created = timestamp
+
+ # save the new project
+ project.save()
+
+ # save project tasks
+ for task in tasks:
+ project.tasks.add(task)
+
+ # save project views
+ for view in views:
+ project.views.add(view)
+
+ # save current project values
+ project_values = []
+ for value in values:
+ value.id = None
+ value.project = project
+ value.created = timestamp
+
+ if value.file:
+ # file values cannot be bulk created since we need their id and only postgres provides that (reliably)
+ # https://docs.djangoproject.com/en/4.2/ref/models/querysets/#bulk-create
+ value.save()
+ value.copy_file(value.file_name, value.file)
+ else:
+ project_values.append(value)
+
+ # insert the new values using bulk_create
+ Value.objects.bulk_create(project_values)
+
+ # save project snapshots
+ for snapshot, snapshot_values in snapshots.items():
+ snapshot.id = None
+ snapshot.project = project
+ snapshot.created = timestamp
+ snapshot.save(copy_values=False)
+
+ project_snapshot_values = []
+ for value in snapshot_values:
+ value.id = None
+ value.project = project
+ value.snapshot = snapshot
+ value.created = timestamp
+
+ if value.file:
+ value.save()
+ value.copy_file(value.file_name, value.file)
+ else:
+ project_snapshot_values.append(value)
+
+ # insert the new snapshot values using bulk_create
+ Value.objects.bulk_create(project_snapshot_values)
+
+ for owner in owners:
+ membership = Membership(project=project, user=owner, role='owner')
+ membership.save()
+
+ return project
+
+
def save_import_values(project, values, checked):
for value in values:
if value.attribute:
diff --git a/rdmo/projects/validators.py b/rdmo/projects/validators.py
index da0ace2423..e3bd382923 100644
--- a/rdmo/projects/validators.py
+++ b/rdmo/projects/validators.py
@@ -20,6 +20,19 @@
VALUE_TYPE_URL,
)
from rdmo.core.utils import human2bytes
+from rdmo.core.validators import InstanceValidator
+
+
+class ProjectParentValidator(InstanceValidator):
+
+ def __call__(self, data, serializer=None):
+ super().__call__(data, serializer)
+
+ if self.instance and self.instance.id \
+ and data.get('parent') in self.instance.get_descendants(include_self=True):
+ raise self.raise_validation_error({
+ 'parent': [_('A project may not be moved to be a child of itself or one of its descendants.')]
+ })
class ValueConflictValidator:
diff --git a/rdmo/projects/views/__init__.py b/rdmo/projects/views/__init__.py
index 543481bc22..a6899c8671 100644
--- a/rdmo/projects/views/__init__.py
+++ b/rdmo/projects/views/__init__.py
@@ -14,6 +14,7 @@
ProjectsView,
)
from .project_answers import ProjectAnswersExportView, ProjectAnswersView
+from .project_copy import ProjectCopyView
from .project_create import ProjectCreateImportView, ProjectCreateView
from .project_update import (
ProjectUpdateCatalogView,
diff --git a/rdmo/projects/views/project_copy.py b/rdmo/projects/views/project_copy.py
new file mode 100644
index 0000000000..6d9df32c3e
--- /dev/null
+++ b/rdmo/projects/views/project_copy.py
@@ -0,0 +1,42 @@
+import logging
+
+from django.contrib.sites.shortcuts import get_current_site
+from django.http import HttpResponseRedirect
+from django.views.generic import UpdateView
+
+from rdmo.core.views import ObjectPermissionMixin, RedirectViewMixin
+from rdmo.questions.models import Catalog
+
+from ..forms import ProjectForm
+from ..models import Project
+from ..utils import copy_project
+
+logger = logging.getLogger(__name__)
+
+
+class ProjectCopyView(ObjectPermissionMixin, RedirectViewMixin, UpdateView):
+
+ model = Project
+ form_class = ProjectForm
+ permission_required = ('projects.add_project', 'projects.view_project_object')
+
+ def get_form_kwargs(self):
+ catalogs = Catalog.objects.filter_current_site() \
+ .filter_group(self.request.user) \
+ .filter_availability(self.request.user) \
+ .order_by('-available', 'order')
+ projects = Project.objects.filter_user(self.request.user)
+
+ form_kwargs = super().get_form_kwargs()
+ form_kwargs.update({
+ 'copy': True,
+ 'catalogs': catalogs,
+ 'projects': projects
+ })
+ return form_kwargs
+
+ def form_valid(self, form):
+ site = get_current_site(self.request)
+ owners = [self.request.user]
+ project = copy_project(form.instance, site, owners)
+ return HttpResponseRedirect(project.get_absolute_url())
diff --git a/rdmo/projects/viewsets.py b/rdmo/projects/viewsets.py
index 8c5d367a5b..0a0fde16d6 100644
--- a/rdmo/projects/viewsets.py
+++ b/rdmo/projects/viewsets.py
@@ -6,7 +6,7 @@
from django.http import Http404, HttpResponseRedirect
from django.utils.translation import gettext_lazy as _
-from rest_framework import serializers
+from rest_framework import serializers, status
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
from rest_framework.mixins import CreateModelMixin, ListModelMixin, RetrieveModelMixin, UpdateModelMixin
@@ -48,6 +48,7 @@
InviteSerializer,
IssueSerializer,
MembershipSerializer,
+ ProjectCopySerializer,
ProjectIntegrationSerializer,
ProjectInviteSerializer,
ProjectInviteUpdateSerializer,
@@ -63,7 +64,7 @@
)
from .serializers.v1.overview import CatalogSerializer, ProjectOverviewSerializer
from .serializers.v1.page import PageSerializer
-from .utils import check_conditions, get_upload_accept, send_invite_email
+from .utils import check_conditions, copy_project, get_upload_accept, send_invite_email
class ProjectPagination(PageNumberPagination):
@@ -116,6 +117,25 @@ def get_queryset(self):
return queryset
+ @action(detail=True, methods=['POST'],
+ permission_classes=(HasModelPermission | HasProjectPermission, ))
+ def copy(self, request, pk=None):
+ instance = self.get_object()
+ serializer = ProjectCopySerializer(instance, data=request.data, context=self.get_serializer_context())
+ serializer.is_valid(raise_exception=True)
+
+ # update instance
+ for key, value in serializer.validated_data.items():
+ setattr(instance, key, value)
+
+ site = get_current_site(self.request)
+ owners = [self.request.user]
+ project_copy = copy_project(instance, site, owners)
+
+ serializer = self.get_serializer(project_copy)
+ headers = self.get_success_headers(serializer.data)
+ return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+
@action(detail=True, permission_classes=(HasModelPermission | HasProjectPermission, ))
def overview(self, request, pk=None):
project = self.get_object()