diff --git a/official/nlp/pangu_alpha/serving_increment/pangu_distributed/pangu/servable_config.py b/official/nlp/pangu_alpha/serving_increment/pangu_distributed/pangu/servable_config.py index a61b8f0cc1f53e5be729631c4094495316e6f9f8..0899cf0e435c1fc95ec1774b61c7afe418bb65e6 100644 --- a/official/nlp/pangu_alpha/serving_increment/pangu_distributed/pangu/servable_config.py +++ b/official/nlp/pangu_alpha/serving_increment/pangu_distributed/pangu/servable_config.py @@ -52,27 +52,10 @@ def topk_fun(logits, topk=5): return value, index -distributed.declare_servable(rank_size=8, stage_size=1, with_batch_dim=False) +model = distributed.declare_servable(rank_size=8, stage_size=1, with_batch_dim=False) -@register.register_method(output_names=["logits"]) -def predict_sub0(input_ids, current_index, init, batch_valid_length): - logits = register.call_servable(input_ids, current_index, init, batch_valid_length, subgraph=0) - return logits - - -@register.register_method(output_names=["logits"]) -def predict_sub1(input_id, current_index, init, batch_valid_length): - logits = register.call_servable(input_id, current_index, init, batch_valid_length, subgraph=1) - return logits - - -sub0_servable = register.PipelineServable(servable_name="pangu", method="predict_sub0") -sub1_servable = register.PipelineServable(servable_name="pangu", method="predict_sub1") - - -@register.register_pipeline(output_names=["output_sentence"]) -def predict(input_sentence): +def predict_stage(input_sentence): """generate sentence with given input_sentence""" print(f"----------------------------- begin {input_sentence} ---------", flush=True) @@ -136,7 +119,7 @@ def generate_increment(origin_inputs): init_false = False init = init_false # Call a single inference with input size of (bs, seq_length) - logits = sub0_servable.run(np.array(input_ids, np.int32), current_index, init, batch_valid_length) + logits = model.call(np.array(input_ids, np.int32), current_index, init, batch_valid_length, subgraph=0) # Claim the second graph and set not_init to true init = init_true @@ -198,6 +181,12 @@ def generate_increment(origin_inputs): outputs.append(int(target)) # Call a single inference with input size of (bs, 1) - logits = sub1_servable.run(input_id, current_index, init, batch_valid_length) + logits = model.call(input_id, current_index, init, batch_valid_length, subgraph=1) # Return valid outputs out of padded outputs return outputs + + +@register.register_method(output_names=["output_sentence"]) +def predict(input_sentence): + reply = register.add_stage(predict_stage, input_sentence, outputs_count=1) + return reply diff --git a/official/nlp/pangu_alpha/serving_increment/pangu_standalone/pangu/servable_config.py b/official/nlp/pangu_alpha/serving_increment/pangu_standalone/pangu/servable_config.py index 0668e834d45d59633cd14cb33dc211eab7246e80..e901f4ba9245cf2c8eb88c6352d2a6870dd5d303 100644 --- a/official/nlp/pangu_alpha/serving_increment/pangu_standalone/pangu/servable_config.py +++ b/official/nlp/pangu_alpha/serving_increment/pangu_standalone/pangu/servable_config.py @@ -51,28 +51,11 @@ def topk_fun(logits, topk=5): return value, index -register.declare_servable(servable_file=["pangu_alpha_1024_graph.mindir", "pangu_alpha_1_graph.mindir"], - model_format="MINDIR", with_batch_dim=False) +model = register.declare_model(model_file=["pangu_alpha_1024_graph.mindir", "pangu_alpha_1_graph.mindir"], + model_format="MINDIR", with_batch_dim=False) -@register.register_method(output_names=["logits"]) -def predict_sub0(input_ids, current_index, init, batch_valid_length): - logits = register.call_servable(input_ids, current_index, init, batch_valid_length, subgraph=0) - return logits - - -@register.register_method(output_names=["logits"]) -def predict_sub1(input_id, current_index, init, batch_valid_length): - logits = register.call_servable(input_id, current_index, init, batch_valid_length, subgraph=1) - return logits - - -sub0_servable = register.PipelineServable(servable_name="pangu", method="predict_sub0") -sub1_servable = register.PipelineServable(servable_name="pangu", method="predict_sub1") - - -@register.register_pipeline(output_names=["output_sentence"]) -def predict(input_sentence): +def predict_stage(input_sentence): """generate sentence with given input_sentence""" print(f"----------------------------- begin {input_sentence} ---------", flush=True) @@ -136,7 +119,7 @@ def generate_increment(origin_inputs): init_false = False init = init_false # Call a single inference with input size of (bs, seq_length) - logits = sub0_servable.run(np.array(input_ids, np.int32), current_index, init, batch_valid_length) + logits = model.call(np.array(input_ids, np.int32), current_index, init, batch_valid_length, subgraph=0) # Claim the second graph and set not_init to true init = init_true @@ -198,6 +181,12 @@ def generate_increment(origin_inputs): outputs.append(int(target)) # Call a single inference with input size of (bs, 1) - logits = sub1_servable.run(input_id, current_index, init, batch_valid_length) + logits = model.call(input_id, current_index, init, batch_valid_length, subgraph=1) # Return valid outputs out of padded outputs return outputs + + +@register.register_method(output_names=["output_sentence"]) +def predict(input_sentence): + reply = register.add_stage(predict_stage, input_sentence, outputs_count=1) + return reply