Skip to content
Snippets Groups Projects
Unverified Commit 7db3376e authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2603 fix bug of rotate

Merge pull request !2603 from 周莉莉/master
parents e3f3c7e6 2c618afc
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@ def get_args():
parser.add_argument(
'--config_path',
type=str,
required=False,
required=True,
default='',
help="json file for dataset"
)
......
......@@ -228,7 +228,7 @@ bash scripts/run_eval.sh [DEVICE_ID] [DEVICE_TARGET] [EVAL_CHECKPOINT] [EVAL_LOG
在运行以下命令之前,请检查用于评估的检查点路径。
```bash
bash scripts/run_eval.sh 0 Ascend Data/wn18rr/ checkpoints/rotate-standalone-ascend/rotate.ckpt eval-standalone-ascend.log
bash scripts/run_eval.sh 0 Ascend checkpoints/rotate-standalone-ascend/rotate.ckpt eval-standalone-ascend.log
```
上述python命令将在后台运行,您可以通过ms_log/eval-standalone-ascend文件查看类似如下的结果:
......@@ -342,4 +342,4 @@ python export.py --eval_checkpoint [EVAL_CHECKPOINT] --file_format [FILE_FORMAT]
# [ModelZoo主页](#目录)
请浏览官网[主页](https://gitee.com/mindspore/models)
\ No newline at end of file
请浏览官网[主页](https://gitee.com/mindspore/models)
......@@ -42,20 +42,20 @@ class KGEModel(nn.Cell):
def __init__(self, network, mode='head-mode'):
super(KGEModel, self).__init__()
self.network = network
self.construct_head = self.network.construct_head
self.construct_tail = self.network.construct_tail
self.mode = mode
self.sort = P.Sort(axis=1, descending=True)
def construct(self, positive_sample, negative_sample, filter_bias):
""" Sort candidate entity id and positive sample entity id. """
if self.mode == 'head-mode':
score = self.network.construct_head((positive_sample, negative_sample))
positive_arg = positive_sample[:, 0]
score = self.construct_head((positive_sample, negative_sample))
else:
score = self.network.construct_tail((positive_sample, negative_sample))
positive_arg = positive_sample[:, 2]
score = self.construct_tail((positive_sample, negative_sample))
score += filter_bias
_, argsort = self.sort(score)
return argsort, positive_arg
return argsort
class EvalKGEMetric(nn.Cell):
......@@ -77,27 +77,34 @@ class EvalKGEMetric(nn.Cell):
def construct(self, positive_sample, negative_sample, filter_bias):
""" Calculate metrics. """
batch_size = positive_sample.shape[0]
argsort, positive_arg = self.kgemodel(positive_sample, negative_sample, filter_bias)
argsort, positive_arg = argsort.asnumpy(), positive_arg.asnumpy()
log = []
for i in range(batch_size):
ranking = np.where(argsort[i, :] == positive_arg[i])[0][0]
ranking = 1 + ranking
log.append({
'MRR': 1.0 / ranking,
'MR': ranking,
'HITS@1': 1.0 if ranking <= 1 else 0.0,
'HITS@3': 1.0 if ranking <= 3 else 0.0,
'HITS@10': 1.0 if ranking <= 10 else 0.0,
})
return log
argsort = self.kgemodel(positive_sample, negative_sample, filter_bias)
if self.mode == 'head-mode':
positive_arg = positive_sample[:, 0]
else:
positive_arg = positive_sample[:, 2]
return argsort, positive_arg
def modelarts_process():
pass
def generate_log(argsort, positive_arg, batch_size):
argsort, positive_arg = argsort.asnumpy(), positive_arg.asnumpy()
log = []
for i in range(batch_size):
ranking = np.where(argsort[i, :] == positive_arg[i])[0][0]
ranking = 1 + ranking
log.append({
'MRR': 1.0 / ranking,
'MR': ranking,
'HITS@1': 1.0 if ranking <= 1 else 0.0,
'HITS@3': 1.0 if ranking <= 3 else 0.0,
'HITS@10': 1.0 if ranking <= 10 else 0.0,
})
return log
@moxing_wrapper(pre_process=modelarts_process)
def eval_kge():
""" Link Prediction Task for Knowledge Graph Embedding Model """
......@@ -127,10 +134,16 @@ def eval_kge():
eval_model_tail = EvalKGEMetric(network=eval_net, mode='tail-mode')
for test_data in test_dataloader_head.create_dict_iterator():
log_head = eval_model_head.construct(test_data["positive"], test_data["negative"], test_data["filter_bias"])
argsort, positive_arg = eval_model_head.construct(test_data["positive"], test_data["negative"],
test_data["filter_bias"])
batch_size = test_data["positive"].shape[0]
log_head = generate_log(argsort, positive_arg, batch_size)
logs += log_head
for test_data in test_dataloader_tail.create_dict_iterator():
log_tail = eval_model_tail.construct(test_data["positive"], test_data["negative"], test_data["filter_bias"])
argsort, positive_arg = eval_model_tail.construct(test_data["positive"], test_data["negative"],
test_data["filter_bias"])
batch_size = test_data["positive"].shape[0]
log_tail = generate_log(argsort, positive_arg, batch_size)
logs += log_tail
metrics = {}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment