ml-agents执行原理和流程


可参考文档:https://unity-technologies.github.io/ml-agents/Python-On-Off-Policy-Trainer-Documentation/

入口

在命令行执行mlagents-learn的时候,会调用对应python环境Scripts文件夹中的mlagents-learn.exe可执行文件,这个可执行文件会读取命令行的配置,并把这些配置发送给ml-agents包trainer文件夹中的learn.py文件,并开始执行,learn.py是ml-agents的入口。最后调用run_training函数,执行初始化的一系列操作。这个函数最后调用TrainerController中的start_learning函数开始训练。


def run_cli(options: RunOptions) -> None:
    try:
        print(
            """

                        ▄▄▄▓▓▓▓
                   ╓▓▓▓▓▓▓█▓▓▓▓▓
              ,▄▄▄m▀▀▀'  ,▓▓▓▀▓▓▄                           ▓▓▓  ▓▓▌
            ▄▓▓▓▀'      ▄▓▓▀  ▓▓▓      ▄▄     ▄▄ ,▄▄ ▄▄▄▄   ,▄▄ ▄▓▓▌▄ ▄▄▄    ,▄▄
          ▄▓▓▓▀        ▄▓▓▀   ▐▓▓▌     ▓▓▌   ▐▓▓ ▐▓▓▓▀▀▀▓▓▌ ▓▓▓ ▀▓▓▌▀ ^▓▓▌  ╒▓▓▌
        ▄▓▓▓▓▓▄▄▄▄▄▄▄▄▓▓▓      ▓▀      ▓▓▌   ▐▓▓ ▐▓▓    ▓▓▓ ▓▓▓  ▓▓▌   ▐▓▓▄ ▓▓▌
        ▀▓▓▓▓▀▀▀▀▀▀▀▀▀▀▓▓▄     ▓▓      ▓▓▌   ▐▓▓ ▐▓▓    ▓▓▓ ▓▓▓  ▓▓▌    ▐▓▓▐▓▓
          ^█▓▓▓        ▀▓▓▄   ▐▓▓▌     ▓▓▓▓▄▓▓▓▓ ▐▓▓    ▓▓▓ ▓▓▓  ▓▓▓▄    ▓▓▓▓`
            '▀▓▓▓▄      ^▓▓▓  ▓▓▓       └▀▀▀▀ ▀▀ ^▀▀    `▀▀ `▀▀   '▀▀    ▐▓▓▌
               ▀▀▀▀▓▄▄▄   ▓▓▓▓▓▓,                                      ▓▓▓▓▀
                   `▀█▓▓▓▓▓▓▓▓▓▌
                        ¬`▀▀▀█▓

        """
        )
    except Exception:
        print("\n\n\tUnity Technologies\n")
    print(get_version_string())

    if options.debug:
        log_level = logging_util.DEBUG
    else:
        log_level = logging_util.INFO

    logging_util.set_log_level(log_level)

    logger.debug("Configuration for this run:")
    logger.debug(json.dumps(options.as_dict(), indent=4))

    # Options deprecation warnings
    if options.checkpoint_settings.load_model:
        logger.warning(
            "The --load option has been deprecated. Please use the --resume option instead."
        )
    if options.checkpoint_settings.train_model:
        logger.warning(
            "The --train option has been deprecated. Train mode is now the default. Use "
            "--inference to run in inference mode."
        )

    run_seed = options.env_settings.seed
    num_areas = options.env_settings.num_areas

    # Add some timer metadata
    add_timer_metadata("mlagents_version", mlagents.trainers.__version__)
    add_timer_metadata("mlagents_envs_version", mlagents_envs.__version__)
    add_timer_metadata("communication_protocol_version", UnityEnvironment.API_VERSION)
    add_timer_metadata("pytorch_version", torch_utils.torch.__version__)
    add_timer_metadata("numpy_version", np.__version__)

    if options.env_settings.seed == -1:
        run_seed = np.random.randint(0, 10000)
        logger.debug(f"run_seed set to {run_seed}")
    run_training(run_seed, options, num_areas)


def main():
    run_cli(parse_command_line())


# For python debugger to directly run this script
if __name__ == "__main__":
    main()
def run_training(run_seed: int, options: RunOptions, num_areas: int) -> None:
    """
    Launches training session.
    :param run_seed: Random seed used for training.
    :param num_areas: Number of training areas to instantiate
    :param options: parsed command line arguments
    """
    with hierarchical_timer("run_training.setup"):
        torch_utils.set_torch_config(options.torch_settings)
        checkpoint_settings = options.checkpoint_settings
        env_settings = options.env_settings
        engine_settings = options.engine_settings

        run_logs_dir = checkpoint_settings.run_logs_dir
        port: Optional[int] = env_settings.base_port
        # Check if directory exists
        validate_existing_directories(
            checkpoint_settings.write_path,
            checkpoint_settings.resume,
            checkpoint_settings.force,
            checkpoint_settings.maybe_init_path,
        )
        # Make run logs directory
        os.makedirs(run_logs_dir, exist_ok=True)
        # Load any needed states in case of resume
        if checkpoint_settings.resume:
            GlobalTrainingStatus.load_state(
                os.path.join(run_logs_dir, "training_status.json")
            )
        # In case of initialization, set full init_path for all behaviors
        elif checkpoint_settings.maybe_init_path is not None:
            setup_init_path(options.behaviors, checkpoint_settings.maybe_init_path)

        # Configure Tensorboard Writers and StatsReporter
        stats_writers = register_stats_writer_plugins(options)
        for sw in stats_writers:
            StatsReporter.add_writer(sw)

        if env_settings.env_path is None:
            port = None
        # 这里创建了一个Unity环境
        env_factory = create_environment_factory(
            env_settings.env_path,
            engine_settings.no_graphics,
            run_seed,
            num_areas,
            port,
            env_settings.env_args,
            os.path.abspath(run_logs_dir),  # Unity environment requires absolute path
        )

        env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs)
        env_parameter_manager = EnvironmentParameterManager(
            options.environment_parameters, run_seed, restore=checkpoint_settings.resume
        )

        trainer_factory = TrainerFactory(
            trainer_config=options.behaviors,
            output_path=checkpoint_settings.write_path,
            train_model=not checkpoint_settings.inference,
            load_model=checkpoint_settings.resume,
            seed=run_seed,
            param_manager=env_parameter_manager,
            init_path=checkpoint_settings.maybe_init_path,
            multi_gpu=False,
        )
        # Create controller and begin training.
        tc = TrainerController(
            trainer_factory,
            checkpoint_settings.write_path,
            checkpoint_settings.run_id,
            env_parameter_manager,
            not checkpoint_settings.inference,
            run_seed,
        )

    # 开始训练
    try:
        tc.start_learning(env_manager)
    finally:
        env_manager.close()
        write_run_options(checkpoint_settings.write_path, options)
        write_timing_tree(run_logs_dir)
        write_training_status(run_logs_dir)

训练

执行TrainerController中的start_learning,主要代码如下

def start_learning(self, env_manager: EnvManager) -> None:
    self._create_output_path(self.output_path)
    try:
        # 初始化环境
        self._reset_env(env_manager)
        self.param_manager.log_current_lesson()
        while self._not_done_training():
            # 执行环境
            n_steps = self.advance(env_manager)
            for _ in range(n_steps):
                self.reset_env_if_ready(env_manager)
        # Stop advancing trainers
        self.join_threads()

持续执行advance函数,更新每个trainer:

def advance(self, env_manager: EnvManager) -> int:
    # Get steps
    with hierarchical_timer("env_step"):
        new_step_infos = env_manager.get_steps()
        self._register_new_behaviors(env_manager, new_step_infos)
        num_steps = env_manager.process_steps(new_step_infos)

    # Report current lesson for each environment parameter
    for (
        param_name,
        lesson_number,
    ) in self.param_manager.get_current_lesson_number().items():
        for trainer in self.trainers.values():
            trainer.stats_reporter.set_stat(
                f"Environment/Lesson Number/{param_name}", lesson_number
            )

    for trainer in self.trainers.values():
        if not trainer.threaded:
            with hierarchical_timer("trainer_advance"):
                trainer.advance()

    return num_steps

trainer可以使用以下函数进行创建,创建trainer的同时会开启新线程,线程会执行advance函数:

def _create_trainer_and_manager(
    self, env_manager: EnvManager, name_behavior_id: str
) -> None:

    parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id)
    brain_name = parsed_behavior_id.brain_name
    trainerthread = None
    if brain_name in self.trainers:
        trainer = self.trainers[brain_name]
    else:
        trainer = self.trainer_factory.generate(brain_name)
        self.trainers[brain_name] = trainer
        if trainer.threaded:
            # Only create trainer thread for new trainers
            trainerthread = threading.Thread(
                target=self.trainer_update_func, args=(trainer,), daemon=True
            )
            self.trainer_threads.append(trainerthread)
        env_manager.on_training_started(
            brain_name, self.trainer_factory.trainer_config[brain_name]
        )

    policy = trainer.create_policy(
        parsed_behavior_id,
        env_manager.training_behaviors[name_behavior_id],
        create_graph=True,
    )
    trainer.add_policy(parsed_behavior_id, policy)

    agent_manager = AgentManager(
        policy,
        name_behavior_id,
        trainer.stats_reporter,
        trainer.parameters.time_horizon,
        threaded=trainer.threaded,
    )
    env_manager.set_agent_manager(name_behavior_id, agent_manager)
    env_manager.set_policy(name_behavior_id, policy)
    self.brain_name_to_identifier[brain_name].add(name_behavior_id)

    trainer.publish_policy_queue(agent_manager.policy_queue)
    trainer.subscribe_trajectory_queue(agent_manager.trajectory_queue)

    # Only start new trainers
    if trainerthread is not None:
        trainerthread.start()

def _create_trainers_and_managers(
    self, env_manager: EnvManager, behavior_ids: Set[str]
) -> None:
    for behavior_id in behavior_ids:
        self._create_trainer_and_manager(env_manager, behavior_id)
def trainer_update_func(self, trainer: Trainer) -> None:
    while not self.kill_trainers:
        with hierarchical_timer("trainer_advance"):
            trainer.advance()          

创建trainer的代码在_register_new_behaviors中调用,而这个函数在advance和_reset_env中调用,_reset_env正是一开始在start_training中调用的函数。

def _register_new_behaviors(
    self, env_manager: EnvManager, step_infos: List[EnvironmentStep]
) -> None:
    """
    Handle registration (adding trainers and managers) of new behaviors ids.
    :param env_manager:
    :param step_infos:
    :return:
    """
    step_behavior_ids: Set[str] = set()
    for s in step_infos:
        step_behavior_ids |= set(s.name_behavior_ids)
    new_behavior_ids = step_behavior_ids - self.registered_behavior_ids
    self._create_trainers_and_managers(env_manager, new_behavior_ids)
    self.registered_behavior_ids |= step_behavior_ids

def _reset_env(self, env_manager: EnvManager) -> None:
    """Resets the environment.

    Returns:
        A Data structure corresponding to the initial reset state of the
        environment.
    """
    new_config = self.param_manager.get_current_samplers()
    env_manager.reset(config=new_config)
    # Register any new behavior ids that were generated on the reset.
    self._register_new_behaviors(env_manager, env_manager.first_step_infos)

文章作者: 微笑紫瞳星
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 微笑紫瞳星 !
  目录