import PropTypes from 'prop-types'
import { useEffect, useRef } from 'react'
import classNames from 'classnames'

import styles from './TrapFocus.module.scss'

/**
 * Name: TrapFocus
 * Desc: Trap Focus On Element
 * For documentation go see in https://github.com/mui-org/material-ui/blob/master/packages/material-ui/src/Unstable_TrapFocus/Unstable_TrapFocus.js
 * @param {object} children
 * @param {bool} isActive
 */

const TrapFocus = ({ children, isActive = null, className = '' }) => {
  const nodeToRestore = useRef()
  const sentinelStart = useRef(null)
  const sentinelEnd = useRef(null)
  const isActiveTrapFocus = !!(isActive || isActive === null)

  useEffect(() => {
    const loopFocus = (event) => {
      // 9 = Tab
      if (event.keyCode !== 9 || !isActiveTrapFocus) {
        return
      }

      if (sentinelEnd.current === document.activeElement) {
        sentinelStart?.current?.focus()
      }

      if (event.shiftKey && sentinelStart.current === document.activeElement) {
        sentinelEnd?.current?.focus()
      }
    }

    document.addEventListener('keydown', loopFocus, true)
    if (isActive === null) {
      // Added to manage focus where TrapFocus will be unmounted with its child
      nodeToRestore.current = document.activeElement
    } else {
      // Add to manage focus where TrapFocus is maintained with isActive boolean value
      isActive
        ? (nodeToRestore.current = document.activeElement)
        : nodeToRestore?.current?.focus()
    }
    isActiveTrapFocus && sentinelStart?.current?.focus()

    return () => {
      document.removeEventListener('keydown', loopFocus, true)
      // Added to manage focus where TrapFocus will be unmounted with its child
      if (isActive === null) {
        nodeToRestore?.current?.focus()
        nodeToRestore.current = null
      }
    }
  }, [isActive, isActiveTrapFocus])

  const focusableButtonProps = isActiveTrapFocus
    ? {
        tabIndex: 0,
      }
    : {}

  const classes = classNames(styles.noFocus, className)

  return (
    <>
      <div {...focusableButtonProps} ref={sentinelStart} className={classes} />
      {children}
      <div {...focusableButtonProps} ref={sentinelEnd} className={classes} />
    </>
  )
}

TrapFocus.propTypes = {
  children: PropTypes.node,
  isActive: PropTypes.bool,
  className: PropTypes.string,
}

export default TrapFocus
