// Copyright (c) 2022 Giorgos Vougioukas
//
// The license can be found in the LICENSE file.

#include "terpsichore/antidote_input.h"
#include "terpsichore/definitions.h"
#include "terpsichore/antidote.h"

#include "WDL/fpcmp.h"
#include "WDL/pcmfmtcvt.h"

RST_AntidoteInput::RST_AntidoteInput()
  : m_fr(NULL)
  , m_eof(false)
  , m_position(0.0)
  , m_srate(44100.0)
  , m_antidote(0)
  , m_rev(false)
{}

RST_AntidoteInput::~RST_AntidoteInput()
{
  if (m_fr) { delete m_fr; m_fr = NULL; }
  g_antidote.ReleaseFilename(m_fn.Get());
}

bool RST_AntidoteInput::Open(const char *filename)
{
  m_antidote = g_ini_file->read_int("antidote", 0, "preferences");
  int readmode = g_ini_file->read_int("read_mode", 2, "preferences");
  int readbuffer = g_ini_file->read_int("read_buffer", 262144, "preferences");
  int readbuffers = g_ini_file->read_int("read_buffers", 3, "preferences");

  m_fn.Set(filename);
  m_srate = GetHardwareSampleRate();
  if (m_fr) { delete m_fr; m_fr = NULL; }
  m_fr = new WDL_FileRead(m_fn.Get(), readmode, readbuffer, readbuffers);
  if (m_fr) return m_fr->IsOpen();
  return false;
}

const char *RST_AntidoteInput::GetType()
{
  return "ANTIDOTE";
}

const char *RST_AntidoteInput::GetFileName()
{
  return m_fn.Get();
}

int RST_AntidoteInput::GetChannels()
{
  return 2;
}

double RST_AntidoteInput::GetSampleRate()
{
  return m_srate;
}

double RST_AntidoteInput::GetLength()
{
  double len = 0.0;

  if (m_fr)
  {
    len = m_fr->GetSize() / (GetBitsPerSample() / 8)  / GetSampleRate() / GetChannels();
  }

  return len;
}

int RST_AntidoteInput::GetBitsPerSample()
{
  int bps = 0;

  switch (m_antidote)
  {
    case 0: bps = 16; break;
    case 1: bps = 24; break;
    case 2: bps = 32; break;
    case 3: bps = 32; break;
    case 4: bps = 64; break;
  }

  return bps;
}

double RST_AntidoteInput::GetPosition()
{
  return m_position;
}

void RST_AntidoteInput::Seek(double time)
{
  double st = time;
  double du = GetLength();

  if (WDL_DefinitelyLessThan(st, 0.0)) { st = 0.0; m_eof = false; }
  else if (WDL_DefinitelyGreaterThan(st, du)) { st = du; m_eof = true; }
  else { m_eof = false; }

  m_buffer_queue.Clear();

  WDL_INT64 timems = (WDL_INT64)(time * 1000);
  int csrate = (int)GetSampleRate() / 1000;
  WDL_INT64 pos = (WDL_INT64)(timems * (csrate * GetChannels() * (GetBitsPerSample() / 8)));
  m_fr->SetPosition(pos);

  m_position = time;
}

bool RST_AntidoteInput::IsReverse() const
{
  return m_rev;
}

void RST_AntidoteInput::SetReverse(bool state)
{
  m_rev = state;
}

int RST_AntidoteInput::GetSamples(SAM *buffer, int length)
{
  int ret = 0;

  while (m_buffer_queue.Available() < length && !m_eof)
  {
    int read = 0;
    switch (m_antidote)
    {
    case 0:
      {
        if (WDL_likely(!m_rev))
        {
          m_rawbuf.Resize(4096 * sizeof(short));
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_buffer.Resize(read / sizeof(short));
          pcmToDoubles(m_rawbuf.Get(), m_buffer.GetSize(), 16, 1, m_buffer.Get(), 1);
        }
        else
        {
          m_rawbuf.Resize(4096 * sizeof(short));
          if (m_fr->GetPosition() < m_rawbuf.GetSize()) m_rawbuf.Resize((int)m_fr->GetPosition());
          m_fr->SetPosition(m_fr->GetPosition() - m_rawbuf.GetSize());
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_fr->SetPosition(m_fr->GetPosition() - read);
          m_buffer.Resize(read / sizeof(short));
          pcmToDoubles(m_rawbuf.Get(), m_buffer.GetSize(), 16, 1, m_buffer.Get(), 1);
        }
      } break;
    case 1:
      {
        if (WDL_likely(!m_rev))
        {
          m_rawbuf.Resize(4096 * sizeof(char) * 3);
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_buffer.Resize(read / (sizeof(char) * 3));
          pcmToDoubles(m_rawbuf.Get(), m_buffer.GetSize(), 24, 1, m_buffer.Get(), 1);
        }
        else
        {
          m_rawbuf.Resize(4096 * sizeof(char) * 3);
          if (m_fr->GetPosition() < m_rawbuf.GetSize()) m_rawbuf.Resize((int)m_fr->GetPosition());
          m_fr->SetPosition(m_fr->GetPosition() - m_rawbuf.GetSize());
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_fr->SetPosition(m_fr->GetPosition() - read);
          m_buffer.Resize(read / (sizeof(char) * 3));
          pcmToDoubles(m_rawbuf.Get(), m_buffer.GetSize(), 24, 1, m_buffer.Get(), 1);
        }
      } break;
    case 2:
      {
        if (WDL_likely(!m_rev))
        {
          m_rawbuf.Resize(4096 * sizeof(int));
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_buffer.Resize(read / sizeof(int));
          pcmToDoubles(m_rawbuf.Get(), m_buffer.GetSize(), 32, 1, m_buffer.Get(), 1);
        }
        else
        {
          m_rawbuf.Resize(4096 * sizeof(int));
          if (m_fr->GetPosition() < m_rawbuf.GetSize()) m_rawbuf.Resize((int)m_fr->GetPosition());
          m_fr->SetPosition(m_fr->GetPosition() - m_rawbuf.GetSize());
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_fr->SetPosition(m_fr->GetPosition() - read);
          m_buffer.Resize(read / sizeof(int));
          pcmToDoubles(m_rawbuf.Get(), m_buffer.GetSize(), 32, 1, m_buffer.Get(), 1);
        }
      } break;
    case 3:
      {
        if (WDL_likely(!m_rev))
        {
          m_rawbuf.Resize(4096 * sizeof(float));
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_buffer.Resize(read / sizeof(float));
          float *flt = (float *)m_rawbuf.Get();
          for (int i = 0; i < m_buffer.GetSize(); i++)
          {
            m_buffer.Get()[i] = (SAM)flt[i];
          }
        }
        else
        {
          m_rawbuf.Resize(4096 * sizeof(float));
          if (m_fr->GetPosition() < m_rawbuf.GetSize()) m_rawbuf.Resize((int)m_fr->GetPosition());
          m_fr->SetPosition(m_fr->GetPosition() - m_rawbuf.GetSize());
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_fr->SetPosition(m_fr->GetPosition() - read);
          m_buffer.Resize(read / sizeof(float));
          float *flt = (float *)m_rawbuf.Get();
          for (int i = 0; i < m_buffer.GetSize(); i++)
          {
            m_buffer.Get()[i] = (SAM)flt[i];
          }
        }
      } break;
    case 4:
      {
        if (WDL_likely(!m_rev))
        {
          m_rawbuf.Resize(4096 * sizeof(double));
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_buffer.Resize(read / sizeof(double));
          double *dbl = (double *)m_rawbuf.Get();
          for (int i = 0; i < m_buffer.GetSize(); i++)
          {
            m_buffer.Get()[i] = (SAM)dbl[i];
          }
        }
        else
        {
          m_rawbuf.Resize(4096 * sizeof(double));
          if (m_fr->GetPosition() < m_rawbuf.GetSize()) m_rawbuf.Resize((int)m_fr->GetPosition());
          m_fr->SetPosition(m_fr->GetPosition() - m_rawbuf.GetSize());
          read = m_fr->Read(m_rawbuf.Get(), m_rawbuf.GetSize());
          m_fr->SetPosition(m_fr->GetPosition() - read);
          m_buffer.Resize(read / sizeof(double));
          double *dbl = (double *)m_rawbuf.Get();
          for (int i = 0; i < m_buffer.GetSize(); i++)
          {
            m_buffer.Get()[i] = (SAM)dbl[i];
          }
        }
      } break;
    }

    if (WDL_unlikely(m_rev))
    {
      int nch = GetChannels();

      for (int ch = 0; ch < nch; ch++)
      {
        for (int lo = ch, hi = m_buffer.GetSize() - nch + ch; lo < hi; lo += nch, hi -= nch)
        {
          SAM tmp = m_buffer.Get()[lo];
          m_buffer.Get()[lo] = m_buffer.Get()[hi];
          m_buffer.Get()[hi] = tmp;
        }
      }
    }

    m_buffer_queue.Add(m_buffer.Get(), m_buffer.GetSize());
    if (m_fr->GetPosition() == m_fr->GetSize()) m_eof = true;
    if (m_fr->GetPosition() == 0 && m_rev) m_eof = true;
  }

  if (m_buffer_queue.Available() > length)
  {
    memcpy(buffer, m_buffer_queue.Get(), length * sizeof(SAM));
    ret = length;
    m_buffer_queue.Advance(length);
    m_buffer_queue.Compact();
  }
  else
  {
    memcpy(buffer, m_buffer_queue.Get(), m_buffer_queue.Available() * sizeof(SAM));
    ret = m_buffer_queue.Available();
    m_buffer_queue.Clear();
  }

  if (m_eof && !m_rev && (m_buffer_queue.GetSize() == 0))
  {
    m_position = GetLength();
  }
  else if (m_eof && m_rev && (m_buffer_queue.GetSize() == 0))
  {
    m_position = 0;
  }
  else
  {
    if (WDL_likely(!m_rev))
    {
      m_position += ret / GetSampleRate() / GetChannels();
    }
    else
    {
      m_position -= ret / GetSampleRate() / GetChannels();
    }
  }

  return ret;
}

bool RST_AntidoteInput::IsStreaming()
{
  return false;
}

int RST_AntidoteInput::Extended(int call, void *parm1, void *parm2, void *parm3)
{
  return 0;
}
