diff --git a/.git_oss_config_pub b/.git_oss_config_pub index 6299809aa..df66d8e7f 100644 --- a/.git_oss_config_pub +++ b/.git_oss_config_pub @@ -3,3 +3,4 @@ git_oss_data_dir = data/git_oss_sample_data host = oss-cn-beijing.aliyuncs.com git_oss_cache_dir = ${TMPDIR}/${PROJECT_NAME}/.git_oss_cache git_oss_private_config = ~/.git_oss_config_private +accl_endpoint = oss-accelerate.aliyuncs.com diff --git a/.github/workflows/ci_py3.yml b/.github/workflows/ci_py3.yml new file mode 100644 index 000000000..28887dc65 --- /dev/null +++ b/.github/workflows/ci_py3.yml @@ -0,0 +1,115 @@ +name: CI Build PY3 +on: + pull_request: + types: [opened, reopened, synchronize] + +jobs: + ci-test: + runs-on: EasyRec-py3-15 + defaults: + run: + shell: bash {0} + steps: + - name: FetchCommit ${{ github.event.pull_request.head.sha }} + uses: actions/checkout@v2 + with: + ref: ${{ github.event.pull_request.head.sha }} + submodules: recursive + - name: RunCiTest + id: run_ci_test + env: + TEST_DEVICES: "" + PULL_REQUEST_NUM: ${{ github.event.pull_request.number }} + run: | + source activate tf15_py3 + python git-lfs/git_lfs.py pull + source scripts/ci_test.sh + - name: LabelAndComment + env: + CI_TEST_PASSED: ${{steps.run_ci_test.outputs.ci_test_passed}} + uses: actions/github-script@v5 + with: + script: | + const { CI_TEST_PASSED } = process.env + labels = await github.rest.issues.listLabelsOnIssue({ + issue_number: context.issue.number, + repo:context.repo.repo, + owner:context.repo.owner + }) + console.log('labels.url=' + labels.url) + + labels = labels.data + + var label_names = [] + if (labels != null) { + labels.forEach(tmp_lbl => label_names.push(tmp_lbl.name)) + } + console.log(`ci_test_passed=${CI_TEST_PASSED} labels=${label_names}`); + + var pass_label = null; + if (labels != null) { + pass_label = labels.find(label=>label.name=='ci_py3_test_passed'); + } + + var fail_label = null; + if (labels != null) { + fail_label = labels.find(label=>label.name=='ci_py3_test_failed'); + } + + if (pass_label) { + github.rest.issues.removeLabel({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + name: 'ci_py3_test_passed' + }) + } + + if (fail_label) { + github.rest.issues.removeLabel({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + name: 'ci_py3_test_failed' + }) + } + + if (CI_TEST_PASSED == 1) { + github.rest.issues.addLabels({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + labels: ['ci_py3_test_passed'] + }) + + github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: "CI PY3 Test Passed" + }) + } else { + github.rest.issues.addLabels({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + labels: ['ci_py3_test_failed'] + }) + + github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: "CI PY3 Test Failed" + }) + } + - name: SignalFail + env: + CI_TEST_PASSED: ${{steps.run_ci_test.outputs.ci_test_passed}} + run: | + echo "CI_TEST_PASSED=${CI_TEST_PASSED}" + if [ $CI_TEST_PASSED -ne 1 ] + then + echo "ci_py3_test_failed, will exit" + exit 1 + fi diff --git a/easy_rec/python/model/multi_tower_bst.py b/easy_rec/python/model/multi_tower_bst.py index 6d93ebeda..11a21f98d 100644 --- a/easy_rec/python/model/multi_tower_bst.py +++ b/easy_rec/python/model/multi_tower_bst.py @@ -73,7 +73,7 @@ def attention_net(self, net, dim, cur_seq_len, seq_size, name): hist_mask = tf.sequence_mask( cur_seq_len, maxlen=seq_size - 1) # [B, seq_size-1] - cur_id_mask = tf.ones([tf.shape(hist_mask)[0], 1], dtype=tf.bool) # [B, 1] + cur_id_mask = tf.ones(tf.stack([tf.shape(hist_mask)[0], 1]), dtype=tf.bool) # [B, 1] mask = tf.concat([hist_mask, cur_id_mask], axis=1) # [B, seq_size] masks = tf.reshape(tf.tile(mask, [1, seq_size]), (-1, seq_size, seq_size)) # [B, seq_size, seq_size] diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 017e02f17..93d18da34 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -4,6 +4,7 @@ import glob import logging import os +import sys import unittest from distutils.version import LooseVersion @@ -255,24 +256,36 @@ def test_metric_learning(self): 'samples/model_config/metric_learning_on_taobao.config', self._test_dir) self.assertTrue(self._success) + @unittest.skipIf((sys.version_info.major, sys.version_info.minor) > (3,6), + 'Currently graph-learn not support python3.7' + ) def test_dssm_neg_sampler(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/dssm_neg_sampler_on_taobao.config', self._test_dir) self.assertTrue(self._success) + @unittest.skipIf((sys.version_info.major, sys.version_info.minor) > (3,6), + 'Currently graph-learn not support python3.7' + ) def test_dssm_neg_sampler_v2(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/dssm_neg_sampler_v2_on_taobao.config', self._test_dir) self.assertTrue(self._success) + @unittest.skipIf((sys.version_info.major, sys.version_info.minor) > (3,6), + 'Currently graph-learn not support python3.7' + ) def test_dssm_hard_neg_sampler(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/dssm_hard_neg_sampler_on_taobao.config', self._test_dir) self.assertTrue(self._success) + @unittest.skipIf((sys.version_info.major, sys.version_info.minor) > (3,6), + 'Currently graph-learn not support python3.7' + ) def test_dssm_hard_neg_sampler_v2(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/dssm_hard_neg_sampler_v2_on_taobao.config', diff --git a/git-lfs/git_lfs.py b/git-lfs/git_lfs.py index 89e64d799..caa191c59 100644 --- a/git-lfs/git_lfs.py +++ b/git-lfs/git_lfs.py @@ -226,6 +226,8 @@ def get_yes_no(msg): host = None bucket_name = None git_oss_private_path = None + enable_accelerate = 0 + accl_endpoint = None for line_str in fin: line_str = line_str.strip() if len(line_str) == 0: @@ -248,6 +250,8 @@ def get_yes_no(msg): git_oss_private_path = os.path.join(os.environ['HOME'], git_oss_private_path[2:]) elif line_tok[0] == 'git_oss_cache_dir': git_oss_cache_dir = line_tok[1] + elif line_tok[0] == 'accl_endpoint': + accl_endpoint = line_tok[1] logging.info('git_oss_data_dir=%s, host=%s, bucket_name=%s' % ( git_oss_data_dir, host, bucket_name)) @@ -353,15 +357,43 @@ def get_yes_no(msg): remote_path = git_bin_url[leaf_path][1] _, file_name_with_sig = os.path.split(remote_path) tar_tmp_path = '%s/%s.tar.gz' % (git_oss_cache_dir, file_name_with_sig) - if not os.path.exists(tar_tmp_path): - if oss_bucket: - oss_bucket.get_object_to_file(remote_path, tar_tmp_path) - else: - url = 'http://%s.%s/%s' % (bucket_name, host, remote_path) - subprocess.check_output(['wget', url, '-O', tar_tmp_path]) - else: - logging.info('%s is in cache' % file_name_with_sig) - subprocess.check_output(['tar', '-zxf', tar_tmp_path]) + + max_retry = 5 + while max_retry > 0: + try: + if not os.path.exists(tar_tmp_path): + in_cache = False + if oss_bucket: + oss_bucket.get_object_to_file(remote_path, tar_tmp_path) + else: + url = 'http://%s.%s/%s' % (bucket_name, host, remote_path) + subprocess.check_output(['wget', url, '-O', tar_tmp_path]) + else: + in_cache = True + logging.info('%s is in cache' % file_name_with_sig) + subprocess.check_output(['tar', '-zxf', tar_tmp_path]) + local_sig = get_local_sig(leaf_files) + if local_sig == remote_sig: + break + if in_cache: + logging.warning('cache invalid, will download from remote') + os.remove(tar_tmp_path) + continue + logging.warning('download failed, local_sig(%s) != remote_sig(%s)' % ( + local_sig, remote_sig)) + except subprocess.CalledProcessError as ex: + logging.error("exception: %s" % str(ex)) + except oss2.exceptions.RequestError as ex: + logging.error("exception: %s" % str(ex)) + + os.remove(tar_tmp_path) + if accl_endpoint is not None and host != accl_endpoint: + logging.info('will try accelerate endpoint: %s' % accl_endpoint) + host = accl_endpoint + if oss_auth: + oss_bucket = oss2.Bucket(oss_auth, host, bucket_name) + max_retry -= 1 + logging.info('%s updated' % leaf_path) any_update = True if not any_update: