// Copyright (c) 2022 Giorgos Vougioukas
//
// The license can be found in the LICENSE file.

#include "thalia_wavpack/wavpack_input.h"
#include "thalia_wavpack/wavpack_entry_point.h"

#include "WDL/fpcmp.h"
#include "WDL/pcmfmtcvt.h"
#include "WDL/win32_utf8.h"

THA_WavPackInput::THA_WavPackInput()
  : m_file(NULL)
  , m_channels(2)
  , m_bitspersample(16)
  , m_samplerate(48000.0)
  , m_buffer(NULL)
  , m_totallength(0)
  , m_currentpos(0.0)
  , m_hwsamplerate(48000.0)
  , m_eof(false)
  , m_wvctx(NULL)
{}

THA_WavPackInput::~THA_WavPackInput()
{
  if (m_wvctx) WavpackCloseFile(m_wvctx);
  if (m_file) delete m_file;
}

bool THA_WavPackInput::Open(const char *filename)
{
  m_fn.Set(filename);

  m_hwsamplerate = THA_GetAudioDeviceSamplerate();

  int rmode, rbufsize, rnbufs;
  THA_GetDiskReadMode(&rmode, &rbufsize, &rnbufs);
  m_file = new WDL_FileRead(filename, rmode, rbufsize, rnbufs);

  if (!m_file->IsOpen())
  {
    return false;
  }

  char error[2048];
  int flags = OPEN_NORMALIZE; //OPEN_WVC | OPEN_DSD_AS_PCM | OPEN_TAGS | OPEN_DSD_NATIVE | OPEN_ALT_TYPES;
  //flags |= OPEN_WRAPPER;
#if _WIN32
  flags |= OPEN_FILE_UTF8;
#endif

  m_sr.can_seek = tha_can_seek;
  m_sr.close = tha_close;
  m_sr.get_length = tha_get_length;
  m_sr.get_pos = tha_get_pos;
  m_sr.push_back_byte = tha_push_back_byte;
  m_sr.read_bytes = tha_read_bytes;
  m_sr.set_pos_abs = tha_set_pos_abs;
  m_sr.set_pos_rel = tha_set_pos_rel;
  m_sr.truncate_here = tha_truncate_here;
  m_sr.write_bytes = tha_write_bytes;

  m_wvctx = WavpackOpenFileInput(m_fn.Get(), error, flags, 0);
  //m_wvctx = WavpackOpenFileInputEx64(&m_sr, m_file, NULL, error, flags, 0);

  if (!m_wvctx) return false;

  m_channels = WavpackGetNumChannels(m_wvctx);
  m_samplerate = (double)WavpackGetSampleRate(m_wvctx);
  m_bitspersample = WavpackGetBitsPerSample(m_wvctx);
  m_totallength = (WDL_INT64)WavpackGetNumSamples(m_wvctx);

  bool interp, sinc;
  int filtercnt, sinc_size, sinc_interpsize;
  THA_GetResampleMode(&interp, &filtercnt, &sinc, &sinc_size, &sinc_interpsize);
  m_rs.SetMode(interp, filtercnt, sinc, sinc_size, sinc_interpsize);

  return true;
}

const char *THA_WavPackInput::GetType() const
{
  return "WV";
}

const char *THA_WavPackInput::GetFileName() const
{
  return m_fn.Get();
}

int THA_WavPackInput::GetChannels() const
{
  return m_channels;
}

double THA_WavPackInput::GetSampleRate() const
{
  return m_samplerate;
}

double THA_WavPackInput::GetLength() const
{
  return m_totallength / m_samplerate;
}

int THA_WavPackInput::GetBitsPerSample() const
{
  return m_bitspersample;
}

double THA_WavPackInput::GetPosition() const
{
  return m_currentpos;
}

void THA_WavPackInput::Seek(double time)
{}

int THA_WavPackInput::GetSamples(SAM *buffer, int length)
{
  int ret = 0;

  while (m_samples.Available() == 0 || m_samples.Available() < length)
  {
    if (!m_eof)
    {
      ReadNext();
    }
    else
    {
      return 0;
    }
  }

  int copied = 0;

  if (m_samples.Available() < length)
  {
    memcpy(buffer, m_samples.Get(), sizeof(SAM) * m_samples.Available());
    copied = m_samples.Available();
    m_samples.Clear();
  }
  else
  {
    memcpy(buffer, m_samples.Get(), sizeof(SAM) * length);
    copied = length;
    m_samples.Advance(length);
    m_samples.Compact();
  }

  m_currentpos += copied / GetChannels() / m_hwsamplerate;

  return copied;
}

bool THA_WavPackInput::IsStreaming() const
{
  return false;
}

int THA_WavPackInput::Extended(int call, void *parm1, void *parm2, void *parm3)
{
  return -1;
}

void THA_WavPackInput::ReadNext()
{
  m_buffer.Resize(4096);

  int read = WavpackUnpackSamples(m_wvctx, m_buffer.Get(), m_buffer.GetSize() / m_channels);
  if (read)
  {
    m_buffer.Resize(read * m_channels, false);

    WDL_TypedBuf<SAM> tmp_samples;
    tmp_samples.Resize(read * m_channels);
    tmp_samples.SetToZero();

    if (MODE_FLOAT & WavpackGetMode(m_wvctx))
    {
      if (127 != WavpackGetFloatNormExp(m_wvctx)) return;

      for (int i = 0; i < tmp_samples.GetSize(); i++)
      {
        float *fl = (float *)m_buffer.Get();
        tmp_samples.Get()[i] = (SAM)fl[i];
        tmp_samples.Get()[i] = wdl_clamp(tmp_samples.Get()[i], -1.0, 1.0);
      }
    }
    else
    {
      double scale = (double)(1LL << WavpackGetBitsPerSample(m_wvctx));

      for (int i = 0; i < m_buffer.GetSize(); i++)
      {
        if (m_buffer.Get()[i] <= 0)
        {
          tmp_samples.Get()[i] = (double)(m_buffer.Get()[i] / (scale - 1));
        }
        else
        {
          tmp_samples.Get()[i] = (double)(m_buffer.Get()[i] / scale);
        }

        tmp_samples.Get()[i] = wdl_clamp(tmp_samples.Get()[i], -1.0, 1.0);
      }
    }

    if (WDL_ApproximatelyEqual(m_samplerate, m_hwsamplerate))
    {
      m_samples.Add(tmp_samples.Get(), tmp_samples.GetSize());
    }
    else
    {
      const int nch = m_channels;
      int frames = read;

      m_rs.SetRates(m_samplerate, m_hwsamplerate);
      m_rs.SetFeedMode(true);

      for (;;)
      {
        WDL_ResampleSample *ob = NULL;
        int amt = m_rs.ResamplePrepare(frames, nch, &ob);
        if (amt > frames) amt = frames;
        if (ob)
        {
          for (int i = 0; i < frames; i++)
          {
            for (int j = 0; j < nch; j++)
            {
              *ob++ = tmp_samples.Get()[i * nch + j];
            }
          }
        }
        frames -= amt;

        WDL_TypedBuf<WDL_ResampleSample> tmp;
        tmp.Resize(2048 * nch);
        amt = m_rs.ResampleOut(tmp.Get(), amt, 2048, nch);

        if (frames < 1 && amt < 1) break;

        amt *= nch;
        m_samples.Add(tmp.Get(), amt);
      }
    }
  }
  else
  {
    m_eof = true;
  }
}
