// Copyright (c) 2022 Giorgos Vougioukas
//
// The license can be found in the LICENSE file.

#include "rhea_vorbis/vorbis_input.h"
#include "rhea_vorbis/vorbis_entry_point.h"

#include "WDL/fpcmp.h"
#include "WDL/pcmfmtcvt.h"
#include "WDL/win32_utf8.h"

RHEA_VorbisInput::RHEA_VorbisInput()
  : m_oy(NULL)
  , m_os(NULL)
  , m_og(NULL)
  , m_op(NULL)
  , m_vi(NULL)
  , m_vc(NULL)
  , m_vd(NULL)
  , m_vb(NULL)
  , m_file(NULL)
  , m_channels(2)
  , m_bitspersample(16)
  , m_samplerate(48000.0)
  , m_buffer(NULL)
  , m_totallength(0)
  , m_currentpos(0.0)
  , m_hwsamplerate(48000.0)
  , m_eof(false)
{}

RHEA_VorbisInput::~RHEA_VorbisInput()
{
  // ogg_page and ogg_packet structs always point to storage in
  // libvorbis. They're never freed or manipulated directly


  if (m_vb) { vorbis_block_clear(m_vb); delete m_vb; }
  if (m_vd) { vorbis_dsp_clear(m_vd); delete m_vd; }

  if (m_os) { ogg_stream_clear(m_os); delete m_os; }
  if (m_vc) { vorbis_comment_clear(m_vc); delete m_vc; }
  if (m_vi) { vorbis_info_clear(m_vi); delete m_vi; } // must be called last

  // OK, clean up the framer
  if (m_oy) { ogg_sync_clear(m_oy); delete m_oy; }

  if (m_og) delete m_og;
  if (m_op) delete m_op;

  if (m_file) delete m_file;
}

bool RHEA_VorbisInput::Open(const char *filename)
{
  m_fn.Set(filename);

  m_hwsamplerate = RHEA_GetAudioDeviceSamplerate();

  int rmode, rbufsize, rnbufs;
  RHEA_GetDiskReadMode(&rmode, &rbufsize, &rnbufs);
  m_file = new WDL_FileRead(filename, rmode, rbufsize, rnbufs);

  if (!m_file->IsOpen())
  {
    return false;
  }

  if (!m_oy) m_oy = new ogg_sync_state;
  ogg_sync_init(m_oy);

  m_buffer = ogg_sync_buffer(m_oy, 4096);
  int bytes = m_file->Read(m_buffer, 4096);
  ogg_sync_wrote(m_oy, bytes);

  // Get the first page.
  if (!m_og) m_og = new ogg_page;
  if(ogg_sync_pageout(m_oy, m_og) != 1)
  {
    // have we simply run out of data?  If so, we're done.
    //if (bytes < 4096) break;

    // error case.  Must not be Vorbis data
    wdl_log("Input does not appear to be an Ogg bitstream.\n");
    return false;
  }

  // Get the serial number and set up the rest of decode.
  // serialno first; use it to set up a logical stream
  if (!m_os) m_os = new ogg_stream_state;
  ogg_stream_init(m_os, ogg_page_serialno(m_og));


  // extract the initial header from the first page and verify that the
  // Ogg bitstream is in fact Vorbis data */

  // I handle the initial header first instead of just having the code
  // read all three Vorbis headers at once because reading the initial
  // header is an easy way to identify a Vorbis bitstream and it's
  // useful to see that functionality seperated out.

  if (!m_vi) m_vi = new vorbis_info;
  vorbis_info_init(m_vi);
  if (!m_vc) m_vc = new vorbis_comment;
  vorbis_comment_init(m_vc);
  if (ogg_stream_pagein(m_os, m_og) < 0)
  {
    // error; stream version mismatch perhaps
    wdl_log("Error reading first page of Ogg bitstream data.\n");
    return false;
  }

  if (!m_op) m_op = new ogg_packet;
  if(ogg_stream_packetout(m_os, m_op) != 1)
  {
    // no page? must not be vorbis
    wdl_log("Error reading initial header packet.\n");
    return false;
  }

  if(vorbis_synthesis_headerin(m_vi, m_vc, m_op) < 0)
  {
    // error case; not a vorbis header
    wdl_log("This Ogg bitstream does not contain Vorbis audio data.\n");
    return false;
  }

  // At this point, we're sure we're Vorbis. We've set up the logical
  // (Ogg) bitstream decoder. Get the comment and codebook headers and
  // set up the Vorbis decoder

  // The next two packets in order are the comment and codebook headers.
  // They're likely large and may span multiple pages. Thus we read
  // and submit data until we get our two packets, watching that no
  // pages are missing. If a page is missing, error out; losing a
  // header page is the only place where missing data is fatal.

  int i = 0;
  while (i < 2)
  {
    while(i < 2)
    {
      int result = ogg_sync_pageout(m_oy, m_og);
      if (result == 0) break; // Need more data
      // Don't complain about missing or corrupt data yet. We'll
      // catch it at the packet output phase
      if (result == 1)
      {
        ogg_stream_pagein(m_os, m_og); // we can ignore any errors here as they'll also become apparent at packetout
        while ( i < 2)
        {
          result=ogg_stream_packetout(m_os, m_op);
          if (result == 0) break;
          if (result < 0)
          {
            // Uh oh; data at some point was corrupted or missing!
            // We can't tolerate that in a header.
            wdl_log("Corrupt secondary header.\n");
            return false;
          }
          result = vorbis_synthesis_headerin(m_vi, m_vc, m_op);
          if (result < 0)
          {
            wdl_log("Corrupt secondary header.\n");
            return false;
          }
          i++;
        }
      }
    }

    // no harm in not checking before adding more 
    m_buffer = ogg_sync_buffer(m_oy, 4096);
    bytes = m_file->Read(m_buffer, 4096);
    if (bytes == 0 && i < 2)
    {
      wdl_log("End of file before finding all Vorbis headers!\n");
      return false;
    }
    ogg_sync_wrote(m_oy, bytes);
  }

  // Throw the comments plus a few lines about the bitstream we're decoding
  {
    char **ptr = m_vc->user_comments;
    while(*ptr)
    {
      wdl_log("%s\n",*ptr);
      ++ptr;
    }
    wdl_log("\nBitstream is %d channel, %ldHz\n", m_vi->channels, m_vi->rate);
    wdl_log("Encoded by: %s\n\n", m_vc->vendor);
  }

  m_channels = m_vi->channels;
  m_samplerate = m_vi->rate;
  m_bitspersample = 16;

  FILE *file = fopen(m_fn.Get(), "rb");
  if (ov_open_callbacks(file, &m_vf, NULL, 0, OV_CALLBACKS_NOCLOSE) < 0)
  {
    fclose(file);
    return false;
  }
  m_totallength = ov_pcm_total(&m_vf, -1);
  ov_clear(&m_vf);
  fclose(file);

  // OK, got and parsed all three headers. Initialize the Vorbis
  // packet->PCM decoder.
  if (!m_vd) m_vd = new vorbis_dsp_state;
  if(vorbis_synthesis_init(m_vd, m_vi) != 0)
  {
    return false;
  }

  // local state for most of the decode
  // so multiple block decodes can
  // proceed in parallel. We could init
  // multiple vorbis_block structures
  // for vd here
  if (!m_vb) m_vb = new vorbis_block;
  vorbis_block_init(m_vd, m_vb);

  bool interp, sinc;
  int filtercnt, sinc_size, sinc_interpsize;
  RHEA_GetResampleMode(&interp, &filtercnt, &sinc, &sinc_size, &sinc_interpsize);
  m_rs.SetMode(interp, filtercnt, sinc, sinc_size, sinc_interpsize);

  return true;
}

const char *RHEA_VorbisInput::GetType() const
{
  return "OGG";
}

const char *RHEA_VorbisInput::GetFileName() const
{
  return m_fn.Get();
}

int RHEA_VorbisInput::GetChannels() const
{
  return m_channels;
}

double RHEA_VorbisInput::GetSampleRate() const
{
  return m_samplerate;
}

double RHEA_VorbisInput::GetLength() const
{
  return m_totallength / m_samplerate;
}

int RHEA_VorbisInput::GetBitsPerSample() const
{
  return m_bitspersample;
}

double RHEA_VorbisInput::GetPosition() const
{
  return m_currentpos;
}

void RHEA_VorbisInput::Seek(double time)
{}

int RHEA_VorbisInput::GetSamples(SAM *buffer, int length)
{
  int ret = 0;

  while (m_samples.Available() == 0 || m_samples.Available() < length)
  {
    if (!m_eof)
    {
      ReadNext();
    }
    else
    {
      return 0;
    }
  }

  int copied = 0;

  if (m_samples.Available() < length)
  {
    memcpy(buffer, m_samples.Get(), sizeof(SAM) * m_samples.Available());
    copied = m_samples.Available();
    m_samples.Clear();
  }
  else
  {
    memcpy(buffer, m_samples.Get(), sizeof(SAM) * length);
    copied = length;
    m_samples.Advance(length);
    m_samples.Compact();
  }

  m_currentpos += copied / GetChannels() / m_hwsamplerate;

  return copied;
}

bool RHEA_VorbisInput::IsStreaming() const
{
  return false;
}

int RHEA_VorbisInput::Extended(int call, void *parm1, void *parm2, void *parm3)
{
  return -1;
}

void RHEA_VorbisInput::ReadNext()
{
  int convsize = 4096 / m_vi->channels;

  // The rest is just a straight decode loop until end of stream

  while(!m_eof)
  {
    int result = ogg_sync_pageout(m_oy, m_og);

    if (result == 0) break; /* need more data */

    if (result < 0)
    {
      // missing or corrupt data at this page position
      wdl_log("Corrupt or missing data in bitstream; continuing...\n");
    }
    else
    {
      ogg_stream_pagein(m_os, m_og); // can safely ignore errors at this point
      while(1)
      {
        result = ogg_stream_packetout(m_os, m_op);

        if (result == 0) break; /* need more data */
        if (result < 0)
        {
          // missing or corrupt data at this page position
          // no reason to complain; already complained above
        }
        else
        {
          // we have a packet.  Decode it
          float **pcm;
          int samples;

          if (vorbis_synthesis(m_vb, m_op) == 0) // test for success!
            vorbis_synthesis_blockin(m_vd, m_vb);

          // **pcm is a multichannel float vector. In stereo, for
          // example, pcm[0] is left, and pcm[1] is right. samples is
          // the size of each channel. Convert the float values
          // (-1. <= range <= 1.) to whatever PCM format and write it out

          while ((samples = vorbis_synthesis_pcmout(m_vd, &pcm)) > 0)
          {
            //int j;
            //int clipflag = 0;
            int bout = (samples < convsize ? samples : convsize);

            WDL_TypedBuf<SAM> tmp_samples;
            tmp_samples.Resize(bout * m_vi->channels);

            for (int i = 0; i < m_vi->channels; i++)
            {
              for (int j = 0; j < bout; j++)
              {
                tmp_samples.Get()[j * m_vi->channels + i] = (SAM)pcm[i][j];
              }
            }

            if (WDL_ApproximatelyEqual(m_samplerate, m_hwsamplerate))
            {
              m_samples.Add(tmp_samples.Get(), tmp_samples.GetSize());
            }
            else
            {
              const int nch = m_channels;
              int frames = samples;

              m_rs.SetRates(m_samplerate, m_hwsamplerate);
              m_rs.SetFeedMode(true);

              for (;;)
              {
                WDL_ResampleSample *ob = NULL;
                int amt = m_rs.ResamplePrepare(frames, nch, &ob);
                if (amt > frames) amt = frames;
                if (ob)
                {
                  for (int i = 0; i < frames; i++)
                  {
                    for (int j = 0; j < nch; j++)
                    {
                      *ob++ = tmp_samples.Get()[i * nch + j];
                    }
                  }
                }
                frames -= amt;

                WDL_TypedBuf<WDL_ResampleSample> tmp;
                tmp.Resize(2048 * nch);
                amt = m_rs.ResampleOut(tmp.Get(), amt, 2048, nch);

                if (frames < 1 && amt < 1) break;

                amt *= nch;
                m_samples.Add(tmp.Get(), amt);
              }
            }

            vorbis_synthesis_read(m_vd, bout); // tell libvorbis how many samples we actually consumed
          }
        }
      }
      if (ogg_page_eos(m_og)) m_eof = true;
    }
  }

  if (!m_eof)
  {
    m_buffer = ogg_sync_buffer(m_oy, 4096);
    int bytes = m_file->Read(m_buffer, 4096);
    ogg_sync_wrote(m_oy, bytes);
    if (bytes == 0) m_eof = true;
  }
}
