// SPDX-FileComment: This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Containers/ndarray/Indexing.h>
#include <TNL/Containers/ndarray/SizesHolder.h>        // make_sizes_holder
#include <TNL/Containers/ndarray/StaticSizesHolder.h>  // ConstStaticSizesHolder

namespace TNL::Containers {

// HACK for https://stackoverflow.com/q/74240374
#ifdef _MSC_VER
template< typename T >
[[nodiscard]] constexpr std::size_t
getDimension()
{
   return T::getDimension();
}
#endif

/**
 * \brief Indexer for N-dimensional arrays. It does not store any data, only
 * the sizes of each dimension.
 *
 * \tparam SizesHolder Instance of \ref SizesHolder that will represent the
 *                     array sizes.
 * \tparam StridesHolder Type of the base class which represents the strides of
 *                       the N-dimensional array.
 * \tparam Overlaps Sequence of integers representing the overlaps in each
 *                  dimension a distributed N-dimensional array.
 *
 * \ingroup ndarray
 */
template< typename SizesHolder,
          typename StridesHolder,
          typename Overlaps = ConstStaticSizesHolder< typename SizesHolder::IndexType, SizesHolder::getDimension(), 0 > >
class NDArrayIndexer
{
public:
   //! \brief Type of the underlying object which represents the sizes of the N-dimensional array.
   using SizesHolderType = SizesHolder;

   //! \brief Type of the underlying object which represents the strides of the N-dimensional array.
   using StridesHolderType = StridesHolder;

   //! \brief Type of the underlying object which represents the overlaps in each dimension
   //! of a distributed N-dimensional array.
   using OverlapsType = Overlaps;

   //! \brief Type of indices used for addressing the array elements.
   using IndexType = typename SizesHolder::IndexType;

// HACK for https://stackoverflow.com/q/74240374
#ifdef _MSC_VER
   static_assert( Containers::getDimension< StridesHolder >() == Containers::getDimension< SizesHolder >(),
                  "Dimension of strides does not match the dimension of sizes." );
   static_assert( Containers::getDimension< Overlaps >() == Containers::getDimension< SizesHolder >(),
                  "Dimension of overlaps does not match the dimension of sizes." );
#else
   static_assert( StridesHolder::getDimension() == SizesHolder::getDimension(),
                  "Dimension of strides does not match the dimension of sizes." );
   static_assert( Overlaps::getDimension() == SizesHolder::getDimension(),
                  "Dimension of overlaps does not match the dimension of sizes." );
#endif

   //! \brief Constructs an empty indexer with zero sizes and strides.
   __cuda_callable__
   NDArrayIndexer() = default;

   //! \brief Creates the indexer with given sizes and strides.
   __cuda_callable__
   NDArrayIndexer( SizesHolderType sizes, StridesHolderType strides, OverlapsType overlaps )
   : sizes( std::move( sizes ) ),
     strides( std::move( strides ) ),
     overlaps( std::move( overlaps ) )
   {}

   //! \brief Returns the dimension of the \e N-dimensional array, i.e. \e N.
   [[nodiscard]] static constexpr std::size_t
   getDimension()
   {
// HACK for https://stackoverflow.com/q/74240374
#ifdef _MSC_VER
      return Containers::getDimension< SizesHolder >();
#else
      return SizesHolder::getDimension();
#endif
   }

   //! \brief Returns the N-dimensional array sizes held by the indexer.
   [[nodiscard]] __cuda_callable__
   const SizesHolderType&
   getSizes() const
   {
      return sizes;
   }

   /**
    * \brief Returns a specific component of the N-dimensional sizes.
    *
    * \tparam level Integer specifying the component of the sizes to be returned.
    */
   template< std::size_t level >
   [[nodiscard]] __cuda_callable__
   IndexType
   getSize() const
   {
      return sizes.template getSize< level >();
   }

   //! \brief Returns the N-dimensional strides holder instance.
   [[nodiscard]] __cuda_callable__
   const StridesHolderType&
   getStrides() const
   {
      return strides;
   }

   /**
    * \brief Returns a specific component of the N-dimensional strides.
    *
    * \tparam level Integer specifying the component of the strides to be returned.
    */
   template< std::size_t level >
   [[nodiscard]] __cuda_callable__
   IndexType
   getStride() const
   {
      return getStrides().template getSize< level >();
   }

   //! \brief Returns the N-dimensional overlaps holder instance.
   [[nodiscard]] __cuda_callable__
   const OverlapsType&
   getOverlaps() const
   {
      return overlaps;
   }

   /**
    * \brief Returns the overlap of a distributed N-dimensional array along the
    *        specified axis.
    *
    * \tparam level Integer specifying the axis of the array.
    */
   template< std::size_t level >
   [[nodiscard]] __cuda_callable__
   IndexType
   getOverlap() const
   {
      return getOverlaps().template getSize< level >();
   }

   /**
    * \brief Returns the size (number of elements) needed to store the N-dimensional array.
    *
    * \returns The product of the aligned sizes.
    */
   [[nodiscard]] __cuda_callable__
   IndexType
   getStorageSize() const
   {
      return detail::getStorageSize( getSizes(), getOverlaps() );
   }

   /**
    * \brief Computes the one-dimensional storage index for a specific element
    *        of the N-dimensional array.
    *
    * \param indices Indices of the element in the N-dimensional array. The
    *                number of indices supplied must be equal to \e N, i.e.
    *                \ref getDimension().
    * \returns An index that can be used to address the element in a
    *          one-dimensional array.
    */
   template< typename... IndexTypes >
   [[nodiscard]] __cuda_callable__
   IndexType
   getStorageIndex( IndexTypes&&... indices ) const
   {
      static_assert( sizeof...( indices ) == getDimension(), "got wrong number of indices" );
      detail::assertIndicesInBounds( getSizes(), getOverlaps(), std::forward< IndexTypes >( indices )... );
      const IndexType result = detail::getStorageIndex( getStrides(), getOverlaps(), std::forward< IndexTypes >( indices )... );
      TNL_ASSERT_GE( result, (IndexType) 0, "storage index out of bounds - either input error or a bug in the indexer" );
      return result;
   }

   template< typename BeginsHolder, typename EndsHolder >
   [[nodiscard]] __cuda_callable__
   bool
   isContiguousBlock( const BeginsHolder& begins, const EndsHolder& ends ) const
   {
      static_assert( BeginsHolder::getDimension() == getDimension(), "invalid dimension of the begins parameter" );
      static_assert( EndsHolder::getDimension() == getDimension(), "invalid dimension of the ends parameter" );

      constexpr std::size_t dim = getDimension();
      IndexType sizes[ dim ];
      IndexType strides[ dim ];
      int non_degenerate_indices[ dim ] = { 0 };
      int non_degenerate_count = 0;

      Algorithms::staticFor< std::size_t, 0, dim >(
         [ & ]( auto i )
         {
            // Collect sizes and strides for each dimension
            sizes[ i ] = ends.template getSize< i >() - begins.template getSize< i >();
            strides[ i ] = getStride< i >();
            // Collect indices of non-degenerate dimensions
            if( sizes[ i ] > 1 ) {
               non_degenerate_indices[ non_degenerate_count++ ] = i;
            }
         } );

      TNL_ASSERT_LE( non_degenerate_count, (int) dim, "collecting indices of non-degenerate dimensions failed" );

      // Sort non-degenerate dimensions by stride (ascending)
      for( int i = 1; i < non_degenerate_count; i++ ) {
         // GCC incorrectly detects out-of-bounds index
#if defined( __GNUC__ ) && ! defined( __clang__ )
   #pragma GCC diagnostic push
   #pragma GCC diagnostic ignored "-Warray-bounds"
#endif
         int key = non_degenerate_indices[ i ];
#if defined( __GNUC__ ) && ! defined( __clang__ )
   #pragma GCC diagnostic pop
#endif
         int j = i - 1;
         while( j >= 0 && strides[ non_degenerate_indices[ j ] ] > strides[ key ] ) {
            non_degenerate_indices[ j + 1 ] = non_degenerate_indices[ j ];
            j--;
         }
         non_degenerate_indices[ j + 1 ] = key;
      }

      // Check if the smallest stride is 1
      if( strides[ non_degenerate_indices[ 0 ] ] != 1 )
         return false;

      // Check the product condition
      IndexType product = 1;
      for( int i = 0; i < non_degenerate_count; i++ ) {
         int idx = non_degenerate_indices[ i ];
         if( i > 0 && strides[ idx ] != product )
            return false;
         product *= sizes[ idx ];
      }

      return true;
   }

protected:
   /**
    * \brief Returns a non-constant reference to the underlying \ref sizes.
    *
    * The function is not public -- only subclasses like \ref NDArrayStorage
    * may modify the sizes.
    */
   [[nodiscard]] __cuda_callable__
   SizesHolderType&
   getSizes()
   {
      return sizes;
   }

   /**
    * \brief Returns a non-constant reference to the underlying \ref strides.
    *
    * The function is not public -- only subclasses like \ref NDArrayStorage
    * may modify the strides.
    */
   [[nodiscard]] __cuda_callable__
   StridesHolderType&
   getStrides()
   {
      return strides;
   }

   /**
    * \brief Returns a non-constant reference to the underlying \ref overlaps.
    *
    * The function is not public -- only subclasses like \ref NDArrayStorage
    * may modify the strides.
    */
   [[nodiscard]] __cuda_callable__
   OverlapsType&
   getOverlaps()
   {
      return overlaps;
   }

   // TODO: use [[no_unique_address]] in C++20 - see https://www.cppstories.com/2021/no-unique-address/
   //! \brief Underlying object which represents the sizes of the N-dimensional array.
   SizesHolderType sizes;

   // TODO: use [[no_unique_address]] in C++20 - see https://www.cppstories.com/2021/no-unique-address/
   //! \brief Underlying object which represents the strides of the N-dimensional array.
   StridesHolderType strides;

   // TODO: use [[no_unique_address]] in C++20 - see https://www.cppstories.com/2021/no-unique-address/
   //! \brief Underlying object which represents the overlaps of the N-dimensional array.
   OverlapsType overlaps;
};

}  // namespace TNL::Containers
