Skip to content

Commit

Permalink
release DLD code
Browse files Browse the repository at this point in the history
  • Loading branch information
qiaoliang6 committed Jul 13, 2022
1 parent 2bbb4c3 commit 89ff46f
Show file tree
Hide file tree
Showing 24 changed files with 2,369 additions and 7 deletions.
1 change: 1 addition & 0 deletions davarocr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .davar_nlp_common import *
from .davar_ner import *
from .davar_order import *
from .davar_distill import *
from .mmcv import *
from .version import __version__

Expand Down
13 changes: 13 additions & 0 deletions davarocr/davar_distill/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : __init__.py
# Abstract :
# Current Version: 1.0.0
# Date : 2022-07-07
##################################################################################################
"""
from .models import *
from .dataset import *
from .core import *
11 changes: 11 additions & 0 deletions davarocr/davar_distill/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : __init__.py
# Abstract :
# Current Version: 1.0.0
# Date : 2022-07-07
##################################################################################################
"""
from .beam_search import beam_decode
139 changes: 139 additions & 0 deletions davarocr/davar_distill/core/beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : beam_search.py
# Abstract : Beam search for attention decode
# Current Version: 1.0.0
# Date : 2022-07-07
##################################################################################################
"""
import torch
from queue import PriorityQueue


class BeamSearchNode(object):
""" Beam search node class """
def __init__(self, previous_node, char_id, logProb, length):
"""
Args:
previous_node (obj:`BeamSearchNode`): node in queue
char_id (dict): character id
logProb (float): word probability
length (int): word length
"""
self.prev_node = previous_node
self.char_id = char_id
self.logp = logProb
self.leng = length

def eval(self):
""" Calculate beam search path score
Returns:
float: beam search path score
"""
return self.logp / float(self.leng - 1 + 1e-6)

def __lt__(self, other):
"""
Args:
self (obj:`BeamSearchNode`): beam search node
other (obj:`BeamSearchNode`): beam search node
"""
if self.eval() < other.eval():
return False
else:
return True


def beam_decode(encoder_outputs, beam_width=5, topk=1):
""" Beam search decode
Args:
encoder_outputs (Tensor): encoder outputs tensor of shape [B, T, C]
where B is the batch size and T is the maximum length of the output sentence
beam_width (int): beam search width
topk (int): select top-k beam search result
Returns:
list(list(Tensor)): beam search decoded path
"""
decoded_batch = []

# decoding goes sentence by sentence
for idx in range(encoder_outputs.size(0)):
# Start with the start of the sentence token
decoder_input = torch.tensor([[0]], device=encoder_outputs.device).long()

# Number of sentence to generate
endnodes = []
number_required = min((topk + 1), topk - len(endnodes))

# starting node - previous node, char id, logp, length
node = BeamSearchNode(None, decoder_input, 0, 1)
nodes = PriorityQueue()

# start the queue
nodes.put(node)
qsize = 1

# start beam search
while True:
# give up when decoding takes too long
if qsize > 2000:
break

# fetch the best node
priority_node = nodes.get()
decoder_input = priority_node.char_id

if priority_node.char_id.item() == 1 and priority_node.prev_node != None:
endnodes.append(priority_node)
# if we reached maximum # of sentences required
if len(endnodes) >= number_required:
break
else:
continue

# PUT HERE REAL BEAM SEARCH OF TOP
log_prob, indexes = torch.topk(encoder_outputs[idx][priority_node.leng-1], beam_width)
nextnodes = []

for new_k in range(beam_width):
decoded_t = indexes[new_k].view(1, -1)
log_p = log_prob[new_k].item()

node = BeamSearchNode(priority_node, decoded_t, priority_node.logp + log_p, priority_node.leng + 1)
# score = -node.eval()
nextnodes.append(node)

# put them into queue
for i in range(len(nextnodes)):
nextnode = nextnodes[i]
nodes.put(nextnode)
# increase qsize
qsize += len(nextnodes) - 1

# choose nbest paths, back trace them
if len(endnodes) == 0:
endnodes = [nodes.get() for _ in range(topk)]

utterances = []
for endnode in sorted(endnodes, key=lambda x: x.eval()):
utterance = []
utterance.append(endnode.char_id)
# back trace
while endnode.prev_node != None:
endnode = endnode.prev_node
utterance.append(endnode.char_id)

utterance = utterance[::-1]
utterances.append(utterance)

stack_utterances = []
for path_id in range(len(utterances)):
stack_utterances.append(torch.stack(utterances[path_id], dim=-1).squeeze(0).squeeze(0))
decoded_batch.append(stack_utterances)

return decoded_batch
13 changes: 13 additions & 0 deletions davarocr/davar_distill/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : __init__.py
# Abstract :
# Current Version: 1.0.0
# Date : 2022-07-07
##################################################################################################
"""
from .pipelines import DistillFormatBundle

__all__ = ['DistillFormatBundle']
13 changes: 13 additions & 0 deletions davarocr/davar_distill/dataset/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : __init__.py
# Abstract :
# Current Version: 1.0.0
# Date : 2022-07-07
##################################################################################################
"""
from .distill_formating import DistillFormatBundle

__all__ = ['DistillFormatBundle']
59 changes: 59 additions & 0 deletions davarocr/davar_distill/dataset/pipelines/distill_formating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : distill_formating.py
# Abstract : Definition of data formating process for knowledge distillation
# Current Version: 1.0.0
# Date : 2022-07-07
##################################################################################################
"""
import numpy as np
from mmcv.parallel import DataContainer as DC

from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.formating import to_tensor, DefaultFormatBundle


@PIPELINES.register_module()
class DistillFormatBundle(DefaultFormatBundle):
""" The common data format pipeline used by DavarCustom dataset. including,
(1) transferred into Tensor (2) contained by DataContainer (3) put on device (GPU|CPU)
- keys in ['img', 'gt_semantic_seg'] will be transferred into Tensor and put on GPU
- keys in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore','gt_labels', 'stn_params']
will be transferred into Tensor
- keys in ['gt_masks', 'gt_poly_bboxes', 'gt_poly_bboxes_ignore', 'gt_cbboxes',
'gt_cbboxes_ignore', 'gt_texts', 'gt_text'] will be put on CPU
"""

def __call__(self, results):
for key in ['img', 'hr_img']:
if key in results:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
results[key] = DC(to_tensor(img), stack=True)

for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', 'stn_params']:
if key in results:
results[key] = DC(to_tensor(results[key]))
ori_key = 'hr_' + key
if ori_key in results:
results[ori_key] = DC(to_tensor(results[ori_key]))

if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = DC(
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)

# Updated keys by DavarCustom dataset
for key in ['gt_masks', 'gt_poly_bboxes', 'gt_poly_bboxes_ignore', 'gt_cbboxes',
'gt_cbboxes_ignore', 'gt_texts', 'gt_text', 'array_gt_texts', 'gt_bieo_labels']:
if key in results:
results[key] = DC(results[key], cpu_only=True)
ori_key = 'hr_' + key
if ori_key in results:
results[ori_key] = DC(results[ori_key], cpu_only=True)

return results
13 changes: 13 additions & 0 deletions davarocr/davar_distill/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : __init__.py
# Abstract :
# Current Version: 1.0.0
# Date : 2022-07-07
##################################################################################################
"""
from .connect import ResolutionSelector
from .distillation import SpotResolutionDistillation
from .spotters import KDTwoStageEndToEnd
13 changes: 13 additions & 0 deletions davarocr/davar_distill/models/connect/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : __init__.py
# Abstract :
# Current Version: 1.0.0
# Date : 2022-07-07
##################################################################################################
"""
from .resolution_selector import ResolutionSelector

__all__ = ['ResolutionSelector']
Loading

0 comments on commit 89ff46f

Please sign in to comment.