九游平台/ ai开发平台modelarts/ 常见问题/ standard notebook/ 在modelarts的notebook中使用moxing时,如何进行增量训练?
更新时间:2025-01-22 gmt 08:00

在modelarts的notebook中使用moxing时,如何进行增量训练?-九游平台

在使用moxing构建模型时,如果您对前一次训练结果不满意,可以在更改部分数据和标注信息后,进行增量训练。

“mox.run”添加增量训练参数

在完成标注数据或数据集的修改后,您可以在“mox.run”中,修改“log_dir”参数,并新增“checkpoint_path”参数。其中“log_dir”参数建议设置为一个新的目录,“checkpoint_path”参数设置为上一次训练结果输出路径,如果是obs目录,路径填写时建议使用“obs://”开头。

如果标注数据中的标签发生了变化,在运行“mox.run”前先执行如果标签发生变化的操作。

  mox.run(input_fn=input_fn,
          model_fn=model_fn,
          optimizer_fn=optimizer_fn,
          run_mode=flags.run_mode,
          inter_mode=mox.modekeys.eval if use_eval_data else none,
          log_dir=log_dir,
          batch_size=batch_size_per_device,
          auto_batch=false,
          max_number_of_steps=max_number_of_steps,
          log_every_n_steps=flags.log_every_n_steps,
          save_summary_steps=save_summary_steps,
          save_model_secs=save_model_secs,
          checkpoint_path=flags.checkpoint_url,
          export_model=mox.exportkeys.tf_serving)

如果标签发生变化

当数据集中的标签发生变化时,需要执行如下语句。此语句需在“mox.run”之前运行。

语句中的“logits”,表示根据不同网络中分类层权重的变量名,配置不同的参数。此处填写其对应的关键字。

mox.set_flag('checkpoint_exclude_patterns', 'logits')

如果使用的是moxing内置网络,其对应的关键字需使用如下api获取。此示例将打印resnet_v1_50的关键字,为“logits”

import moxing.tensorflow as mox
model_meta = mox.get_model_meta(mox.networkkeys.resnet_v1_50)
logits_pattern = model_meta.default_logits_pattern
print(logits_pattern)

您也可以通过如下接口,获取moxing支持的网络名称列表。

import moxing.tensorflow as mox
print(help(mox.networkkeys))

打印出来的示例如下所示:

help on class networkkeys in module 
moxing.tensorflow.nets.nets_factory:
class networkkeys(builtins.object)
 |  data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  data and other attributes defined here:
 |  
 |  alexnet_v2 = 'alexnet_v2'
 |  
 |  cifarnet = 'cifarnet'
 |  
 |  inception_resnet_v2 = 'inception_resnet_v2'
 |  
 |  inception_v1 = 'inception_v1'
 |  
 |  inception_v2 = 'inception_v2'
 |  
 |  inception_v3 = 'inception_v3'
 |  
 |  inception_v4 = 'inception_v4'
 |  
 |  lenet = 'lenet'
 |  
 |  mobilenet_v1 = 'mobilenet_v1'
 |  
 |  mobilenet_v1_025 = 'mobilenet_v1_025'
 |  
 |  mobilenet_v1_050 = 'mobilenet_v1_050'
 |  
 |  mobilenet_v1_075 = 'mobilenet_v1_075'
 |  
 |  mobilenet_v2 = 'mobilenet_v2'
 |  
 |  mobilenet_v2_035 = 'mobilenet_v2_035'
 |  
 |  mobilenet_v2_140 = 'mobilenet_v2_140'
 |  
 |  nasnet_cifar = 'nasnet_cifar'
 |  
 |  nasnet_large = 'nasnet_large'
 |  
 |  nasnet_mobile = 'nasnet_mobile'
 |  
 |  overfeat = 'overfeat'
 |  
 |  pnasnet_large = 'pnasnet_large'
 |  
 |  pnasnet_mobile = 'pnasnet_mobile'
 |  
 |  pvanet = 'pvanet'
 |  
 |  resnet_v1_101 = 'resnet_v1_101'
 |  
 |  resnet_v1_110 = 'resnet_v1_110'
 |  
 |  resnet_v1_152 = 'resnet_v1_152'
 |  
 |  resnet_v1_18 = 'resnet_v1_18'
 |  
 |  resnet_v1_20 = 'resnet_v1_20'
 |  
 |  resnet_v1_200 = 'resnet_v1_200'
 |  
 |  resnet_v1_50 = 'resnet_v1_50'
 |  
 |  resnet_v1_50_8k = 'resnet_v1_50_8k'
 |  
 |  resnet_v1_50_mox = 'resnet_v1_50_mox'
 |  
 |  resnet_v1_50_oct = 'resnet_v1_50_oct'
 |  
 |  resnet_v2_101 = 'resnet_v2_101'
 |  
 |  resnet_v2_152 = 'resnet_v2_152'
 |  
 |  resnet_v2_200 = 'resnet_v2_200'
 |  
 |  resnet_v2_50 = 'resnet_v2_50'
 |  
 |  resnext_b_101 = 'resnext_b_101'
 |  
 |  resnext_b_50 = 'resnext_b_50'
 |  
 |  resnext_c_101 = 'resnext_c_101'
 |  
 |  resnext_c_50 = 'resnext_c_50'
 |  
 |  vgg_16 = 'vgg_16'
 |  
 |  vgg_16_bn = 'vgg_16_bn'
 |  
 |  vgg_19 = 'vgg_19'
 |  
 |  vgg_19_bn = 'vgg_19_bn'
 |  
 |  vgg_a = 'vgg_a'
 |  
 |  vgg_a_bn = 'vgg_a_bn'
 |  
 |  xception_41 = 'xception_41'
 |  
 |  xception_65 = 'xception_65'
 |  
 |  xception_71 = 'xception_71'

相关文档

网站地图