Skip to content

Commit

Permalink
Dev for ext (DeepLink-org#469)
Browse files Browse the repository at this point in the history
add diopi interface for InternLM

---------

Co-authored-by: gqw <[email protected]>
Co-authored-by: Gong-air <[email protected]>
Co-authored-by: Lingjie Li <[email protected]>
Co-authored-by: hellozmz <[email protected]>
Co-authored-by: xue-yun-liang <[email protected]>
Co-authored-by: jfxu-st <[email protected]>
Co-authored-by: liwenjian-sensetime <[email protected]>
Co-authored-by: POI-WX <[email protected]>
Co-authored-by: yangbo <[email protected]>
Co-authored-by: yangbo1 <[email protected]>
Co-authored-by: Zhangzefeng <[email protected]>
Co-authored-by: Chengyuan Li <[email protected]>
Co-authored-by: lhy <[email protected]>
Co-authored-by: yangbofun <[email protected]>
Co-authored-by: wanglei <[email protected]>
Co-authored-by: lrz-relief <[email protected]>
Co-authored-by: CokeDong <[email protected]>
Co-authored-by: Peter Ye <[email protected]>
Co-authored-by: Jincong Chen <[email protected]>
Co-authored-by: wugeshui <[email protected]>
  • Loading branch information
21 people authored Nov 7, 2023
1 parent 49364d2 commit c2d3157
Show file tree
Hide file tree
Showing 21 changed files with 1,940 additions and 54 deletions.
72 changes: 55 additions & 17 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ env:
ENV_NAME_MMCV: 'pt1.11v1'
GPU_REQUESTS: 1
SLURM_PAR_SH1988: ${{ vars.SLURM_PAR_SH1984 != '' && vars.SLURM_PAR_SH1984 || 'pat_rd -x SH-IDC1-10-198-8-58,SH-IDC1-10-198-8-87' }}
SLURM_PAR_SH1424: ${{ vars.SLURM_PAR_SH1424 != '' && vars.SLURM_PAR_SH1424 || 'pat_rd' }}
SLURM_PAR_CAMB: ${{ vars.SLURM_PAR_CAMB != '' && vars.SLURM_PAR_CAMB || 'camb_mlu370_m8 --exclude HOST-10-142-11-120,HOST-10-142-11-126' }}
CLUSTER_1988: SH1988
CLUSTER_1424: SH1424
CLUSTER_CAMB: CAMB
CLUSTER_ASCEND: ASCEND
CLUSTER_TOPSRIDER: TOPSRIDER
Expand Down Expand Up @@ -82,6 +84,9 @@ jobs:
&& rsync -a --delete ${GITHUB_WORKSPACE}/DIOPI/ ${CLUSTER_TOPSRIDER}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to topsrider"
ssh ${CLUSTER_SUPA} "mkdir -p ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source" \
&& rsync -a --delete ${GITHUB_WORKSPACE}/DIOPI/ ${CLUSTER_SUPA}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to supa"
ssh ${CLUSTER_1424} "mkdir -p ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source" \
&& rsync -a --delete ${GITHUB_WORKSPACE}/DIOPI/ ${CLUSTER_1424}:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/source/ || echo "failure to connect to sh1424"
lint:
name: lint
Expand Down Expand Up @@ -117,16 +122,24 @@ jobs:
export CI=true
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1988} --time=10 bash -c 'cd impl && bash scripts/build_impl.sh torch' || ( cd ${NFS_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${BUILD_TEST1} && exit 1 )
"""
- name: build-dyload
Build-Nvidia-A100:
name: Build-Nvidia-A100
runs-on: tps-diopi-ci
needs: [Rsync]
steps:
- name: build
run: |
ssh ${CLUSTER_1988} """
set -e
source ${ENV_PATH}/github_bashrc && source /mnt/cache/share/platform/env/${ENV_NAME}
cd ${NFS_PATH}/${GITHUB_RUN_NUMBER} && rm -rf ${BUILD_TEST2} && cp -R source ${BUILD_TEST2} && cd ${BUILD_TEST2}
ssh ${CLUSTER_1424} """
set -ex
export USE_COVERAGE=ON
source /mnt/cache/share/platform/env/${ENV_NAME}
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && rm -rf ${BUILD_TEST1} && cp -R source ${BUILD_TEST1} && cd ${BUILD_TEST1}
export CI=true
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1988} --time=10 bash -c 'cd impl && bash scripts/build_impl.sh torch_dyload' || ( cd ${NFS_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${BUILD_TEST2} && exit 1 )
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1424} --time=10 bash -c 'cd impl && bash scripts/build_impl.sh torch' || ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/ && rm -rf ${BUILD_TEST1} && exit 1 )
"""
Build-Camb:
name: Build-Camb
runs-on: tps-diopi-ci
Expand Down Expand Up @@ -203,6 +216,32 @@ jobs:
|| ( cd ${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 )
"""
Gen-Data-Op-Test-A100:
name: Gen-Data-Op-Test-A100
runs-on: tps-diopi-ci
needs: [Build-Nvidia-A100]
steps:
- name: gen-test-data
run: |
ssh ${CLUSTER_1424} """
set -e
export CI=true
source /mnt/cache/share/platform/env/${ENV_NAME}
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1} && cp -f scripts/ci/diopi_config_A100.py diopi_test/python/configs/diopi_configs.py && cd diopi_test/python && ls &&
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1424} --time=10 --gres=gpu:${GPU_REQUESTS} bash -c 'python main.py --mode gen_data' \
|| ( cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} && git clean -xdf ${GEN_DATA} && exit 1 )
"""
- name: test-op
run: |
ssh ${CLUSTER_1424} """
set -e
export CI=true
source /mnt/cache/share/platform/env/${ENV_NAME} && cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1}
export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/impl/lib
echo \$LD_LIBRARY_PATH
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1424} --time=20 --gres=gpu:${GPU_REQUESTS} bash -c 'cd diopi_test/python && python main.py --mode gen_case && python main.py --mode run_test'
"""
Op-Test-Nvidia:
name: Op-Test-Nvidia
runs-on: tps-diopi-ci
Expand All @@ -224,25 +263,24 @@ jobs:
bash /mnt/cache/share/platform/dep/sonar/coverage_DIOPI_nv.sh ${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1} ${GITHUB_RUN_NUMBER} || echo "get coverage fail"
fi
"""
- name: increment coverage check
if: ${{ contains( github.event_name, 'pull_request' ) && contains( github.base_ref, 'main' ) }}
- name: test
run: |
ssh ${CLUSTER_1988} """
set -e
export CI=true
source ${ENV_PATH}/github_bashrc && source /mnt/cache/share/platform/env/${ENV_NAME} && cd ${NFS_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1}
bash scripts/increment_coverage.sh ${REQUIRE_COVERAGE}
export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/impl/lib
echo \$LD_LIBRARY_PATH
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1988} --time=20 --gres=gpu:${GPU_REQUESTS} bash -c 'cd diopi_test/python && python main.py --mode gen_case &&
python main.py --mode run_test'
"""
- name: dyload-test
- name: increment coverage check
if: ${{ contains( github.event_name, 'pull_request' ) && contains( github.base_ref, 'main' ) }}
run: |
ssh ${CLUSTER_1988} """
set -e
export CI=true
source ${ENV_PATH}/github_bashrc && source /mnt/cache/share/platform/env/${ENV_NAME} && cd ${NFS_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST2}
rm -rf ${GEN_DATA} && ln -s ${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST1}/${GEN_DATA} ${GEN_DATA}
export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:${NFS_PATH}/${GITHUB_RUN_NUMBER}/${BUILD_TEST2}/impl/lib
echo \$LD_LIBRARY_PATH
srun --job-name=${GITHUB_JOB} --partition=${SLURM_PAR_SH1988} --time=20 --gres=gpu:${GPU_REQUESTS} bash -c 'cd diopi_test/python && python main.py --mode gen_case &&
python main.py --mode run_test'
source ${ENV_PATH}/github_bashrc && source /mnt/cache/share/platform/env/${ENV_NAME} && cd ${NFS_PATH}/${GITHUB_RUN_NUMBER} && cd ${BUILD_TEST1}
bash scripts/increment_coverage.sh ${REQUIRE_COVERAGE}
"""
Rt-Test-Nvidia:
Expand Down
28 changes: 18 additions & 10 deletions diopi_test/diopi_stub/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,7 @@ def get_func_info(content):
return type_change, args, attr_types, paras_can_be_none, ins_vector, outs_vector, out_ptr


def gen_functions(options, functions_fm):
_cur_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(_cur_dir, options.get('source_dir'), 'functions.h'), 'r', encoding='utf8')as f:
content = f.readlines()
exports = []
device = options.get('device')
if device == 'ascend':
ft = OT.function_ascend_template
else:
ft = OT.function_template
def get_export(content, ft, exports):
for idx, row in enumerate(content):
if row.startswith("DIOPI_API"):
row = row[10:]
Expand Down Expand Up @@ -167,6 +158,23 @@ def gen_functions(options, functions_fm):
else:
exports.append(ft.substitute(env=dict(func_name=func_name, attrs=', '.join(arg_def), convert='',
out_copy='', call_func=call_func)))
return exports


def gen_functions(options, functions_fm):
_cur_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(_cur_dir, options.get('source_dir'), 'functions.h'), 'r', encoding='utf8')as f:
content = f.readlines()
exports = []
device = options.get('device')
if device == 'ascend':
ft = OT.function_ascend_template
else:
ft = OT.function_template
exports = get_export(content, ft, exports)
with open(os.path.join(_cur_dir, options.get('source_dir'), 'functions_ext.h'), 'r', encoding='utf8')as f:
content_ext = f.readlines()
exports = get_export(content_ext, ft, exports)

functions_fm.write("export_functions.cpp", OT.operators_template, env=dict(export_functions=exports))

Expand Down
1 change: 1 addition & 0 deletions diopi_test/diopi_stub/codegen/op_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class OpTemplate(object):
#include "litert.hpp"
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_ext.h>
namespace py = pybind11;
Expand Down
90 changes: 90 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8018,4 +8018,94 @@
],
),
),

'rotary_emb': dict(
name=['rotary_emb'],
interface=['CustomizedTest'],
dtype=[np.float64, np.float32, np.float16],
para=dict(
conj=[False, True, False, True],
),
tensor_para=dict(
gen_fn='Genfunc.randn',
args=[
{
"ins": ['input'],
"shape": ((1, 125, 16, 32), (1, 125, 16, 32), (2, 64, 16, 32), (3, 100, 8, 64)),
},
{
"ins": ['cos'],
"shape": ((125, 1, 16), (125, 1, 16), (64, 1, 16), (100, 1, 32)),
},
{
"ins": ['sin'],
"shape": ((125, 1, 16), (125, 1, 16), (64, 1, 16), (100, 1, 32)),
},
],
),
),

'rms_norm': dict(
name=['rms_norm'],
interface=['CustomizedTest'],
dtype=[np.float32],
para=dict(
eps=[1e-6, 1e-6, 1e-6, 1e-6],
normalized_shape=[(5, ), (32, ), (64, ), (8, )],
),
tensor_para=dict(
gen_fn='Genfunc.randn',
args=[
{
"ins": ['input'],
"shape": ((5, 5), (35, 125, 32), (16, 64, 64), (1, 32, 32, 8)),
},
{
"ins": ['weight'],
"shape": ((5, ), (32, ), (64, ), (8, )),
},
{
"ins": ['bias'],
"shape": ((5, ), (32, ), (64, ), (8, )),
},
],
),
),

# 'multihead_attention_forward': dict(
# name=['multihead_attention_forward'],
# interface=['CustomizedTest'],
# dtype=[np.float16],
# atol=1e-3,
# rtol=1e-4,
# para=dict(
# dropout_p=[0, 0],
# is_causal=[False, False],
# return_debug_mask=[False, False],
# scale=[None, None]
# ),
# tensor_para=dict(
# gen_fn='Genfunc.randn',
# args=[
# {
# "ins": ['q'],
# "shape": ((2, 2, 2, 8), (2, 5, 7, 8)),
# "dtype": [np..float16],
# "gen_fn": Genfunc.randn,
# },
# {
# "ins": ['k'],
# "shape": ((2, 2, 2, 8), (2, 5, 7, 8)),
# "dtype": [np.float16],
# "gen_fn": Genfunc.randn,
# },
# {
# "ins": ['v'],
# "shape": ((2, 2, 2, 8), (2, 5, 7, 8)),
# "dtype": [np.float16],
# "gen_fn": Genfunc.randn,
# },
# ],
# ),
# ),
}
37 changes: 37 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3974,3 +3974,40 @@ def linalgqr(input, mode):
out = [q, r]
check_returncode(ret)
return out


def rotary_emb(input, cos, sin, conj):
call = "diopiRotaryEmbedding"
func = check_function(call)
size = list(input.size().data)
out = Tensor(size, input.get_dtype())
ret = func(input.context(), out, input, cos, sin, conj, False)
check_returncode(ret)
return out


def rms_norm(input, normalized_shape, weight, bias, eps):
call = "diopiRMSNorm"
func = check_function(call)
size = list(input.size().data)
out = Tensor(size, input.get_dtype())
inv_rms = Tensor(size, input.get_dtype())
normalized_shape = Sizes(list(normalized_shape))
ret = func(input.context(), out, inv_rms, input, normalized_shape, weight, bias, eps)
check_returncode(ret)
return out


def multihead_attention_forward(q, k, v, dropout_p, is_causal, return_debug_mask, scale):
call = "diopiMultiHeadAttention"
func = check_function(call)
q_size = list(q.size().data)
k_size = list(k.size().data)
out = Tensor(q_size, q.get_dtype())
softmax_lse = Tensor([q_size[0], q_size[2], q_size[1]], q.get_dtype())
gen = None
debug_attn_mask = Tensor([0], q.get_dtype())
softmax_scale = 1.0 / math.sqrt(q.shape().data[-1]) if not scale else scale
ret = func(q.context(), q, k, v, dropout_p, is_causal, return_debug_mask, softmax_scale, out, softmax_lse, gen, debug_attn_mask)
check_returncode(ret)
return out
41 changes: 41 additions & 0 deletions diopi_test/python/conformance/gen_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,47 @@ def batch_norm_elemt(input, weight, bias, mean, invstd, eps):
out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
return out

def rotary_emb(input, cos, sin, conj):
x1, x2 = input.chunk(2, dim=-1)
data_type = input.dtype
x1 = x1.to(torch.float32)
x2 = x2.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
if not conj:
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos
else:
out1 = x1 * cos + x2 * sin
out2 = -x1 * sin + x2 * cos
out1 = out1.to(data_type)
out2 = out2.to(data_type)
out = torch.cat((out1, out2), dim=-1)
return out

def rms_norm(input, normalized_shape, weight, bias, eps):
variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True)
input = input * torch.rsqrt(variance + eps)
out = weight * input
return out

def multihead_attention_forward(q, k, v, dropout_p, is_causal, return_debug_mask, scale):
# 为了保证精度,因此在test的时候不使用dropout
from einops import rearrange
import math

_, seqlen = q.shape[0], q.shape[1]
softmax_scale = 1.0 / math.sqrt(q.shape[-1]) if not scale else scale
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if is_causal:
causal_mask = torch.triu(
torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
output = torch.einsum("bhts,bshd->bthd", attention, v)
return output


class GenOutputData(object):
r'''
Expand Down
24 changes: 24 additions & 0 deletions impl/ascend/device_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2440,5 +2440,29 @@
interface=['torch'],
atol=1e-4,
rtol=1e-4,
),

'rotary_emb': dict(
name=["rotary_emb"],
tensor_para=dict(
args=[
{
"ins": ['input'],
"dtype": [Skip(np.float64), Skip(np.float32), Skip(np.float16)],
},
],
),
),

'rms_norm': dict(
name=["rms_norm"],
tensor_para=dict(
args=[
{
"ins": ['input'],
"dtype": [Skip(np.float32)],
},
],
),
)
}
Loading

0 comments on commit c2d3157

Please sign in to comment.