#include "sputnik/block/dsd/cutlass/dsd.h"
#include "sputnik/block/cutlass/block_pitch_linear.h"
#include "sputnik/block/cutlass/default_block_gemm.h"
#include "sputnik/block/cutlass/kernel.h"
#include "sputnik/block/cutlass/threadblock_swizzle.h"
#include "sputnik/block/transpose/transpose.h"

namespace sputnik {
namespace block {
namespace cutlass {

namespace {

using dsd_mixed_b128_128x128x32x5_tt_align8_base =
  typename DefaultBlockGemm<
  BlockSize::k128,
  // Transposed A operand.
  ::cutlass::half_t,
  BlockColumnMajor,
  8,
  // Transposed B operand.
  ::cutlass::half_t,
  ::cutlass::layout::ColumnMajor,
  8,
  // C operand.
  ::cutlass::half_t,
  ::cutlass::layout::RowMajor,
  float,
  ::cutlass::arch::OpClassTensorOp,
  ::cutlass::arch::Sm80,
  ::cutlass::gemm::GemmShape<128, 128, 32>,
  ::cutlass::gemm::GemmShape<64, 64, 32>,
  ::cutlass::gemm::GemmShape<16, 8, 16>,
  ::cutlass::epilogue::thread::LinearCombination<::cutlass::half_t, 8, float, float>,
  GemmVerticalThreadblockSwizzle,
  5,
  ::cutlass::arch::OpMultiplyAdd
>::GemmKernel;

// Define named type
struct dsd_mixed_b128_128x128x32x5_tt_align8 :
  public dsd_mixed_b128_128x128x32x5_tt_align8_base { };

}  // namespace


bool can_launch_dsd_mixed_b128_128x128x32x5_tt_align8(
    const BlockMatrix a, bool transpose_a,
    const Matrix b, bool transpose_b, Matrix c) {
  using Dsd = Kernel<dsd_mixed_b128_128x128x32x5_tt_align8>;

  MatmulShape shape(a, transpose_a, b, transpose_b);
  Dsd::Arguments args({shape.m, shape.n, shape.k},
                      {1.0f, 0.0f},
                      {nullptr, 0},
                      {nullptr, 0},
                      {nullptr, 0},
                      {nullptr, 0});

  // Verify that we can implement the given problem.
  ::cutlass::Status status = Dsd::KernelFn::can_implement(args);
  bool can_implement = status == ::cutlass::Status::kSuccess;
  can_implement &= a.block_size == BlockSize::k128;
  can_implement &= transpose_a && transpose_b;
  can_implement &= ValidMatmul(a, transpose_a, b, transpose_b, c);
  return can_implement;
}

cudaError_t launch_dsd_mixed_b128_128x128x32x5_tt_align8(
    const BlockMatrix a, bool transpose_a,
    const Matrix b, bool transpose_b,
    Matrix c, cudaStream_t stream) {
  SPUTNIK_CHECK(a.offsets_t);
  SPUTNIK_CHECK(a.indices_t);
  SPUTNIK_CHECK(a.block_offsets);

  // Produce the transpose meta-data.
  //
  // TODO(tgale): Add flag to BlockMatrix that indicates when
  // the meta-data is already set correctly so that we can
  // skip this stage for matrices that do not change.
  if (a.create_metadata) {
    cudaError_t custatus = Transpose(a, stream);
    if (custatus != cudaSuccess) {
      return custatus;
    }
  }

  using Dsd = Kernel<dsd_mixed_b128_128x128x32x5_tt_align8>;

  MatmulShape shape(a, transpose_a, b, transpose_b);
  Dsd::Arguments args({shape.m, shape.n, shape.k},
                      {1.0f, 0.0f},
                      {a.data,
                       a.offsets_t,
                       a.indices_t,
                       a.block_offsets,
                       shape.lda},
                      {b.data, shape.ldb},
                      {c.data, shape.ldc},
                      {c.data, shape.ldc});

  // Verify that we can implement the given problem.
  ::cutlass::Status status = Dsd::KernelFn::can_implement(args);
  if (status != ::cutlass::Status::kSuccess) {
    return cudaErrorNotSupported;
  }

  Dsd dsd_operator;
  return dsd_operator(args, stream);
}

}  // namespace cutlass
}  // namespace block
}  // namespace sputnik
