diff --git a/lib/components/Tray.js b/lib/components/Tray.js index e52d73b..b082178 100644 --- a/lib/components/Tray.js +++ b/lib/components/Tray.js @@ -14,7 +14,7 @@ export default React.createClass({ closeTimeoutMS: React.PropTypes.number, closeOnBlur: React.PropTypes.bool, maintainFocus: React.PropTypes.bool, - elementToFocus: React.PropTypes.string, + getElementToFocus: a11yFunction, getAriaHideElement: a11yFunction }, diff --git a/lib/components/TrayPortal.js b/lib/components/TrayPortal.js index b8589b3..317ec42 100644 --- a/lib/components/TrayPortal.js +++ b/lib/components/TrayPortal.js @@ -59,7 +59,7 @@ export default React.createClass({ closeTimeoutMS: PropTypes.number, children: PropTypes.any, maintainFocus: PropTypes.bool, - elementToFocus: PropTypes.string, + getElementToFocus: PropTypes.func, getAriaHideElement: PropTypes.func }, @@ -88,8 +88,8 @@ export default React.createClass({ componentDidUpdate() { if (this.focusAfterRender) { - if (this.props.elementToFocus) { - this.focusSelector(this.props.elementToFocus); + if (this.props.getElementToFocus) { + this.props.getElementToFocus().focus(); } else { this.focusContent(); } @@ -105,17 +105,6 @@ export default React.createClass({ this.refs.content.focus(); }, - findSingleElement(querySelectorToUse) { - const el = document.querySelectorAll(querySelectorToUse); - const element = (el.length) ? el[0] : el; - return element; - }, - - focusSelector(querySelectorToUse) { - const element = this.findSingleElement(querySelectorToUse); - element.focus(); - }, - toggleAriaHidden(element) { if (!element.getAttribute('aria-hidden')) { element.setAttribute('aria-hidden', true); diff --git a/lib/components/__tests__/Tray-test.js b/lib/components/__tests__/Tray-test.js index 246092e..a111b20 100644 --- a/lib/components/__tests__/Tray-test.js +++ b/lib/components/__tests__/Tray-test.js @@ -117,9 +117,13 @@ describe('react-tray', function() { }); }); - describe('elementToFocus prop', function() { + describe('getElementToFocus prop', function() { + const getElementToFocus = () => { + return document.getElementById('two'); + }; + beforeEach(function(done) { - const props = {isOpen: true, onBlur: function() {}, closeTimeoutMS: 0, maintainFocus: true, elementToFocus: '#two'}; + const props = {isOpen: true, onBlur: function() {}, closeTimeoutMS: 0, maintainFocus: true, getElementToFocus: getElementToFocus}; const children = (