// Copyright (c) 2018 Giorgos Vougioukas
//
// The license can be found in the LICENSE file.

#include "rsi_flac/flac_input.h"
#include "rsi_flac/flac_entry_point.h"

#include "WDL/fpcmp.h"
#include "WDL/pcmfmtcvt.h"

RSI_FlacInput::RSI_FlacInput()
  : m_decoder(NULL)
  , m_file(NULL)
  , m_channels(2)
  , m_bitspersample(16)
  , m_samplerate(48000.0)
  , m_totallength(0)
  , m_currentpos(0.0)
  , m_hwsamplerate(48000.0)
{}

RSI_FlacInput::~RSI_FlacInput()
{
  if (m_decoder)
  {
    FLAC__stream_decoder_finish(m_decoder);
    FLAC__stream_decoder_delete(m_decoder);
    m_decoder = NULL;
  }

  if (m_file) delete m_file;
}

bool RSI_FlacInput::Open(const char *filename)
{
  m_fn.Set(filename);

  m_hwsamplerate = RSI_GetAudioDeviceSamplerate();

  int rmode, rbufsize, rnbufs;
  RSI_GetDiskReadMode(&rmode, &rbufsize, &rnbufs);
  m_file = new WDL_FileRead(filename, rmode, rbufsize, rnbufs);

  if (!m_file->IsOpen())
  {
    return false;
  }

  m_decoder = FLAC__stream_decoder_new();

  if (WDL_NOT_NORMALLY(!m_decoder))
  {
    wdl_log("flac decoder cannot be allocated\n");
    return false;
  }

  int ret = FLAC__stream_decoder_init_stream(m_decoder, read_callback,
    seek_callback, tell_callback, length_callback, eof_callback,
    write_callback, metadata_callback, error_callback, this);

  if (WDL_NOT_NORMALLY(ret != FLAC__STREAM_DECODER_INIT_STATUS_OK))
  {
    wdl_log("flac stream cannot be initialized\n");
    return false;
  }

  FLAC__bool success = FLAC__stream_decoder_process_until_end_of_metadata(m_decoder);

  if (WDL_NOT_NORMALLY(!success))
  {
    wdl_log("flac metadata cannot be read\n");
    return false;
  }

  bool interp, sinc;
  int filtercnt, sinc_size, sinc_interpsize;
  RSI_GetResampleMode(&interp, &filtercnt, &sinc, &sinc_size, &sinc_interpsize);
  m_rs.SetMode(interp, filtercnt, sinc, sinc_size, sinc_interpsize);

  return true;
}

const char *RSI_FlacInput::GetType() const
{
  return "FLAC";
}

const char *RSI_FlacInput::GetFileName() const
{
  return m_fn.Get();
}

int RSI_FlacInput::GetChannels() const
{
  return 2; //m_channels;
}

double RSI_FlacInput::GetSampleRate() const
{
  return m_samplerate;
}

double RSI_FlacInput::GetLength() const
{
  return m_totallength / m_samplerate;
}

int RSI_FlacInput::GetBitsPerSample() const
{
  return m_bitspersample;
}

double RSI_FlacInput::GetPosition() const
{
  return m_currentpos;
}

void RSI_FlacInput::Seek(double time)
{
  FLAC__uint64 pos = (FLAC__uint64)(time * GetSampleRate());

  FLAC__stream_decoder_flush(m_decoder);

  if (FLAC__stream_decoder_seek_absolute(m_decoder, pos))
  {
    m_samples.Clear();
    m_currentpos = time;
  }
}

int RSI_FlacInput::GetSamples(SAM *buffer, int length)
{
  int ret = 0;

  if (m_samples.Available() == 0 || m_samples.Available() < length)
  {
    if (FLAC__stream_decoder_get_state(m_decoder) != FLAC__STREAM_DECODER_END_OF_STREAM)
    {
      ret = FLAC__stream_decoder_process_single(m_decoder);

      if (WDL_NOT_NORMALLY(!ret))
      {
        wdl_log("flac cannot decode single block\n");
      }
    }
    else
    {
      return 0;
    }
  }

  int copied = 0;

  if (m_samples.Available() < length)
  {
    memcpy(buffer, m_samples.Get(), sizeof(SAM) * m_samples.Available());
    copied = m_samples.Available();
    m_samples.Clear();
  }
  else
  {
    memcpy(buffer, m_samples.Get(), sizeof(SAM) * length);
    copied = length;
    m_samples.Advance(length);
    m_samples.Compact();
  }

  m_currentpos += copied / GetChannels() / m_hwsamplerate;

  return copied;
}

bool RSI_FlacInput::IsStreaming() const
{
  return false;
}

int RSI_FlacInput::Extended(int call, void *parm1, void *parm2, void *parm3)
{
  return -1;
}

FLAC__StreamDecoderReadStatus RSI_FlacInput::read_callback(
  const FLAC__StreamDecoder *m_decoder, FLAC__byte buffer[],
  size_t *bytes, void *client_data)
{
  RSI_FlacInput *cd = (RSI_FlacInput *)client_data;

  if (!cd)
  {
    return FLAC__STREAM_DECODER_READ_STATUS_ABORT;
  }

  *bytes = cd->m_file->Read(buffer, (int)*bytes); // fix me

  return FLAC__STREAM_DECODER_READ_STATUS_CONTINUE;
}

FLAC__StreamDecoderSeekStatus RSI_FlacInput::seek_callback(
  const FLAC__StreamDecoder *m_decoder, FLAC__uint64 absolute_byte_offset,
  void *client_data)
{
  RSI_FlacInput *cd = (RSI_FlacInput *)client_data;

  if (!cd)
  {
    return FLAC__STREAM_DECODER_SEEK_STATUS_ERROR;
  }

  int ret = cd->m_file->SetPosition(absolute_byte_offset);

  if (ret < 0)
  {
    return FLAC__STREAM_DECODER_SEEK_STATUS_ERROR;
  }

  return FLAC__STREAM_DECODER_SEEK_STATUS_OK;
}

FLAC__StreamDecoderTellStatus RSI_FlacInput::tell_callback(
  const FLAC__StreamDecoder *m_decoder, FLAC__uint64 *absolute_byte_offset,
  void *client_data)
{
  RSI_FlacInput *cd = (RSI_FlacInput *)client_data;

  if (!cd)
  {
    return FLAC__STREAM_DECODER_TELL_STATUS_ERROR;
  }

  *absolute_byte_offset = (FLAC__uint64)cd->m_file->GetPosition();

  return FLAC__STREAM_DECODER_TELL_STATUS_OK;
}

FLAC__StreamDecoderLengthStatus RSI_FlacInput::length_callback(
  const FLAC__StreamDecoder *m_decoder, FLAC__uint64 *stream_length,
  void *client_data)
{
  RSI_FlacInput *cd = (RSI_FlacInput *)client_data;

  if (!cd)
  {
    return FLAC__STREAM_DECODER_LENGTH_STATUS_ERROR;
  }

  *stream_length = (FLAC__uint64)cd->m_file->GetSize();

  return FLAC__STREAM_DECODER_LENGTH_STATUS_OK;
}

FLAC__bool RSI_FlacInput::eof_callback(const FLAC__StreamDecoder *m_decoder,
    void *client_data)
{
  RSI_FlacInput *cd = (RSI_FlacInput *)client_data;

  if (!cd)
  {
    return true;
  }

  if (cd->m_file->GetPosition() == cd->m_file->GetSize())
  {
    return true;
  }

  return false;
}

FLAC__StreamDecoderWriteStatus RSI_FlacInput::write_callback(
  const FLAC__StreamDecoder *m_decoder, const FLAC__Frame *frame,
  const FLAC__int32 * const buffer[], void *client_data)
{
  RSI_FlacInput *cd = (RSI_FlacInput *)client_data;

  if (!cd)
  {
    return FLAC__STREAM_DECODER_WRITE_STATUS_ABORT;
  }

  const int bps = (int)frame->header.bits_per_sample;
  const int samples = (int)frame->header.blocksize;
  int channels = (int)frame->header.channels;

  const int items = samples * 2;
  const int bytes_per_sample = bps / 8;

  //cd->m_buffer.Resize(items * bytes_per_sample);
  cd->m_buffer.Resize(items * sizeof(SAM));

  unsigned char *ptr = (unsigned char *)cd->m_buffer.Get();
  FLAC__int16 *ptr16 = (FLAC__int16 *)ptr;
  FLAC__int32 *ptr32 = (FLAC__int32 *)ptr;

  if (channels > 2) channels = 2;

  // Convert from planar to scalar format
  switch (bytes_per_sample)
  {
  case 2:
    for (int sample = 0; sample < samples; sample++)
    {
      if (channels == 1)
      {
        for (int channel = 0; channel < channels; channel++)
        {
          ptr16[sample * 2 + 0] =
            (FLAC__int16)buffer[channel][sample];
          ptr16[sample * 2 + 1] =
            (FLAC__int16)buffer[channel][sample];
        }
      }
      else
      {
        for (int channel = 0; channel < channels; channel++)
        {
          ptr16[sample * channels + channel] =
            (FLAC__int16)buffer[channel][sample];
        }
      }
    }
    break;
  case 3:
    for (int sample = 0; sample < samples; sample++)
    {
      if (channels == 1)
      {
        for (int channel = 0; channel < channels; channel++)
        {
          ptr[3] = 0;
          ptr[2] = (buffer[channel][sample] >> 16) & 0xff;
          ptr[1] = (buffer[channel][sample] >> 8) & 0xff;
          ptr[0] = (buffer[channel][sample] >> 0) & 0xff;
          ptr += bytes_per_sample;
          ptr[3] = 0;
          ptr[2] = (buffer[channel][sample] >> 16) & 0xff;
          ptr[1] = (buffer[channel][sample] >> 8) & 0xff;
          ptr[0] = (buffer[channel][sample] >> 0) & 0xff;
          ptr += bytes_per_sample;
        }
      }
      else
      {
        for (int channel = 0; channel < channels; channel++)
        {
          ptr[3] = 0;
          ptr[2] = (buffer[channel][sample] >> 16) & 0xff;
          ptr[1] = (buffer[channel][sample] >> 8) & 0xff;
          ptr[0] = (buffer[channel][sample] >> 0) & 0xff;
          ptr += bytes_per_sample;
        }
      }
    }
    break;
  case 4:
    for (int sample = 0; sample < samples; sample++)
    {
      if (channels == 1)
      {
        for (int channel = 0; channel < channels; channel++)
        {
          ptr32[sample * 2 + 0] =
            (FLAC__int32)buffer[channel][sample];
          ptr32[sample * 2 + 1] =
            (FLAC__int32)buffer[channel][sample];
        }
      }
      else
      {
        for (int channel = 0; channel < channels; channel++)
        {
          ptr32[sample * channels + channel] =
            (FLAC__int32)buffer[channel][sample];
        }
      }
    }
    break;
  }

  ptr = (unsigned char *)cd->m_buffer.Get();

#if 1
  WDL_TypedBuf<SAM> tmp_samples;
  tmp_samples.Resize(items);

#if (RSI_SAMPLE_PRECISION == 8)
  pcmToDoubles(ptr, items, bps, 1, tmp_samples.Get(), 1);
#else
  pcmToFloats(ptr, items, bps, 1, tmp_samples.Get(), 1);
#endif

  if (WDL_ApproximatelyEqual(cd->m_samplerate, cd->m_hwsamplerate))
  {
    cd->m_samples.Add(tmp_samples.Get(), tmp_samples.GetSize());
  }
  else
  {
    const int nch = 2;
    int frames = samples;

    //SAM **samples = (SAM **)malloc(nch * sizeof(SAM *));

    //for (int i = 0; i < nch; i++)
    //{
    //  samples[i] = (SAM *)malloc(frames * sizeof(SAM));

    //  for (int j = 0; j < frames; j++)
    //  {
    //    samples[i][j] = tmp_samples.Get()[j * nch + i];
    //  }
    //}

    cd->m_rs.SetRates(cd->m_samplerate, cd->m_hwsamplerate);
    cd->m_rs.SetFeedMode(true);

    for (;;)
    {
      WDL_ResampleSample *ob = NULL;
      int amt = cd->m_rs.ResamplePrepare(frames, nch, &ob);
      if (amt > frames) amt = frames;
      if (ob) //cd->splcvt(ob, samples, nch, nch, 1, frames);
      {
        for (int i = 0; i < frames; i++)
        {
          for (int j = 0; j < nch; j++)
          {
            *ob++ = tmp_samples.Get()[i * nch + j];
          }
        }
      }
      frames -= amt;

      WDL_TypedBuf<WDL_ResampleSample> tmp;
      tmp.Resize(2048 * nch);
      amt = cd->m_rs.ResampleOut(tmp.Get(), amt, 2048, nch);

      if (frames < 1 && amt < 1) break;

      amt *= nch;
      cd->m_samples.Add(tmp.Get(), amt);
    }

    //for (int i = 0; i < nch; i++)
    //{
    //  free (samples[i]);
    //}

    //free (samples);
  }
#endif

  return FLAC__STREAM_DECODER_WRITE_STATUS_CONTINUE;
}

void RSI_FlacInput::metadata_callback(const FLAC__StreamDecoder *m_decoder,
  const FLAC__StreamMetadata *metadata, void *client_data)
{
  RSI_FlacInput *cd = (RSI_FlacInput *)client_data;

  if (!cd)
  {
    return;
  }

  cd->m_samplerate = (double)metadata->data.stream_info.sample_rate;
  cd->m_bitspersample = (int)metadata->data.stream_info.bits_per_sample;
  cd->m_channels = (int)metadata->data.stream_info.channels;
  cd->m_totallength = (WDL_INT64)metadata->data.stream_info.total_samples;
}

void RSI_FlacInput::error_callback(const FLAC__StreamDecoder *m_decoder,
  FLAC__StreamDecoderErrorStatus status, void *client_data)
{
  wdl_log("flac callback error: %d\n", status);
}
