diff --git a/invenio_pidstore/providers/base32.py b/invenio_pidstore/providers/base32.py new file mode 100644 index 0000000..32ef1e3 --- /dev/null +++ b/invenio_pidstore/providers/base32.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# +# This file is part of Invenio. +# Copyright (C) 2015-2018 CERN. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Douglas Crockford Base32-URL encoder. + +This encoder/decoder: +- uses Douglas Crockford Base32 encoding +- allows for checksum +- encodes the checksum using only characters in the base32 set + (only digits in fact) +- produces string that are URI-friendly (no '=' or '/' for instance) + +This is based on: +- https://github.com/datacite/base32-url +- https://github.com/jbittel/base32-crockford +""" +import string + +import six + +# NO i, l, o or u +ENCODING_CHARS = '0123456789abcdefghjkmnpqrstvwxyz' +DECODING_CHARS = {c: i for i, c in enumerate(ENCODING_CHARS)} + + +def encode(number, split_every=0, checksum=False): + """Encodes `number` to URI-friendly Douglas Crockford base32 string. + + :param number: number to encode + :param split_every: if provided, insert '-' every `split_every` characters + going from left to right + :param checksum: append modulo 97-10 (ISO 7064) checksum to string + :returns: A random Douglas Crockford base32 encoded string composed only + of valid URI characters. + """ + assert isinstance(number, six.integer_types) + + if number < 0: + raise ValueError("Invalid 'number'. 'number' must > 0.") + + if split_every < 0: + raise ValueError("Invalid 'split_every'. 'split_every' must > 0.") + + encoded = '' + original_number = number + while number > 0: + remainder = number % 32 + number //= 32 + encoded = ENCODING_CHARS[remainder] + encoded + + if checksum: + # NOTE: 100 * original_number is used because datacite also uses it + computed_checksum = 97 - ((100 * original_number) % 97) + 1 + encoded_checksum = "{:02d}".format(computed_checksum) + encoded += encoded_checksum + + if split_every: + splits = [ + encoded[i:i+split_every] + for i in range(0, len(encoded), split_every) + ] + encoded = '-'.join(splits) + + return encoded + + +def normalize(encoded): + """Returns normalized encoded string. + + - string is lowercased + - '-' are removed + - I,i,l,L decodes to the digit 1 + - O,o decodes to the digit 0 + + :param encoded: string to decode + :returns: normalized string. + """ + table = ( + ''.maketrans('IiLlOo', '111100') if six.PY3 else + string.maketrans('IiLlOo', '111100') + ) + encoded = encoded.replace('-', '').translate(table).lower() + + if not all([c in ENCODING_CHARS for c in encoded]): + raise ValueError("'encoded' contains undecodable characters") + + return encoded + + +def decode(encoded, checksum=False): + """Decodes `encoded` string (via above) to a number. + + The string is normalized before decoding. + + If `checksum` is enabled, raises a ValueError on checksum error. + + :param encoded: string to decode + :param checksum: extract checksum and validate + :returns: original number. + """ + if checksum: + encoded_checksum = encoded[-2:] + encoded = encoded[:-2] + + encoded = normalize(encoded) + + number = 0 + for i, c in enumerate(reversed(encoded)): + number += DECODING_CHARS[c] * (32**i) + + if checksum: + verification_checksum = int(encoded_checksum, 10) + # NOTE: 100 * number is used because datacite also uses it + computed_checksum = 97 - ((100 * number) % 97) + 1 + + if verification_checksum != computed_checksum: + raise ValueError("Invalid checksum.") + + return number diff --git a/invenio_pidstore/providers/datacite.py b/invenio_pidstore/providers/datacite.py index 9200065..848b6fe 100644 --- a/invenio_pidstore/providers/datacite.py +++ b/invenio_pidstore/providers/datacite.py @@ -10,12 +10,19 @@ from __future__ import absolute_import +import codecs +import os +import random +import re + from datacite import DataCiteMDSClient from datacite.errors import DataCiteError, DataCiteGoneError, \ DataCiteNoContentError, DataCiteNotFoundError, HttpError from flask import current_app +from idutils import normalize_pid from ..models import PIDStatus, logger +from ..providers import base32 from .base import BaseProvider @@ -31,20 +38,71 @@ class DataCiteProvider(BaseProvider): default_status = PIDStatus.NEW """Default status for newly created PIDs by this provider.""" + doi_prefix_regexp = re.compile( + r"10\.\d+(\.\d+)*$" + ) + + @classmethod + def valid_doi_prefix(cls, prefix): + """Matches if prefix is a DOI prefix. + + Potential TODO: add to idutils module. + """ + return cls.doi_prefix_regexp.match(prefix) + + @classmethod + def generate_doi(cls, prefix, suffix_length=10, split_every=5, + checksum=True): + """Generate random DOI with `prefix`.""" + if not cls.valid_doi_prefix(prefix): + logger.exception("Invalid DOI prefix", extra={'prefix': prefix}) + return None + + if checksum and suffix_length < 3: + logger.exception( + "Invalid suffix_length. At least 3 if checksum enabled", + extra={'suffix_length': suffix_length} + ) + return None + + generator = random.SystemRandom() + length = suffix_length - 2 if checksum else suffix_length + number = generator.getrandbits(length * 5) + + return ( + prefix + + "/" + + base32.encode( + number, + split_every=split_every, + checksum=checksum + ) + ) + @classmethod - def create(cls, pid_value, **kwargs): + def create(cls, pid_value=None, prefix=None, **kwargs): """Create a new record identifier. For more information about parameters, see :meth:`invenio_pidstore.providers.BaseProvider.create`. :param pid_value: Persistent identifier value. + :param prefix: DOI prefix if pid_value is None :params **kwargs: See :meth:`invenio_pidstore.providers.base.BaseProvider.create` extra parameters. :returns: A :class:`invenio_pidstore.providers.DataCiteProvider` instance. """ + if pid_value is None: + prefix = ( + prefix or + current_app.config.get('PIDSTORE_DATACITE_DOI_PREFIX') + ) + pid_value = cls.generate_doi(prefix) + else: + pid_value = normalize_pid(pid_value, 'doi') + return super(DataCiteProvider, cls).create( pid_value=pid_value, **kwargs) diff --git a/run-tests.sh b/run-tests.sh index ee41155..040a381 100755 --- a/run-tests.sh +++ b/run-tests.sh @@ -1,3 +1,4 @@ +#!/usr/bin/env sh # -*- coding: utf-8 -*- # # This file is part of Invenio. diff --git a/setup.py b/setup.py index b415dea..16374bc 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ 'Flask-Menu>=0.5.1', 'invenio-access>=1.0.0', 'invenio-accounts>=1.0.0', - 'mock>=1.3.0', + 'mock>=3.0.0', 'pydocstyle>=1.0.0', 'pytest-cov>=1.8.0', 'pytest-pep8>=1.0.6', @@ -70,6 +70,8 @@ install_requires = [ 'Flask-BabelEx>=0.9.3', 'Flask>=0.11.1', + 'idutils>=1.1.4', + 'six>=1.12.0' ] packages = find_packages() diff --git a/tests/test_base32.py b/tests/test_base32.py new file mode 100644 index 0000000..e68f205 --- /dev/null +++ b/tests/test_base32.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# +# This file is part of Invenio. +# Copyright (C) 2015-2019 CERN. +# Copyright (C) 2019 Northwestern University. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Provider tests.""" + +import pytest + +from invenio_pidstore.providers import base32 + + +def test_basic_encode(): + assert base32.encode(32) == "10" + assert base32.encode(1234) == "16j" + + +def test_basic_decode(): + assert base32.decode("16j") == 1234 + + +def test_decode_normalizes_symbols(): + assert ( + base32.decode("abcdefghijklmnopqrstvwxyz") == + base32.decode("ABCDEFGHIJKLMNOPQRSTVWXYZ") + ) + assert base32.decode('IL1O0ilo') == base32.decode('11100110') + assert base32.decode('1-6-j') == base32.decode('16j') + + +def test_decode_raises_for_invalid_string(): + with pytest.raises(ValueError): + base32.decode("Ü'+?") + + +def test_encode_hyphenates(): + assert base32.encode(1234, split_every=1) == "1-6-j" + + with pytest.raises(ValueError): + assert base32.encode(1234, split_every=-1) + + +def test_encode_checksum(): + assert base32.encode(1234, checksum=True) == "16j82" + + +def test_decode_checksum(): + assert base32.decode("16j82", checksum=True) == 1234 + + +def test_decode_invalid_checksum(): + with pytest.raises(ValueError): + assert base32.decode("16j44", checksum=True) diff --git a/tests/test_providers.py b/tests/test_providers.py index 5a026e3..0c171ba 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -241,3 +241,71 @@ def test_datacite_sync(logger, app, db): assert provider.pid.status == PIDStatus.NEW assert logger.exception.call_args[0][0] == \ "Failed to sync status from DataCite" + + +def test_datacite_valid_doi_prefix(): + assert DataCiteProvider.valid_doi_prefix('10.1234') + assert DataCiteProvider.valid_doi_prefix('10.12.34') + assert not DataCiteProvider.valid_doi_prefix('101234') + assert not DataCiteProvider.valid_doi_prefix('101234.') + assert not DataCiteProvider.valid_doi_prefix('10.1234/') + assert not DataCiteProvider.valid_doi_prefix('100.1234') + assert not DataCiteProvider.valid_doi_prefix('10.12E45') + + +@patch('invenio_pidstore.providers.datacite.base32') +def test_datacite_generate_doi_calls_encode(patched_base32): + DataCiteProvider.generate_doi('10.1234') + + patched_base32.encode.assert_called() + + +@patch('invenio_pidstore.providers.datacite.logger') +def test_datacite_generate_doi_properties(logger): + doi = DataCiteProvider.generate_doi( + '10.1234', suffix_length=8, split_every=2 + ) + + # prefix + assert doi.startswith('10.1234/') + + # suffix_length + suffix_length = len(doi.split('/')[1]) + assert suffix_length == 8 + 8 / 2 - 1 + assert DataCiteProvider.generate_doi('10.1234', suffix_length=2) is None + logger.exception.assert_called() + + # hyphenated + assert doi.count('-') == 8 / 2 - 1 + + tmp_doi = DataCiteProvider.generate_doi('10.1234', split_every=0) + assert tmp_doi.count('-') == 0 + + tmp_doi = DataCiteProvider.generate_doi( + '10.1234', suffix_length=8, split_every=8 + ) + assert tmp_doi.count('-') == 0 + + tmp_doi = DataCiteProvider.generate_doi( + '10.1234', suffix_length=8, split_every=9 + ) + assert tmp_doi.count('-') == 0 + + +def test_datacite_provider_create_with_prefix(app, db): + original_config = app.config.get('PIDSTORE_DATACITE_DOI_PREFIX') + app.config['PIDSTORE_DATACITE_DOI_PREFIX'] = '10.4321' + + with app.app_context(): + provider = DataCiteProvider.create(prefix='10.1234') + assert provider.pid.status == PIDStatus.NEW + assert provider.pid.pid_provider == 'datacite' + assert provider.pid.pid_value.startswith('10.1234') + + # Test DataCiteProvider.create uses configuration setting + provider = DataCiteProvider.create() + assert provider.pid.status == PIDStatus.NEW + assert provider.pid.pid_provider == 'datacite' + assert provider.pid.pid_value.startswith('10.4321') + + app.config['PIDSTORE_DATACITE_DOI_PREFIX'] = original_config