Skip to content
Snippets Groups Projects
Commit 9bcdf707 authored by lixinqi's avatar lixinqi
Browse files

rm job_conf.num_of_batches_in_snapshot

parent 645cc145
No related branches found
No related tags found
No related merge requests found
......@@ -30,11 +30,6 @@ bool IsLastRegstInPieceWithOrder(const Regst* regst, ColIdOrder order) {
|| (order == ColIdOrder::kDescending && regst->col_id() == 0);
}
bool NeedModelSave(const JobDesc& job_desc, int64_t model_version_id) {
return model_version_id + 1 == job_desc.TotalBatchNum()
|| (model_version_id + 1) % job_desc.NumOfBatchesInSnapshot() == 0;
}
void Actor::Init(const JobDesc* job_desc, const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
job_desc_ = job_desc;
......
......@@ -19,7 +19,7 @@ enum class ColIdOrder { kUnCertain = 0, kAscending, kDescending };
bool IsFirstRegstInPieceWithOrder(const Regst*, ColIdOrder);
bool IsLastRegstInPieceWithOrder(const Regst*, ColIdOrder);
bool NeedModelSave(const JobDesc& job_desc, int64_t model_version_id);
inline bool NeedModelSave(const JobDesc& job_desc, int64_t model_version_id) { return false; }
class Actor {
public:
......
......@@ -277,10 +277,6 @@ std::function<const HashMap<int64_t, double>&(int64_t)> MakeGetterPathDurations4
};
}
uint64_t NumOfPiecesInSnapshot() {
return GlobalJobDesc().NumOfBatchesInSnapshot() * GlobalJobDesc().NumOfPiecesInBatch();
}
std::function<const HashMap<int64_t, double>&(int64_t)> MakeGetterPathIIScales4RegstDescId(
const ChainActGraph& graph) {
auto regst_desc_id2consumer_id2ii_scale =
......
......@@ -14,7 +14,6 @@ import "oneflow/core/job/sbp_parallel.proto";
message TrainConf {
required int64 batch_size = 1; // batch_size % piece_size = 0
required NormalModelUpdateOpUserConf model_update_conf = 3;
required int32 num_of_batches_in_snapshot = 5;
repeated string loss_lbn = 6;
optional int32 loss_scale_factor = 7 [default = 1];
optional string train_step_lbn = 8;
......
......@@ -46,9 +46,6 @@ bool JobDesc::enable_experiment_run() const {
return job_conf_.exp_run_conf().enable_experiment_run();
}
int32_t JobDesc::NumOfBatchesInSnapshot() const {
return job_conf_.train_conf().num_of_batches_in_snapshot();
}
int64_t JobDesc::TotalBatchNum() const { return job_conf_.total_batch_num(); }
int64_t JobDesc::BatchSize() const { return job_conf_.train_conf().batch_size(); }
int64_t JobDesc::NumOfPiecesInBatch() const {
......
......@@ -52,7 +52,6 @@ class JobDesc final {
int64_t cudnn_buf_limit_mbyte() const { return job_conf_.cudnn_buf_limit_mbyte(); }
// Train conf
int32_t NumOfBatchesInSnapshot() const;
int64_t TotalBatchNum() const;
int64_t BatchSize() const;
int64_t NumOfPiecesInBatch() const;
......
......@@ -210,7 +210,6 @@ def alexnet_train_job():
job_conf.train_conf()
job_conf.train_conf().batch_size = 12
job_conf.train_conf().primary_lr = 0.00001
job_conf.train_conf().num_of_batches_in_snapshot = 100
job_conf.train_conf().model_update_conf.naive_conf.SetInParent()
job_conf.train_conf().loss_lbn.extend(["softmax_loss/out"])
......
......@@ -131,7 +131,6 @@ def PretrainJob():
job_conf.train_conf()
job_conf.train_conf().primary_lr = 1e-4
job_conf.train_conf().weight_l2 = 0.01
job_conf.train_conf().num_of_batches_in_snapshot = 1000
job_conf.model_update_conf(_BERT_MODEL_UPDATE_CONF)
job_conf.train_conf().loss_lbn.extend(["identity_loss/loss"])
job_conf.enable_inplace(False)
......
......@@ -15,7 +15,6 @@ def UpdateVariable(x, scope_name, enable_all_reduce_group = True):
job_conf.batch_size(1).data_part_num(1).default_data_type(flow.float)
job_conf.train_conf()
job_conf.train_conf().primary_lr = 0.01
job_conf.train_conf().num_of_batches_in_snapshot = 100
job_conf.train_conf().model_update_conf.naive_conf.SetInParent()
job_conf.train_conf().loss_lbn.extend([scope_name + "-loss_op/out"])
job_conf.enable_all_reduce_group(enable_all_reduce_group)
......
......@@ -131,7 +131,6 @@ def PretrainJob():
job_conf.train_conf()
job_conf.train_conf().primary_lr = 1e-4
job_conf.train_conf().weight_l2 = 0.01
job_conf.train_conf().num_of_batches_in_snapshot = 1000
job_conf.model_update_conf(_BERT_MODEL_UPDATE_CONF)
job_conf.train_conf().loss_lbn.extend(["identity_loss/loss"])
job_conf.enable_inplace(False)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment