diff --git a/PyNvCodec/TC/inc/FFmpegDemuxer.h b/PyNvCodec/TC/inc/FFmpegDemuxer.h index 3a1ea41c..1ff2ea41 100644 --- a/PyNvCodec/TC/inc/FFmpegDemuxer.h +++ b/PyNvCodec/TC/inc/FFmpegDemuxer.h @@ -222,6 +222,10 @@ class DllExport FFmpegDemuxer double GetTimebase() const; + int64_t TsFromTime(double ts_sec); + + int64_t TsFromFrameNumber(int64_t frame_num); + uint32_t GetVideoStreamIndex() const; AVPixelFormat GetPixelFormat() const; diff --git a/PyNvCodec/TC/inc/Tasks.hpp b/PyNvCodec/TC/inc/Tasks.hpp index fb9ad026..8c73f835 100644 --- a/PyNvCodec/TC/inc/Tasks.hpp +++ b/PyNvCodec/TC/inc/Tasks.hpp @@ -221,6 +221,8 @@ class DllExport DemuxFrame final : public Task DemuxFrame& operator=(const DemuxFrame& other) = delete; void GetParams(struct MuxingParams& params) const; + int64_t TsFromTime(double ts_sec); + int64_t TsFromFrameNumber(int64_t frame_num); void Flush(); TaskExecStatus Run() final; ~DemuxFrame() final; diff --git a/PyNvCodec/TC/src/FFmpegDemuxer.cpp b/PyNvCodec/TC/src/FFmpegDemuxer.cpp index 449e6128..1addc1a4 100644 --- a/PyNvCodec/TC/src/FFmpegDemuxer.cpp +++ b/PyNvCodec/TC/src/FFmpegDemuxer.cpp @@ -236,67 +236,86 @@ void FFmpegDemuxer::Flush() { avformat_flush(fmtc); } -bool FFmpegDemuxer::Seek(SeekContext &seekCtx, uint8_t *&pVideo, - size_t &rVideoBytes, PacketData &pktData, - uint8_t **ppSEI, size_t *pSEIBytes) { +int64_t FFmpegDemuxer::TsFromTime(double ts_sec) +{ + /* Internal timestamp representation is integer, so multiply to AV_TIME_BASE + * and switch to fixed point precision arithmetics; */ + auto const ts_tbu = lround(ts_sec * AV_TIME_BASE); + + // Rescale the timestamp to value represented in stream base units; + AVRational factor; + factor.num = 1; + factor.den = AV_TIME_BASE; + return av_rescale_q(ts_tbu, factor, fmtc->streams[videoStream]->time_base); +} + +int64_t FFmpegDemuxer::TsFromFrameNumber(int64_t frame_num) +{ + auto const ts_sec = (double)frame_num / GetFramerate(); + return TsFromTime(ts_sec); +} + +bool FFmpegDemuxer::Seek(SeekContext& seekCtx, uint8_t*& pVideo, + size_t& rVideoBytes, PacketData& pktData, + uint8_t** ppSEI, size_t* pSEIBytes) +{ + /* !!! IMPORTANT !!! + * Across this function packet decode timestamp (DTS) values are used to + * compare given timestamp against. This is done for reason. DTS values shall + * monotonically increase during the course of decoding unlike PTS velues + * which may be affected by frame reordering due to B frames presence. + */ + if (!is_seekable) { cerr << "Seek isn't supported for this input." << endl; return false; } - // Convert timestamp in time units to timestamp in stream base units; - auto ts_from_time = [&](double ts_sec) { - auto const ts_tbu = (int64_t)(ts_sec * AV_TIME_BASE); - AVRational factor; - factor.num = 1; - factor.den = AV_TIME_BASE; - return av_rescale_q(ts_tbu, factor, fmtc->streams[videoStream]->time_base); - }; - - // Convert frame number to timestamp; - auto ts_from_num = [&](int64_t frame_num) { - auto const ts_sec = (double)seekCtx.seek_frame / GetFramerate(); - return ts_from_time(ts_sec); - }; + if (IsVFR() && (BY_NUMBER == seekCtx.crit)) { + cerr << "Can't seek by frame number in VFR sequences. Seek by timestamp " + "instead." + << endl; + return false; + } // Seek for single frame; - auto seek_frame = [&](SeekContext const &seek_ctx, int flags) { + auto seek_frame = [&](SeekContext const& seek_ctx, int flags) { bool seek_backward = false; + int64_t timestamp = 0; int ret = 0; switch (seek_ctx.crit) { case BY_NUMBER: - seek_backward = - last_packet_data.dts > seek_ctx.seek_frame * pktDst.duration; - ret = av_seek_frame(fmtc, GetVideoStreamIndex(), - ts_from_num(seek_ctx.seek_frame), + timestamp = TsFromFrameNumber(seek_ctx.seek_frame); + seek_backward = last_packet_data.dts > timestamp; + ret = av_seek_frame(fmtc, GetVideoStreamIndex(), timestamp, seek_backward ? AVSEEK_FLAG_BACKWARD | flags : flags); - if (ret < 0) - throw runtime_error("Error seeking for frame: " + AvErrorToString(ret)); break; case BY_TIMESTAMP: - seek_backward = - last_packet_data.dts > seek_ctx.seek_frame; - ret = av_seek_frame(fmtc, GetVideoStreamIndex(), - ts_from_time(seek_ctx.seek_frame), + timestamp = TsFromTime(seek_ctx.seek_frame); + seek_backward = last_packet_data.dts > timestamp; + ret = av_seek_frame(fmtc, GetVideoStreamIndex(), timestamp, seek_backward ? AVSEEK_FLAG_BACKWARD | flags : flags); break; default: throw runtime_error("Invalid seek mode"); } - return; + + if (ret < 0) { + throw runtime_error("Error seeking for frame: " + AvErrorToString(ret)); + } }; // Check if frame satisfies seek conditions; - auto is_seek_done = [&](PacketData &pkt_data, SeekContext const &seek_ctx) { + auto is_seek_done = [&](PacketData& pkt_data, SeekContext const& seek_ctx) { int64_t target_ts = 0; switch (seek_ctx.crit) { case BY_NUMBER: - target_ts = ts_from_num(seek_ctx.seek_frame); + target_ts = TsFromFrameNumber(seek_ctx.seek_frame); break; case BY_TIMESTAMP: - target_ts = ts_from_time(seek_ctx.seek_frame); + target_ts = TsFromTime(seek_ctx.seek_frame); break; default: throw runtime_error("Invalid seek criteria"); @@ -312,17 +331,18 @@ bool FFmpegDemuxer::Seek(SeekContext &seekCtx, uint8_t *&pVideo, }; }; - // This will seek for exact frame number; - // Note that decoder may not be able to decode such frame; - auto seek_for_exact_frame = [&](PacketData &pkt_data, - SeekContext &seek_ctx) { + /* This will seek for exact frame number; + * Note that decoder may not be able to decode such frame; */ + auto seek_for_exact_frame = [&](PacketData& pkt_data, SeekContext& seek_ctx) { // Repetititive seek until seek condition is satisfied; SeekContext tmp_ctx(seek_ctx.seek_frame); seek_frame(tmp_ctx, AVSEEK_FLAG_ANY); int condition = 0; do { - Demux(pVideo, rVideoBytes, pkt_data, ppSEI, pSEIBytes); + if (!Demux(pVideo, rVideoBytes, pkt_data, ppSEI, pSEIBytes)) { + break; + } condition = is_seek_done(pkt_data, seek_ctx); // We've gone too far and need to seek backwards; @@ -341,11 +361,9 @@ bool FFmpegDemuxer::Seek(SeekContext &seekCtx, uint8_t *&pVideo, }; // Seek for closest key frame in the past; - auto seek_for_prev_key_frame = [&](PacketData &pkt_data, - SeekContext &seek_ctx) { - // Repetititive seek until seek condition is satisfied; - auto tmp_ctx = seek_ctx; - seek_frame(tmp_ctx, AVSEEK_FLAG_BACKWARD); + auto seek_for_prev_key_frame = [&](PacketData& pkt_data, + SeekContext& seek_ctx) { + seek_frame(seek_ctx, AVSEEK_FLAG_BACKWARD); Demux(pVideo, rVideoBytes, pkt_data, ppSEI, pSEIBytes); seek_ctx.out_frame_pts = pkt_data.pts; diff --git a/PyNvCodec/TC/src/NvDecoder.cpp b/PyNvCodec/TC/src/NvDecoder.cpp index f8b18763..43749a97 100644 --- a/PyNvCodec/TC/src/NvDecoder.cpp +++ b/PyNvCodec/TC/src/NvDecoder.cpp @@ -145,7 +145,7 @@ struct Dim { struct NvDecoderImpl { bool m_bReconfigExternal = false, m_bReconfigExtPPChange = false, - eos_set = false; + eos_set = false, decoder_recon = false; unsigned int m_nWidth = 0U, m_nLumaHeight = 0U, m_nChromaHeight = 0U, m_nNumChromaPlanes = 0U, m_nMaxWidth = 0U, m_nMaxHeight = 0U; @@ -195,6 +195,7 @@ cudaVideoCodec NvDecoder::GetCodec() const { return p_impl->m_eCodec; } int NvDecoder::HandleVideoSequence(CUVIDEOFORMAT* pVideoFormat) noexcept { try { + p_impl->decoder_recon = true; CudaCtxPush ctxPush(p_impl->m_cuContext); CudaStrSync strSync(p_impl->m_cuvidStream); @@ -753,9 +754,19 @@ bool NvDecoder::DecodeLockSurface(Buffer const* encFrame, */ auto ret = false; - // Prepare black packet data in case no frames are decoded yet; + // Prepare blank packet data in case no frames are decoded yet; memset(&decCtx.out_pdata, 0, sizeof(decCtx.out_pdata)); + /* In case decoder was reconfigured by cuvidParseVideoData() call made above, + * some previously decoded frames could have been pushed to decoded frames + * queue. Need to clean them up; */ + if (p_impl->decoder_recon) { + p_impl->decoder_recon = false; + while (!p_impl->m_DecFramesCtxQueue.empty()) { + p_impl->m_DecFramesCtxQueue.pop(); + } + } + if (!p_impl->m_DecFramesCtxQueue.empty()) { decCtx = p_impl->m_DecFramesCtxQueue.front(); p_impl->m_DecFramesCtxQueue.pop(); diff --git a/PyNvCodec/TC/src/Tasks.cpp b/PyNvCodec/TC/src/Tasks.cpp index 7f4d5a01..e14922b2 100644 --- a/PyNvCodec/TC/src/Tasks.cpp +++ b/PyNvCodec/TC/src/Tasks.cpp @@ -875,6 +875,16 @@ DemuxFrame::~DemuxFrame() { delete pImpl; } void DemuxFrame::Flush() { pImpl->demuxer->Flush(); } +int64_t DemuxFrame::TsFromTime(double ts_sec) +{ + return pImpl->demuxer->TsFromTime(ts_sec); +} + +int64_t DemuxFrame::TsFromFrameNumber(int64_t frame_num) +{ + return pImpl->demuxer->TsFromFrameNumber(frame_num); +} + TaskExecStatus DemuxFrame::Run() { NvtxMark tick(GetName()); diff --git a/PyNvCodec/src/PyNvDecoder.cpp b/PyNvCodec/src/PyNvDecoder.cpp index 60444ab8..e574e456 100644 --- a/PyNvCodec/src/PyNvDecoder.cpp +++ b/PyNvCodec/src/PyNvDecoder.cpp @@ -469,8 +469,7 @@ bool PyNvDecoder::DecodeSurface(DecodeContext& ctx) Surface* p_surf = nullptr; do { try { - auto const no_eos = true; - p_surf = getDecodedSurfaceFromPacket(nullptr, nullptr, no_eos); + p_surf = getDecodedSurfaceFromPacket(nullptr, nullptr); } catch (decoder_error& dec_exc) { dec_error = true; cerr << dec_exc.what() << endl; @@ -521,8 +520,7 @@ bool PyNvDecoder::DecodeSurface(DecodeContext& ctx) ctx.SetOutPacketData(pktDataBuf->GetDataAs()); } - auto is_seek_done = [&](DecodeContext const& ctx, double time_base, - double duration, double pts) { + auto is_seek_done = [&](DecodeContext const& ctx, int64_t pts) { auto seek_ctx = ctx.GetSeekContext(); if (!seek_ctx) throw runtime_error("No seek context."); @@ -531,13 +529,11 @@ bool PyNvDecoder::DecodeSurface(DecodeContext& ctx) switch (seek_ctx->crit) { case BY_NUMBER: - seek_pts = seek_ctx->seek_frame * duration; + seek_pts = upDemuxer->TsFromFrameNumber(seek_ctx->seek_frame); break; - case BY_TIMESTAMP: - seek_pts = seek_ctx->seek_frame / time_base; + seek_pts = upDemuxer->TsFromTime(seek_ctx->seek_frame); break; - default: throw runtime_error("Invalid seek criteria."); break; @@ -549,22 +545,12 @@ bool PyNvDecoder::DecodeSurface(DecodeContext& ctx) /* Check if seek is done. */ if (!use_seek) { loop_end = true; - } else { - MuxingParams params; - upDemuxer->GetParams(params); - - if (pktDataBuf) { - auto out_pkt_data = pktDataBuf->GetDataAs(); - if (AV_NOPTS_VALUE == out_pkt_data->pts) { - throw runtime_error( - "Decoded frame doesn't have valid PTS, can't seek."); - } - if (!out_pkt_data->duration) { - throw runtime_error("Decoded frames has zero duration, can't seek."); - } - loop_end = is_seek_done(ctx, params.videoContext.timeBase, - out_pkt_data->duration, out_pkt_data->pts); + } else if (pktDataBuf) { + auto out_pkt_data = pktDataBuf->GetDataAs(); + if (AV_NOPTS_VALUE == out_pkt_data->pts) { + throw runtime_error("Decoded frame doesn't have PTS, can't seek."); } + loop_end = is_seek_done(ctx, out_pkt_data->pts); } if (dmx_error) { diff --git a/Tests/test_PyNvDecoder.py b/Tests/test_PyNvDecoder.py index cd6d4153..e04a36c9 100644 --- a/Tests/test_PyNvDecoder.py +++ b/Tests/test_PyNvDecoder.py @@ -53,10 +53,10 @@ gt_pix_fmt = nvc.PixelFormat.NV12 gt_framerate = 30 gt_num_frames = 96 +gt_timebase = 8.1380e-5 gt_color_space = nvc.ColorSpace.BT_709 gt_color_range = nvc.ColorRange.MPEG - class TestDecoderBasic(unittest.TestCase): def __init__(self, methodName): super().__init__(methodName=methodName) @@ -94,7 +94,6 @@ def test_framesize(self): def test_timebase(self): epsilon = 1e-4 - gt_timebase = 8.1380e-5 self.assertLessEqual( np.abs(gt_timebase - self.nvDec.Timebase()), epsilon) @@ -109,73 +108,66 @@ def test_lastpacketdata(self): class TestDecoderStandalone(unittest.TestCase): def __init__(self, methodName): super().__init__(methodName=methodName) - gpu_id = 0 - enc_file = gt_file - self.nvDmx = nvc.PyFFmpegDemuxer(enc_file, {}) - self.nvDec = nvc.PyNvDecoder( - self.nvDmx.Width(), self.nvDmx.Height(), self.nvDmx.Format(), - self.nvDmx.Codec(), gpu_id) def test_decodesurfacefrompacket(self): + nvDmx = nvc.PyFFmpegDemuxer(gt_file, {}) + nvDec = nvc.PyNvDecoder(nvDmx.Width(), nvDmx.Height(), nvDmx.Format(), + nvDmx.Codec(), 0) + packet = np.ndarray(shape=(0), dtype=np.uint8) - while self.nvDmx.DemuxSinglePacket(packet): - surf = self.nvDec.DecodeSurfaceFromPacket(packet) + while nvDmx.DemuxSinglePacket(packet): + surf = nvDec.DecodeSurfaceFromPacket(packet) self.assertIsNotNone(surf) if not surf.Empty(): self.assertNotEqual(0, surf.PlanePtr().GpuMem()) - self.assertEqual(self.nvDmx.Width(), surf.Width()) - self.assertEqual(self.nvDmx.Height(), surf.Height()) - self.assertEqual(self.nvDmx.Format(), surf.Format()) + self.assertEqual(nvDmx.Width(), surf.Width()) + self.assertEqual(nvDmx.Height(), surf.Height()) + self.assertEqual(nvDmx.Format(), surf.Format()) return def test_decodesurfacefrompacket_outpktdata(self): + nvDmx = nvc.PyFFmpegDemuxer(gt_file, {}) + nvDec = nvc.PyNvDecoder( + nvDmx.Width(), nvDmx.Height(), nvDmx.Format(), nvDmx.Codec(), 0) + + dec_frames = 0 packet = np.ndarray(shape=(0), dtype=np.uint8) - in_pdata = nvc.PacketData() - last_pts = nvc.NO_PTS - # Decoded frames counter - dec_frame = 0 - # Size of Annex.B elementary bitstream in bytes we feed to decoder - inp_bst_size = 0 - # Size of Annex.B elementary bitstream in bytes decoder has consumed - # It may be smaller then input size, because some NALU are not VCL out_bst_size = 0 - while self.nvDmx.DemuxSinglePacket(packet): - self.nvDmx.LastPacketData(in_pdata) - inp_bst_size += packet.size + while nvDmx.DemuxSinglePacket(packet): + in_pdata = nvc.PacketData() + nvDmx.LastPacketData(in_pdata) out_pdata = nvc.PacketData() - surf = self.nvDec.DecodeSurfaceFromPacket( + surf = nvDec.DecodeSurfaceFromPacket( in_pdata, packet, out_pdata) self.assertIsNotNone(surf) if not surf.Empty(): - dec_frame += 1 + dec_frames += 1 out_bst_size += out_pdata.bsl - else: - break - if 0 != dec_frame: - self.assertGreaterEqual(out_pdata.pts, last_pts) - last_pts = out_pdata.pts while True: out_pdata = nvc.PacketData() - surf = self.nvDec.FlushSingleSurface(out_pdata) + surf = nvDec.FlushSingleSurface(out_pdata) if not surf.Empty(): out_bst_size += out_pdata.bsl else: break self.assertNotEqual(0, out_bst_size) - self.assertGreaterEqual(inp_bst_size, out_bst_size) def test_decode_all_surfaces(self): + nvDmx = nvc.PyFFmpegDemuxer(gt_file, {}) + nvDec = nvc.PyNvDecoder(nvDmx.Width(), nvDmx.Height(), nvDmx.Format(), + nvDmx.Codec(), 0) + dec_frames = 0 packet = np.ndarray(shape=(0), dtype=np.uint8) - while self.nvDmx.DemuxSinglePacket(packet): - surf = self.nvDec.DecodeSurfaceFromPacket(packet) + while nvDmx.DemuxSinglePacket(packet): + surf = nvDec.DecodeSurfaceFromPacket(packet) self.assertIsNotNone(surf) if not surf.Empty(): dec_frames += 1 while True: - surf = self.nvDec.FlushSingleSurface() + surf = nvDec.FlushSingleSurface() self.assertIsNotNone(surf) if not surf.Empty(): dec_frames += 1 @@ -187,24 +179,28 @@ def test_decode_all_surfaces(self): class TestDecoderBuiltin(unittest.TestCase): def __init__(self, methodName): super().__init__(methodName=methodName) - gpu_id = 0 - enc_file = gt_file - self.nvDec = nvc.PyNvDecoder(enc_file, gpu_id) def test_decodesinglesurface(self): + gpu_id = 0 + enc_file = gt_file + nvDec = nvc.PyNvDecoder(enc_file, gpu_id) try: - surf = self.nvDec.DecodeSingleSurface() + surf = nvDec.DecodeSingleSurface() self.assertIsNotNone(surf) self.assertFalse(surf.Empty()) except: self.fail("Test case raised exception unexpectedly!") def test_decodesinglesurface_outpktdata(self): + gpu_id = 0 + enc_file = gt_file + nvDec = nvc.PyNvDecoder(enc_file, gpu_id) + dec_frame = 0 last_pts = nvc.NO_PTS while True: pdata = nvc.PacketData() - surf = self.nvDec.DecodeSingleSurface(pdata) + surf = nvDec.DecodeSingleSurface(pdata) if surf.Empty(): break self.assertNotEqual(pdata.pts, nvc.NO_PTS) @@ -214,33 +210,74 @@ def test_decodesinglesurface_outpktdata(self): last_pts = pdata.pts def test_decodesinglesurface_sei(self): + gpu_id = 0 + enc_file = gt_file + nvDec = nvc.PyNvDecoder(enc_file, gpu_id) + total_sei_size = 0 while True: sei = np.ndarray(shape=(0), dtype=np.uint8) - surf = self.nvDec.DecodeSingleSurface(sei) + surf = nvDec.DecodeSingleSurface(sei) if surf.Empty(): break total_sei_size += sei.size self.assertNotEqual(0, total_sei_size) def test_decodesinglesurface_seek(self): + gpu_id = 0 + enc_file = gt_file + nvDec = nvc.PyNvDecoder(enc_file, gpu_id) + start_frame = random.randint(0, gt_num_frames-1) dec_frames = 1 seek_ctx = nvc.SeekContext( seek_frame=start_frame, seek_criteria=nvc.SeekCriteria.BY_NUMBER) - surf = self.nvDec.DecodeSingleSurface(seek_ctx) + surf = nvDec.DecodeSingleSurface(seek_ctx) self.assertNotEqual(True, surf.Empty()) while True: - surf = self.nvDec.DecodeSingleSurface() + surf = nvDec.DecodeSingleSurface() if surf.Empty(): break dec_frames += 1 self.assertEqual(gt_num_frames-start_frame, dec_frames) + def test_decodesinglesurface_cmp_vs_continuous(self): + gpu_id = 0 + enc_file = gt_file + nvDec = nvc.PyNvDecoder(enc_file, gpu_id) + + # First get reconstructed frame with seek + for idx in range(0, gt_num_frames): + seek_ctx = nvc.SeekContext( + seek_frame=idx, seek_criteria=nvc.SeekCriteria.BY_NUMBER) + frame_seek = np.ndarray(shape=(0), dtype=np.uint8) + pdata_seek = nvc.PacketData() + self.assertTrue(nvDec.DecodeSingleFrame( + frame_seek, seek_ctx, pdata_seek)) + + # Then get it with continuous decoding + nvDec = nvc.PyNvDecoder(gt_file, 0) + frame_cont = np.ndarray(shape=(0), dtype=np.uint8) + pdata_cont = nvc.PacketData() + for i in range(0, idx+1): + self.assertTrue(nvDec.DecodeSingleFrame( + frame_cont, pdata_cont)) + + # Compare frames + if not np.array_equal(frame_seek, frame_cont): + fail_msg = "" + fail_msg += 'Seek frame number: ' + str(idx) + '.\n' + fail_msg += 'Seek frame pts: ' + str(pdata_seek.pts) + '.\n' + fail_msg += 'Cont frame pts: ' + str(pdata_cont.pts) + '.\n' + fail_msg += 'Video frames are not same\n' + self.fail(fail_msg) + def test_decode_all_surfaces(self): + nvDec = nvc.PyNvDecoder(gt_file, 0) + dec_frames = 0 while True: - surf = self.nvDec.DecodeSingleSurface() + surf = nvDec.DecodeSingleSurface() if not surf or surf.Empty(): break dec_frames += 1