Skip to content
This repository has been archived by the owner on Jun 10, 2024. It is now read-only.

Commit

Permalink
Cleaning up seek code.
Browse files Browse the repository at this point in the history
Replacing pts calculation based on duration
with that based on fps for higher precision
  • Loading branch information
rarzumanyan committed Jan 26, 2022
1 parent ec453ad commit e7ea766
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 112 deletions.
4 changes: 4 additions & 0 deletions PyNvCodec/TC/inc/FFmpegDemuxer.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ class DllExport FFmpegDemuxer

double GetTimebase() const;

int64_t TsFromTime(double ts_sec);

int64_t TsFromFrameNumber(int64_t frame_num);

uint32_t GetVideoStreamIndex() const;

AVPixelFormat GetPixelFormat() const;
Expand Down
2 changes: 2 additions & 0 deletions PyNvCodec/TC/inc/Tasks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class DllExport DemuxFrame final : public Task
DemuxFrame& operator=(const DemuxFrame& other) = delete;

void GetParams(struct MuxingParams& params) const;
int64_t TsFromTime(double ts_sec);
int64_t TsFromFrameNumber(int64_t frame_num);
void Flush();
TaskExecStatus Run() final;
~DemuxFrame() final;
Expand Down
102 changes: 60 additions & 42 deletions PyNvCodec/TC/src/FFmpegDemuxer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,67 +236,86 @@ void FFmpegDemuxer::Flush() {
avformat_flush(fmtc);
}

bool FFmpegDemuxer::Seek(SeekContext &seekCtx, uint8_t *&pVideo,
size_t &rVideoBytes, PacketData &pktData,
uint8_t **ppSEI, size_t *pSEIBytes) {
int64_t FFmpegDemuxer::TsFromTime(double ts_sec)
{
/* Internal timestamp representation is integer, so multiply to AV_TIME_BASE
* and switch to fixed point precision arithmetics; */
auto const ts_tbu = lround(ts_sec * AV_TIME_BASE);

// Rescale the timestamp to value represented in stream base units;
AVRational factor;
factor.num = 1;
factor.den = AV_TIME_BASE;
return av_rescale_q(ts_tbu, factor, fmtc->streams[videoStream]->time_base);
}

int64_t FFmpegDemuxer::TsFromFrameNumber(int64_t frame_num)
{
auto const ts_sec = (double)frame_num / GetFramerate();
return TsFromTime(ts_sec);
}

bool FFmpegDemuxer::Seek(SeekContext& seekCtx, uint8_t*& pVideo,
size_t& rVideoBytes, PacketData& pktData,
uint8_t** ppSEI, size_t* pSEIBytes)
{
/* !!! IMPORTANT !!!
* Across this function packet decode timestamp (DTS) values are used to
* compare given timestamp against. This is done for reason. DTS values shall
* monotonically increase during the course of decoding unlike PTS velues
* which may be affected by frame reordering due to B frames presence.
*/

if (!is_seekable) {
cerr << "Seek isn't supported for this input." << endl;
return false;
}

// Convert timestamp in time units to timestamp in stream base units;
auto ts_from_time = [&](double ts_sec) {
auto const ts_tbu = (int64_t)(ts_sec * AV_TIME_BASE);
AVRational factor;
factor.num = 1;
factor.den = AV_TIME_BASE;
return av_rescale_q(ts_tbu, factor, fmtc->streams[videoStream]->time_base);
};

// Convert frame number to timestamp;
auto ts_from_num = [&](int64_t frame_num) {
auto const ts_sec = (double)seekCtx.seek_frame / GetFramerate();
return ts_from_time(ts_sec);
};
if (IsVFR() && (BY_NUMBER == seekCtx.crit)) {
cerr << "Can't seek by frame number in VFR sequences. Seek by timestamp "
"instead."
<< endl;
return false;
}

// Seek for single frame;
auto seek_frame = [&](SeekContext const &seek_ctx, int flags) {
auto seek_frame = [&](SeekContext const& seek_ctx, int flags) {
bool seek_backward = false;
int64_t timestamp = 0;
int ret = 0;

switch (seek_ctx.crit) {
case BY_NUMBER:
seek_backward =
last_packet_data.dts > seek_ctx.seek_frame * pktDst.duration;
ret = av_seek_frame(fmtc, GetVideoStreamIndex(),
ts_from_num(seek_ctx.seek_frame),
timestamp = TsFromFrameNumber(seek_ctx.seek_frame);
seek_backward = last_packet_data.dts > timestamp;
ret = av_seek_frame(fmtc, GetVideoStreamIndex(), timestamp,
seek_backward ? AVSEEK_FLAG_BACKWARD | flags : flags);
if (ret < 0)
throw runtime_error("Error seeking for frame: " + AvErrorToString(ret));
break;
case BY_TIMESTAMP:
seek_backward =
last_packet_data.dts > seek_ctx.seek_frame;
ret = av_seek_frame(fmtc, GetVideoStreamIndex(),
ts_from_time(seek_ctx.seek_frame),
timestamp = TsFromTime(seek_ctx.seek_frame);
seek_backward = last_packet_data.dts > timestamp;
ret = av_seek_frame(fmtc, GetVideoStreamIndex(), timestamp,
seek_backward ? AVSEEK_FLAG_BACKWARD | flags : flags);
break;
default:
throw runtime_error("Invalid seek mode");
}
return;

if (ret < 0) {
throw runtime_error("Error seeking for frame: " + AvErrorToString(ret));
}
};

// Check if frame satisfies seek conditions;
auto is_seek_done = [&](PacketData &pkt_data, SeekContext const &seek_ctx) {
auto is_seek_done = [&](PacketData& pkt_data, SeekContext const& seek_ctx) {
int64_t target_ts = 0;

switch (seek_ctx.crit) {
case BY_NUMBER:
target_ts = ts_from_num(seek_ctx.seek_frame);
target_ts = TsFromFrameNumber(seek_ctx.seek_frame);
break;
case BY_TIMESTAMP:
target_ts = ts_from_time(seek_ctx.seek_frame);
target_ts = TsFromTime(seek_ctx.seek_frame);
break;
default:
throw runtime_error("Invalid seek criteria");
Expand All @@ -312,17 +331,18 @@ bool FFmpegDemuxer::Seek(SeekContext &seekCtx, uint8_t *&pVideo,
};
};

// This will seek for exact frame number;
// Note that decoder may not be able to decode such frame;
auto seek_for_exact_frame = [&](PacketData &pkt_data,
SeekContext &seek_ctx) {
/* This will seek for exact frame number;
* Note that decoder may not be able to decode such frame; */
auto seek_for_exact_frame = [&](PacketData& pkt_data, SeekContext& seek_ctx) {
// Repetititive seek until seek condition is satisfied;
SeekContext tmp_ctx(seek_ctx.seek_frame);
seek_frame(tmp_ctx, AVSEEK_FLAG_ANY);

int condition = 0;
do {
Demux(pVideo, rVideoBytes, pkt_data, ppSEI, pSEIBytes);
if (!Demux(pVideo, rVideoBytes, pkt_data, ppSEI, pSEIBytes)) {
break;
}
condition = is_seek_done(pkt_data, seek_ctx);

// We've gone too far and need to seek backwards;
Expand All @@ -341,11 +361,9 @@ bool FFmpegDemuxer::Seek(SeekContext &seekCtx, uint8_t *&pVideo,
};

// Seek for closest key frame in the past;
auto seek_for_prev_key_frame = [&](PacketData &pkt_data,
SeekContext &seek_ctx) {
// Repetititive seek until seek condition is satisfied;
auto tmp_ctx = seek_ctx;
seek_frame(tmp_ctx, AVSEEK_FLAG_BACKWARD);
auto seek_for_prev_key_frame = [&](PacketData& pkt_data,
SeekContext& seek_ctx) {
seek_frame(seek_ctx, AVSEEK_FLAG_BACKWARD);

Demux(pVideo, rVideoBytes, pkt_data, ppSEI, pSEIBytes);
seek_ctx.out_frame_pts = pkt_data.pts;
Expand Down
15 changes: 13 additions & 2 deletions PyNvCodec/TC/src/NvDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ struct Dim {

struct NvDecoderImpl {
bool m_bReconfigExternal = false, m_bReconfigExtPPChange = false,
eos_set = false;
eos_set = false, decoder_recon = false;

unsigned int m_nWidth = 0U, m_nLumaHeight = 0U, m_nChromaHeight = 0U,
m_nNumChromaPlanes = 0U, m_nMaxWidth = 0U, m_nMaxHeight = 0U;
Expand Down Expand Up @@ -195,6 +195,7 @@ cudaVideoCodec NvDecoder::GetCodec() const { return p_impl->m_eCodec; }
int NvDecoder::HandleVideoSequence(CUVIDEOFORMAT* pVideoFormat) noexcept
{
try {
p_impl->decoder_recon = true;
CudaCtxPush ctxPush(p_impl->m_cuContext);
CudaStrSync strSync(p_impl->m_cuvidStream);

Expand Down Expand Up @@ -753,9 +754,19 @@ bool NvDecoder::DecodeLockSurface(Buffer const* encFrame,
*/
auto ret = false;

// Prepare black packet data in case no frames are decoded yet;
// Prepare blank packet data in case no frames are decoded yet;
memset(&decCtx.out_pdata, 0, sizeof(decCtx.out_pdata));

/* In case decoder was reconfigured by cuvidParseVideoData() call made above,
* some previously decoded frames could have been pushed to decoded frames
* queue. Need to clean them up; */
if (p_impl->decoder_recon) {
p_impl->decoder_recon = false;
while (!p_impl->m_DecFramesCtxQueue.empty()) {
p_impl->m_DecFramesCtxQueue.pop();
}
}

if (!p_impl->m_DecFramesCtxQueue.empty()) {
decCtx = p_impl->m_DecFramesCtxQueue.front();
p_impl->m_DecFramesCtxQueue.pop();
Expand Down
10 changes: 10 additions & 0 deletions PyNvCodec/TC/src/Tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,16 @@ DemuxFrame::~DemuxFrame() { delete pImpl; }

void DemuxFrame::Flush() { pImpl->demuxer->Flush(); }

int64_t DemuxFrame::TsFromTime(double ts_sec)
{
return pImpl->demuxer->TsFromTime(ts_sec);
}

int64_t DemuxFrame::TsFromFrameNumber(int64_t frame_num)
{
return pImpl->demuxer->TsFromFrameNumber(frame_num);
}

TaskExecStatus DemuxFrame::Run()
{
NvtxMark tick(GetName());
Expand Down
32 changes: 9 additions & 23 deletions PyNvCodec/src/PyNvDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,7 @@ bool PyNvDecoder::DecodeSurface(DecodeContext& ctx)
Surface* p_surf = nullptr;
do {
try {
auto const no_eos = true;
p_surf = getDecodedSurfaceFromPacket(nullptr, nullptr, no_eos);
p_surf = getDecodedSurfaceFromPacket(nullptr, nullptr);
} catch (decoder_error& dec_exc) {
dec_error = true;
cerr << dec_exc.what() << endl;
Expand Down Expand Up @@ -521,8 +520,7 @@ bool PyNvDecoder::DecodeSurface(DecodeContext& ctx)
ctx.SetOutPacketData(pktDataBuf->GetDataAs<PacketData>());
}

auto is_seek_done = [&](DecodeContext const& ctx, double time_base,
double duration, double pts) {
auto is_seek_done = [&](DecodeContext const& ctx, int64_t pts) {
auto seek_ctx = ctx.GetSeekContext();
if (!seek_ctx)
throw runtime_error("No seek context.");
Expand All @@ -531,13 +529,11 @@ bool PyNvDecoder::DecodeSurface(DecodeContext& ctx)

switch (seek_ctx->crit) {
case BY_NUMBER:
seek_pts = seek_ctx->seek_frame * duration;
seek_pts = upDemuxer->TsFromFrameNumber(seek_ctx->seek_frame);
break;

case BY_TIMESTAMP:
seek_pts = seek_ctx->seek_frame / time_base;
seek_pts = upDemuxer->TsFromTime(seek_ctx->seek_frame);
break;

default:
throw runtime_error("Invalid seek criteria.");
break;
Expand All @@ -549,22 +545,12 @@ bool PyNvDecoder::DecodeSurface(DecodeContext& ctx)
/* Check if seek is done. */
if (!use_seek) {
loop_end = true;
} else {
MuxingParams params;
upDemuxer->GetParams(params);

if (pktDataBuf) {
auto out_pkt_data = pktDataBuf->GetDataAs<PacketData>();
if (AV_NOPTS_VALUE == out_pkt_data->pts) {
throw runtime_error(
"Decoded frame doesn't have valid PTS, can't seek.");
}
if (!out_pkt_data->duration) {
throw runtime_error("Decoded frames has zero duration, can't seek.");
}
loop_end = is_seek_done(ctx, params.videoContext.timeBase,
out_pkt_data->duration, out_pkt_data->pts);
} else if (pktDataBuf) {
auto out_pkt_data = pktDataBuf->GetDataAs<PacketData>();
if (AV_NOPTS_VALUE == out_pkt_data->pts) {
throw runtime_error("Decoded frame doesn't have PTS, can't seek.");
}
loop_end = is_seek_done(ctx, out_pkt_data->pts);
}

if (dmx_error) {
Expand Down
Loading

0 comments on commit e7ea766

Please sign in to comment.