可参考文档:https://unity-technologies.github.io/ml-agents/Python-On-Off-Policy-Trainer-Documentation/
入口
在命令行执行mlagents-learn的时候,会调用对应python环境Scripts文件夹中的mlagents-learn.exe可执行文件,这个可执行文件会读取命令行的配置,并把这些配置发送给ml-agents包trainer文件夹中的learn.py文件,并开始执行,learn.py是ml-agents的入口。最后调用run_training函数,执行初始化的一系列操作。这个函数最后调用TrainerController中的start_learning函数开始训练。
def run_cli(options: RunOptions) -> None:
try:
print(
"""
▄▄▄▓▓▓▓
╓▓▓▓▓▓▓█▓▓▓▓▓
,▄▄▄m▀▀▀' ,▓▓▓▀▓▓▄ ▓▓▓ ▓▓▌
▄▓▓▓▀' ▄▓▓▀ ▓▓▓ ▄▄ ▄▄ ,▄▄ ▄▄▄▄ ,▄▄ ▄▓▓▌▄ ▄▄▄ ,▄▄
▄▓▓▓▀ ▄▓▓▀ ▐▓▓▌ ▓▓▌ ▐▓▓ ▐▓▓▓▀▀▀▓▓▌ ▓▓▓ ▀▓▓▌▀ ^▓▓▌ ╒▓▓▌
▄▓▓▓▓▓▄▄▄▄▄▄▄▄▓▓▓ ▓▀ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▄ ▓▓▌
▀▓▓▓▓▀▀▀▀▀▀▀▀▀▀▓▓▄ ▓▓ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▐▓▓
^█▓▓▓ ▀▓▓▄ ▐▓▓▌ ▓▓▓▓▄▓▓▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▓▄ ▓▓▓▓`
'▀▓▓▓▄ ^▓▓▓ ▓▓▓ └▀▀▀▀ ▀▀ ^▀▀ `▀▀ `▀▀ '▀▀ ▐▓▓▌
▀▀▀▀▓▄▄▄ ▓▓▓▓▓▓, ▓▓▓▓▀
`▀█▓▓▓▓▓▓▓▓▓▌
¬`▀▀▀█▓
"""
)
except Exception:
print("\n\n\tUnity Technologies\n")
print(get_version_string())
if options.debug:
log_level = logging_util.DEBUG
else:
log_level = logging_util.INFO
logging_util.set_log_level(log_level)
logger.debug("Configuration for this run:")
logger.debug(json.dumps(options.as_dict(), indent=4))
# Options deprecation warnings
if options.checkpoint_settings.load_model:
logger.warning(
"The --load option has been deprecated. Please use the --resume option instead."
)
if options.checkpoint_settings.train_model:
logger.warning(
"The --train option has been deprecated. Train mode is now the default. Use "
"--inference to run in inference mode."
)
run_seed = options.env_settings.seed
num_areas = options.env_settings.num_areas
# Add some timer metadata
add_timer_metadata("mlagents_version", mlagents.trainers.__version__)
add_timer_metadata("mlagents_envs_version", mlagents_envs.__version__)
add_timer_metadata("communication_protocol_version", UnityEnvironment.API_VERSION)
add_timer_metadata("pytorch_version", torch_utils.torch.__version__)
add_timer_metadata("numpy_version", np.__version__)
if options.env_settings.seed == -1:
run_seed = np.random.randint(0, 10000)
logger.debug(f"run_seed set to {run_seed}")
run_training(run_seed, options, num_areas)
def main():
run_cli(parse_command_line())
# For python debugger to directly run this script
if __name__ == "__main__":
main()
def run_training(run_seed: int, options: RunOptions, num_areas: int) -> None:
"""
Launches training session.
:param run_seed: Random seed used for training.
:param num_areas: Number of training areas to instantiate
:param options: parsed command line arguments
"""
with hierarchical_timer("run_training.setup"):
torch_utils.set_torch_config(options.torch_settings)
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings
engine_settings = options.engine_settings
run_logs_dir = checkpoint_settings.run_logs_dir
port: Optional[int] = env_settings.base_port
# Check if directory exists
validate_existing_directories(
checkpoint_settings.write_path,
checkpoint_settings.resume,
checkpoint_settings.force,
checkpoint_settings.maybe_init_path,
)
# Make run logs directory
os.makedirs(run_logs_dir, exist_ok=True)
# Load any needed states in case of resume
if checkpoint_settings.resume:
GlobalTrainingStatus.load_state(
os.path.join(run_logs_dir, "training_status.json")
)
# In case of initialization, set full init_path for all behaviors
elif checkpoint_settings.maybe_init_path is not None:
setup_init_path(options.behaviors, checkpoint_settings.maybe_init_path)
# Configure Tensorboard Writers and StatsReporter
stats_writers = register_stats_writer_plugins(options)
for sw in stats_writers:
StatsReporter.add_writer(sw)
if env_settings.env_path is None:
port = None
# 这里创建了一个Unity环境
env_factory = create_environment_factory(
env_settings.env_path,
engine_settings.no_graphics,
run_seed,
num_areas,
port,
env_settings.env_args,
os.path.abspath(run_logs_dir), # Unity environment requires absolute path
)
env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs)
env_parameter_manager = EnvironmentParameterManager(
options.environment_parameters, run_seed, restore=checkpoint_settings.resume
)
trainer_factory = TrainerFactory(
trainer_config=options.behaviors,
output_path=checkpoint_settings.write_path,
train_model=not checkpoint_settings.inference,
load_model=checkpoint_settings.resume,
seed=run_seed,
param_manager=env_parameter_manager,
init_path=checkpoint_settings.maybe_init_path,
multi_gpu=False,
)
# Create controller and begin training.
tc = TrainerController(
trainer_factory,
checkpoint_settings.write_path,
checkpoint_settings.run_id,
env_parameter_manager,
not checkpoint_settings.inference,
run_seed,
)
# 开始训练
try:
tc.start_learning(env_manager)
finally:
env_manager.close()
write_run_options(checkpoint_settings.write_path, options)
write_timing_tree(run_logs_dir)
write_training_status(run_logs_dir)
训练
执行TrainerController中的start_learning,主要代码如下
def start_learning(self, env_manager: EnvManager) -> None:
self._create_output_path(self.output_path)
try:
# 初始化环境
self._reset_env(env_manager)
self.param_manager.log_current_lesson()
while self._not_done_training():
# 执行环境
n_steps = self.advance(env_manager)
for _ in range(n_steps):
self.reset_env_if_ready(env_manager)
# Stop advancing trainers
self.join_threads()
持续执行advance函数,更新每个trainer:
def advance(self, env_manager: EnvManager) -> int:
# Get steps
with hierarchical_timer("env_step"):
new_step_infos = env_manager.get_steps()
self._register_new_behaviors(env_manager, new_step_infos)
num_steps = env_manager.process_steps(new_step_infos)
# Report current lesson for each environment parameter
for (
param_name,
lesson_number,
) in self.param_manager.get_current_lesson_number().items():
for trainer in self.trainers.values():
trainer.stats_reporter.set_stat(
f"Environment/Lesson Number/{param_name}", lesson_number
)
for trainer in self.trainers.values():
if not trainer.threaded:
with hierarchical_timer("trainer_advance"):
trainer.advance()
return num_steps
trainer可以使用以下函数进行创建,创建trainer的同时会开启新线程,线程会执行advance函数:
def _create_trainer_and_manager(
self, env_manager: EnvManager, name_behavior_id: str
) -> None:
parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id)
brain_name = parsed_behavior_id.brain_name
trainerthread = None
if brain_name in self.trainers:
trainer = self.trainers[brain_name]
else:
trainer = self.trainer_factory.generate(brain_name)
self.trainers[brain_name] = trainer
if trainer.threaded:
# Only create trainer thread for new trainers
trainerthread = threading.Thread(
target=self.trainer_update_func, args=(trainer,), daemon=True
)
self.trainer_threads.append(trainerthread)
env_manager.on_training_started(
brain_name, self.trainer_factory.trainer_config[brain_name]
)
policy = trainer.create_policy(
parsed_behavior_id,
env_manager.training_behaviors[name_behavior_id],
create_graph=True,
)
trainer.add_policy(parsed_behavior_id, policy)
agent_manager = AgentManager(
policy,
name_behavior_id,
trainer.stats_reporter,
trainer.parameters.time_horizon,
threaded=trainer.threaded,
)
env_manager.set_agent_manager(name_behavior_id, agent_manager)
env_manager.set_policy(name_behavior_id, policy)
self.brain_name_to_identifier[brain_name].add(name_behavior_id)
trainer.publish_policy_queue(agent_manager.policy_queue)
trainer.subscribe_trajectory_queue(agent_manager.trajectory_queue)
# Only start new trainers
if trainerthread is not None:
trainerthread.start()
def _create_trainers_and_managers(
self, env_manager: EnvManager, behavior_ids: Set[str]
) -> None:
for behavior_id in behavior_ids:
self._create_trainer_and_manager(env_manager, behavior_id)
def trainer_update_func(self, trainer: Trainer) -> None:
while not self.kill_trainers:
with hierarchical_timer("trainer_advance"):
trainer.advance()
创建trainer的代码在_register_new_behaviors中调用,而这个函数在advance和_reset_env中调用,_reset_env正是一开始在start_training中调用的函数。
def _register_new_behaviors(
self, env_manager: EnvManager, step_infos: List[EnvironmentStep]
) -> None:
"""
Handle registration (adding trainers and managers) of new behaviors ids.
:param env_manager:
:param step_infos:
:return:
"""
step_behavior_ids: Set[str] = set()
for s in step_infos:
step_behavior_ids |= set(s.name_behavior_ids)
new_behavior_ids = step_behavior_ids - self.registered_behavior_ids
self._create_trainers_and_managers(env_manager, new_behavior_ids)
self.registered_behavior_ids |= step_behavior_ids
def _reset_env(self, env_manager: EnvManager) -> None:
"""Resets the environment.
Returns:
A Data structure corresponding to the initial reset state of the
environment.
"""
new_config = self.param_manager.get_current_samplers()
env_manager.reset(config=new_config)
# Register any new behavior ids that were generated on the reset.
self._register_new_behaviors(env_manager, env_manager.first_step_infos)