diff --git a/README.md b/README.md index 68ae4386e..7c00010c7 100644 --- a/README.md +++ b/README.md @@ -146,20 +146,6 @@ The left side is a basic rendering (which will likely be replaced by a proper rendering some day). The right side is the feature layers that the agent receives, with some coloring to make it more useful to us. -## Watch a replay - -Running the random agent and playing as a human save a replay by default. You -can watch that replay by running: - -```shell -$ python -m pysc2.bin.play --replay -``` - -This works for any replay as long as the map can be found by the game. - -The same controls work as for playing the game, so `F4` to exit, `pgup`/`pgdn` -to control the speed, etc. - ## List the maps [Maps](docs/maps.md) need to be configured before they're known to the @@ -190,11 +176,6 @@ configure your own, take a look [here](docs/maps.md). A replay lets you review what happened during a game. You can see the actions and observations that each player made as they played. -Blizzard is releasing a large number of anonymized 1v1 replays played on the -ladder. You can find instructions for how to get the -[replay files](https://github.com/Blizzard/s2client-proto#downloads) on their -site. You can also review your own replays. - Replays can be played back to get the observations and actions made during that game. The observations are rendered at the resolution you request, so may differ from what the human actually saw. Similarly the actions specify a point, which @@ -203,6 +184,55 @@ match in our observations, though they should be fairly similar. Replays are version dependent, so a 3.15 replay will fail in a 3.16 binary. +## Watch a replay + +Running the random agent and playing as a human will save a replay by default. You +can watch that replay by running: + +```shell +$ python -m pysc2.bin.play --replay +``` + +This works for any replay as long as the map can be found by the game. + +The same controls work as for playing the game, so `F4` to exit, `pgup`/`pgdn` +to control the speed, etc. + You can visualize the replays with the full game, or with `pysc2.bin.play`. -Alternatively you can run `pysc2.bin.replay_actions` to process many replays -in parallel. +Alternatively you can run `pysc2.bin.process_replays` to process many replays +in parallel by supplying a replay directory. Each replay in the supplied directory +will be processed. + +```shell +$ python -m pysc2.bin.process_replays --replays +``` +The default number of instances to run in parallel is 1, but can be changed using +the `parallel` argument. + +```shell +$ python -m pysc2.bin.process_replays --replays --parallel +``` + +## Parse a replay + +To collect data from one or more replays, a replay parser can be used. Two example +replay parsers can be found in the replay_parsers folder: + +* `action_parser`: Collects statistics about actions and general replay stats and prints to console +* `player_info_parser`: Collects General player info at each replay step and saves to file + +To run a specific replay parser, pass the parser as the `parser` argument. If the replay parser +returns data to be stored in a file, a directory must be supplied to the `data_dir` argument + +```shell +$ python -m pysc2.bin.process_replays --replays --parser pysc2.replay_parsers.action_parser.ActionParser --data_dir +``` + +Details on how to implement a custom replay parser can be found in the [here](docs/environment.md#replay-parsers). + +## Public Replays + +Blizzard is releasing a large number of anonymized 1v1 replays played on the +ladder. You can find instructions for how to get the +[replay files](https://github.com/Blizzard/s2client-proto#downloads) on their +site. diff --git a/docs/environment.md b/docs/environment.md index d42352ce2..a8bb13b48 100644 --- a/docs/environment.md +++ b/docs/environment.md @@ -439,3 +439,24 @@ There are a couple basic agents. * `random_agent`: Just plays randomly, shows how to make valid moves. * `scripted_agent`: These are scripted for a single easy map. + +## Replay Parsers + +Custom replay parsers can be built to collect data from replay files. Two example +replay parsers can be found in the replay_parsers folder: + +* `action_parser`: Collects statistics about actions and general replay stats and prints to console +* `player_info_parser`: Collects General player info at each replay step and saves to file + +To build a custom replay parser, a class that inherits from BaseParser needs to be defined. +The main method of the replay parser is the `parse_step` method. This function must take as arguments: +`obs`,`feat` and `info` which are the game observations, feature layers and replay information at a single +step in the replay, which is passed to the parser from `process_replays` script for each step in the +replay file. This information is used to parse the desired data. If the `parse_step` function returns, +the returned value is appended to a list containing the parsed data for each step in the replay. Once the +replay is finished, this list is saved to a data file in the supplied `data_dir` directory. +If no directory is supplied, the data is not saved to a file. + +The `valid_replay` method of the parent BaseParser class can be overridden to supply a custom +definition for valid replays (ie. filter out replays having players with small MMR). The `valid_replay` +method must take as arguments `info` and `ping` supplied from `process_replays` and return a boolean. \ No newline at end of file diff --git a/pysc2/bin/process_replays.py b/pysc2/bin/process_replays.py new file mode 100644 index 000000000..2129b6484 --- /dev/null +++ b/pysc2/bin/process_replays.py @@ -0,0 +1,296 @@ +#!/usr/bin/python +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dump out stats about all the actions that are in use in a set of replays.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import multiprocessing +import os +import signal +import sys +import threading +import time + +from future.builtins import range # pylint: disable=redefined-builtin +import six +from six.moves import queue + +from pysc2 import run_configs +from pysc2.lib import features +from pysc2.lib import point +from pysc2.lib import protocol +from pysc2.lib import remote_controller + +from absl import app +from absl import flags +from pysc2.lib import gfile +from s2clientprotocol import common_pb2 as sc_common +from s2clientprotocol import sc2api_pb2 as sc_pb + +import importlib +import json +import sys + +FLAGS = flags.FLAGS +flags.DEFINE_integer("parallel", 1, "How many instances to run in parallel.") +flags.DEFINE_integer("step_mul", 8, "How many game steps per observation.") +flags.DEFINE_string("replays", "None", "Path to a directory of replays.") +flags.DEFINE_string("parser", "pysc2.replay_parsers.action_parser.ActionParser", + "Which parser to use in scrapping replay data") +flags.DEFINE_string("data_dir", None, + "Path to directory to save replay data from replay parser") +flags.DEFINE_integer("screen_resolution", 16, + "Resolution for screen feature layers.") +flags.DEFINE_integer("minimap_resolution", 16, + "Resolution for minimap feature layers.") + +interface = sc_pb.InterfaceOptions() +interface.raw = True +interface.score = False +interface.feature_layer.width = 24 +interface.feature_layer.resolution.x = FLAGS.screen_resolution +interface.feature_layer.resolution.y = FLAGS.screen_resolution +interface.feature_layer.minimap_resolution.x = FLAGS.minimap_resolution +interface.feature_layer.minimap_resolution.y = FLAGS.minimap_resolution + +class ProcessStats(object): + """Stats for a worker process.""" + + def __init__(self, proc_id, parser_cls): + self.proc_id = proc_id + self.time = time.time() + self.stage = "" + self.replay = "" + self.parser = parser_cls() + + def update(self, stage): + self.time = time.time() + self.stage = stage + + def __str__(self): + return ("[%2d] replay: %10s, replays: %5d, steps: %7d, game loops: %7s, " + "last: %12s, %3d s ago" % ( + self.proc_id, self.replay, self.parser.replays, + self.parser.steps, + self.parser.steps * FLAGS.step_mul, self.stage, + time.time() - self.time)) + + +class ReplayProcessor(multiprocessing.Process): + """A Process that pulls replays and processes them.""" + + def __init__(self, proc_id, run_config, replay_queue, stats_queue, parser_cls): + super(ReplayProcessor, self).__init__() + self.stats = ProcessStats(proc_id, parser_cls) + self.run_config = run_config + self.replay_queue = replay_queue + self.stats_queue = stats_queue + + def run(self): + signal.signal(signal.SIGTERM, lambda a, b: sys.exit()) # Exit quietly. + self._update_stage("spawn") + replay_name = "none" + while True: + self._print("Starting up a new SC2 instance.") + self._update_stage("launch") + try: + with self.run_config.start() as controller: + self._print("SC2 Started successfully.") + ping = controller.ping() + for _ in range(300): + try: + replay_path = self.replay_queue.get() + except queue.Empty: + self._update_stage("done") + self._print("Empty queue, returning") + return + try: + self.load_replay(replay_path, controller, ping) + finally: + self.replay_queue.task_done() + self._update_stage("shutdown") + except (protocol.ConnectionError, protocol.ProtocolError, + remote_controller.RequestError): + self.stats.parser.crashing_replays.add(replay_name) + except KeyboardInterrupt: + return + + def load_replay(self, replay_path, controller, ping): + replay_name = os.path.basename(replay_path) + self.stats.replay = replay_name + self._print("Got replay: %s" % replay_path) + self._update_stage("open replay file") + replay_data = self.run_config.replay_data(replay_path) + self._update_stage("replay_info") + info = controller.replay_info(replay_data) + self._print((" Replay Info %s " % replay_name).center(60, "-")) + self._print(info) + self._print("-" * 60) + if self.stats.parser.valid_replay(info, ping): + self.stats.parser.maps[info.map_name] += 1 + for player_info in info.player_info: + race_name = sc_common.Race.Name( + player_info.player_info.race_actual) + self.stats.parser.races[race_name] += 1 + map_data = None + if info.local_map_path: + self._update_stage("open map file") + map_data = self.run_config.map_data(info.local_map_path) + for player_id in [1, 2]: + self._print("Starting %s from player %s's perspective" % ( + replay_name, player_id)) + self.process_replay(controller, replay_data, map_data, + player_id, info, replay_name) + else: + self._print("Replay is invalid.") + self.stats.parser.invalid_replays.add(replay_name) + + def _print(self, s): + for line in str(s).strip().splitlines(): + print("[%s] %s" % (self.stats.proc_id, line)) + + def _update_stage(self, stage): + self.stats.update(stage) + self.stats_queue.put(self.stats) + + def process_replay(self, controller, replay_data, map_data, player_id, info, replay_name): + print(replay_name) + """Process a single replay, updating the stats.""" + self._update_stage("start_replay") + controller.start_replay(sc_pb.RequestStartReplay( + replay_data=replay_data, + map_data=map_data, + options=interface, + observed_player_id=player_id)) + + feat = features.Features(controller.game_info()) + + self.stats.parser.replays += 1 + self._update_stage("step") + controller.step() + data = [] + while True: + self.stats.parser.steps += 1 + self._update_stage("observe") + obs = controller.observe() + # If parser.parse_step returns, whatever is returned is appended + # to a data list, and this data list is saved to a json file + # in the data_dir directory with filename = replay_name_player_id.json + parsed_data = self.stats.parser.parse_step(obs,feat,info) + if parsed_data: + data.append(parsed_data) + + if obs.player_result: + # Save scraped replay data to file at end of replay if parser returns + # and data_dir provided + if data: + if FLAGS.data_dir: + stripped_replay_name = replay_name.split(".")[0] + data_file = os.path.join(FLAGS.data_dir, + stripped_replay_name + "_" + str(player_id) + '.json') + with open(data_file,'w') as outfile: + json.dump(data,outfile) + else: + print("Please provide a directory as data_dir to save scrapped data files") + break + + self._update_stage("step") + controller.step(FLAGS.step_mul) + + +def stats_printer(stats_queue, parser_cls): + """A thread that consumes stats_queue and prints them every 10 seconds.""" + proc_stats = [ProcessStats(i,parser_cls) for i in range(FLAGS.parallel)] + print_time = start_time = time.time() + width = 107 + + running = True + while running: + print_time += 10 + + while time.time() < print_time: + try: + s = stats_queue.get(True, print_time - time.time()) + if s is None: # Signal to print and exit NOW! + running = False + break + proc_stats[s.proc_id] = s + except queue.Empty: + pass + + parser = parser_cls() + for s in proc_stats: + parser.merge(s.parser) + + print((" Summary %0d secs " % (print_time - start_time)).center(width, "=")) + print(parser) + print(" Process stats ".center(width, "-")) + print("\n".join(str(s) for s in proc_stats)) + print("=" * width) + + +def replay_queue_filler(replay_queue, replay_list): + """A thread that fills the replay_queue with replay filenames.""" + for replay_path in replay_list: + replay_queue.put(replay_path) + + +def main(unused_argv): + """Collect data from a set of replays using supplied parser.""" + run_config = run_configs.get() + + parser_module, parser_name = FLAGS.parser.rsplit(".", 1) + parser_cls = getattr(importlib.import_module(parser_module), parser_name) + + if not gfile.Exists(FLAGS.replays): + sys.exit("Replay Path {} doesn't exist.".format(FLAGS.replays)) + + stats_queue = multiprocessing.Queue() + stats_thread = threading.Thread(target=stats_printer, args=(stats_queue,parser_cls)) + stats_thread.start() + try: + # For some reason buffering everything into a JoinableQueue makes the + # program not exit, so save it into a list then slowly fill it into the + # queue in a separate thread. Grab the list synchronously so we know there + # is work in the queue before the SC2 processes actually run, otherwise + # The replay_queue.join below succeeds without doing any work, and exits. + print("Getting replay list:", FLAGS.replays) + replay_list = sorted(run_config.replay_paths(FLAGS.replays)) + print(len(replay_list), "replays found.\n") + replay_queue = multiprocessing.JoinableQueue(FLAGS.parallel * 10) + replay_queue_thread = threading.Thread(target=replay_queue_filler, + args=(replay_queue, replay_list)) + replay_queue_thread.daemon = True + replay_queue_thread.start() + + for i in range(FLAGS.parallel): + p = ReplayProcessor(i, run_config, replay_queue, stats_queue, parser_cls) + p.daemon = True + p.start() + time.sleep(1) # Stagger startups, otherwise they seem to conflict somehow + + replay_queue.join() # Wait for the queue to empty. + except KeyboardInterrupt: + print("Caught KeyboardInterrupt, exiting.") + finally: + stats_queue.put(None) # Tell the stats_thread to print and exit. + stats_thread.join() + + +if __name__ == "__main__": + app.run(main) diff --git a/pysc2/bin/replay_actions.py b/pysc2/bin/replay_actions.py deleted file mode 100755 index 93369a766..000000000 --- a/pysc2/bin/replay_actions.py +++ /dev/null @@ -1,372 +0,0 @@ -#!/usr/bin/python -# Copyright 2017 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS-IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Dump out stats about all the actions that are in use in a set of replays.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import multiprocessing -import os -import signal -import sys -import threading -import time - -from future.builtins import range # pylint: disable=redefined-builtin -import six -from six.moves import queue - -from pysc2 import run_configs -from pysc2.lib import features -from pysc2.lib import point -from pysc2.lib import protocol -from pysc2.lib import remote_controller - -from absl import app -from absl import flags -from pysc2.lib import gfile -from s2clientprotocol import common_pb2 as sc_common -from s2clientprotocol import sc2api_pb2 as sc_pb - -FLAGS = flags.FLAGS -flags.DEFINE_integer("parallel", 1, "How many instances to run in parallel.") -flags.DEFINE_integer("step_mul", 8, "How many game steps per observation.") -flags.DEFINE_string("replays", None, "Path to a directory of replays.") -flags.mark_flag_as_required("replays") - - -size = point.Point(16, 16) -interface = sc_pb.InterfaceOptions( - raw=True, score=False, - feature_layer=sc_pb.SpatialCameraSetup(width=24)) -size.assign_to(interface.feature_layer.resolution) -size.assign_to(interface.feature_layer.minimap_resolution) - - -def sorted_dict_str(d): - return "{%s}" % ", ".join("%s: %s" % (k, d[k]) - for k in sorted(d, key=d.get, reverse=True)) - - -class ReplayStats(object): - """Summary stats of the replays seen so far.""" - - def __init__(self): - self.replays = 0 - self.steps = 0 - self.camera_move = 0 - self.select_pt = 0 - self.select_rect = 0 - self.control_group = 0 - self.maps = collections.defaultdict(int) - self.races = collections.defaultdict(int) - self.unit_ids = collections.defaultdict(int) - self.valid_abilities = collections.defaultdict(int) - self.made_abilities = collections.defaultdict(int) - self.valid_actions = collections.defaultdict(int) - self.made_actions = collections.defaultdict(int) - self.crashing_replays = set() - self.invalid_replays = set() - - def merge(self, other): - """Merge another ReplayStats into this one.""" - def merge_dict(a, b): - for k, v in six.iteritems(b): - a[k] += v - - self.replays += other.replays - self.steps += other.steps - self.camera_move += other.camera_move - self.select_pt += other.select_pt - self.select_rect += other.select_rect - self.control_group += other.control_group - merge_dict(self.maps, other.maps) - merge_dict(self.races, other.races) - merge_dict(self.unit_ids, other.unit_ids) - merge_dict(self.valid_abilities, other.valid_abilities) - merge_dict(self.made_abilities, other.made_abilities) - merge_dict(self.valid_actions, other.valid_actions) - merge_dict(self.made_actions, other.made_actions) - self.crashing_replays |= other.crashing_replays - self.invalid_replays |= other.invalid_replays - - def __str__(self): - len_sorted_dict = lambda s: (len(s), sorted_dict_str(s)) - len_sorted_list = lambda s: (len(s), sorted(s)) - return "\n\n".join(( - "Replays: %s, Steps total: %s" % (self.replays, self.steps), - "Camera move: %s, Select pt: %s, Select rect: %s, Control group: %s" % ( - self.camera_move, self.select_pt, self.select_rect, - self.control_group), - "Maps: %s\n%s" % len_sorted_dict(self.maps), - "Races: %s\n%s" % len_sorted_dict(self.races), - "Unit ids: %s\n%s" % len_sorted_dict(self.unit_ids), - "Valid abilities: %s\n%s" % len_sorted_dict(self.valid_abilities), - "Made abilities: %s\n%s" % len_sorted_dict(self.made_abilities), - "Valid actions: %s\n%s" % len_sorted_dict(self.valid_actions), - "Made actions: %s\n%s" % len_sorted_dict(self.made_actions), - "Crashing replays: %s\n%s" % len_sorted_list(self.crashing_replays), - "Invalid replays: %s\n%s" % len_sorted_list(self.invalid_replays), - )) - - -class ProcessStats(object): - """Stats for a worker process.""" - - def __init__(self, proc_id): - self.proc_id = proc_id - self.time = time.time() - self.stage = "" - self.replay = "" - self.replay_stats = ReplayStats() - - def update(self, stage): - self.time = time.time() - self.stage = stage - - def __str__(self): - return ("[%2d] replay: %10s, replays: %5d, steps: %7d, game loops: %7s, " - "last: %12s, %3d s ago" % ( - self.proc_id, self.replay, self.replay_stats.replays, - self.replay_stats.steps, - self.replay_stats.steps * FLAGS.step_mul, self.stage, - time.time() - self.time)) - - -def valid_replay(info, ping): - """Make sure the replay isn't corrupt, and is worth looking at.""" - if (info.HasField("error") or - info.base_build != ping.base_build or # different game version - info.game_duration_loops < 1000 or - len(info.player_info) != 2): - # Probably corrupt, or just not interesting. - return False - for p in info.player_info: - if p.player_apm < 10 or p.player_mmr < 1000: - # Low APM = player just standing around. - # Low MMR = corrupt replay or player who is weak. - return False - return True - - -class ReplayProcessor(multiprocessing.Process): - """A Process that pulls replays and processes them.""" - - def __init__(self, proc_id, run_config, replay_queue, stats_queue): - super(ReplayProcessor, self).__init__() - self.stats = ProcessStats(proc_id) - self.run_config = run_config - self.replay_queue = replay_queue - self.stats_queue = stats_queue - - def run(self): - signal.signal(signal.SIGTERM, lambda a, b: sys.exit()) # Exit quietly. - self._update_stage("spawn") - replay_name = "none" - while True: - self._print("Starting up a new SC2 instance.") - self._update_stage("launch") - try: - with self.run_config.start() as controller: - self._print("SC2 Started successfully.") - ping = controller.ping() - for _ in range(300): - try: - replay_path = self.replay_queue.get() - except queue.Empty: - self._update_stage("done") - self._print("Empty queue, returning") - return - try: - replay_name = os.path.basename(replay_path)[:10] - self.stats.replay = replay_name - self._print("Got replay: %s" % replay_path) - self._update_stage("open replay file") - replay_data = self.run_config.replay_data(replay_path) - self._update_stage("replay_info") - info = controller.replay_info(replay_data) - self._print((" Replay Info %s " % replay_name).center(60, "-")) - self._print(info) - self._print("-" * 60) - if valid_replay(info, ping): - self.stats.replay_stats.maps[info.map_name] += 1 - for player_info in info.player_info: - race_name = sc_common.Race.Name( - player_info.player_info.race_actual) - self.stats.replay_stats.races[race_name] += 1 - map_data = None - if info.local_map_path: - self._update_stage("open map file") - map_data = self.run_config.map_data(info.local_map_path) - for player_id in [1, 2]: - self._print("Starting %s from player %s's perspective" % ( - replay_name, player_id)) - self.process_replay(controller, replay_data, map_data, - player_id) - else: - self._print("Replay is invalid.") - self.stats.replay_stats.invalid_replays.add(replay_name) - finally: - self.replay_queue.task_done() - self._update_stage("shutdown") - except (protocol.ConnectionError, protocol.ProtocolError, - remote_controller.RequestError): - self.stats.replay_stats.crashing_replays.add(replay_name) - except KeyboardInterrupt: - return - - def _print(self, s): - for line in str(s).strip().splitlines(): - print("[%s] %s" % (self.stats.proc_id, line)) - - def _update_stage(self, stage): - self.stats.update(stage) - self.stats_queue.put(self.stats) - - def process_replay(self, controller, replay_data, map_data, player_id): - """Process a single replay, updating the stats.""" - self._update_stage("start_replay") - controller.start_replay(sc_pb.RequestStartReplay( - replay_data=replay_data, - map_data=map_data, - options=interface, - observed_player_id=player_id)) - - feat = features.Features(controller.game_info()) - - self.stats.replay_stats.replays += 1 - self._update_stage("step") - controller.step() - while True: - self.stats.replay_stats.steps += 1 - self._update_stage("observe") - obs = controller.observe() - - for action in obs.actions: - act_fl = action.action_feature_layer - if act_fl.HasField("unit_command"): - self.stats.replay_stats.made_abilities[ - act_fl.unit_command.ability_id] += 1 - if act_fl.HasField("camera_move"): - self.stats.replay_stats.camera_move += 1 - if act_fl.HasField("unit_selection_point"): - self.stats.replay_stats.select_pt += 1 - if act_fl.HasField("unit_selection_rect"): - self.stats.replay_stats.select_rect += 1 - if action.action_ui.HasField("control_group"): - self.stats.replay_stats.control_group += 1 - - try: - func = feat.reverse_action(action).function - except ValueError: - func = -1 - self.stats.replay_stats.made_actions[func] += 1 - - for valid in obs.observation.abilities: - self.stats.replay_stats.valid_abilities[valid.ability_id] += 1 - - for u in obs.observation.raw_data.units: - self.stats.replay_stats.unit_ids[u.unit_type] += 1 - - for ability_id in feat.available_actions(obs.observation): - self.stats.replay_stats.valid_actions[ability_id] += 1 - - if obs.player_result: - break - - self._update_stage("step") - controller.step(FLAGS.step_mul) - - -def stats_printer(stats_queue): - """A thread that consumes stats_queue and prints them every 10 seconds.""" - proc_stats = [ProcessStats(i) for i in range(FLAGS.parallel)] - print_time = start_time = time.time() - width = 107 - - running = True - while running: - print_time += 10 - - while time.time() < print_time: - try: - s = stats_queue.get(True, print_time - time.time()) - if s is None: # Signal to print and exit NOW! - running = False - break - proc_stats[s.proc_id] = s - except queue.Empty: - pass - - replay_stats = ReplayStats() - for s in proc_stats: - replay_stats.merge(s.replay_stats) - - print((" Summary %0d secs " % (print_time - start_time)).center(width, "=")) - print(replay_stats) - print(" Process stats ".center(width, "-")) - print("\n".join(str(s) for s in proc_stats)) - print("=" * width) - - -def replay_queue_filler(replay_queue, replay_list): - """A thread that fills the replay_queue with replay filenames.""" - for replay_path in replay_list: - replay_queue.put(replay_path) - - -def main(unused_argv): - """Dump stats about all the actions that are in use in a set of replays.""" - run_config = run_configs.get() - - if not gfile.Exists(FLAGS.replays): - sys.exit("{} doesn't exist.".format(FLAGS.replays)) - - stats_queue = multiprocessing.Queue() - stats_thread = threading.Thread(target=stats_printer, args=(stats_queue,)) - stats_thread.start() - try: - # For some reason buffering everything into a JoinableQueue makes the - # program not exit, so save it into a list then slowly fill it into the - # queue in a separate thread. Grab the list synchronously so we know there - # is work in the queue before the SC2 processes actually run, otherwise - # The replay_queue.join below succeeds without doing any work, and exits. - print("Getting replay list:", FLAGS.replays) - replay_list = sorted(run_config.replay_paths(FLAGS.replays)) - print(len(replay_list), "replays found.\n") - replay_queue = multiprocessing.JoinableQueue(FLAGS.parallel * 10) - replay_queue_thread = threading.Thread(target=replay_queue_filler, - args=(replay_queue, replay_list)) - replay_queue_thread.daemon = True - replay_queue_thread.start() - - for i in range(FLAGS.parallel): - p = ReplayProcessor(i, run_config, replay_queue, stats_queue) - p.daemon = True - p.start() - time.sleep(1) # Stagger startups, otherwise they seem to conflict somehow - - replay_queue.join() # Wait for the queue to empty. - except KeyboardInterrupt: - print("Caught KeyboardInterrupt, exiting.") - finally: - stats_queue.put(None) # Tell the stats_thread to print and exit. - stats_thread.join() - - -if __name__ == "__main__": - app.run(main) diff --git a/pysc2/replay_parsers/__init__.py b/pysc2/replay_parsers/__init__.py new file mode 100644 index 000000000..b448f59d9 --- /dev/null +++ b/pysc2/replay_parsers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pysc2/replay_parsers/action_parser.py b/pysc2/replay_parsers/action_parser.py new file mode 100644 index 000000000..d087efc41 --- /dev/null +++ b/pysc2/replay_parsers/action_parser.py @@ -0,0 +1,118 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Action statistics parser for replays.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import six + +from pysc2.replay_parsers import base_parser + +class ActionParser(base_parser.BaseParser): + """Action statistics parser for replays.""" + + def __init__(self): + super(ActionParser, self).__init__() + self.camera_move = 0 + self.select_pt = 0 + self.select_rect = 0 + self.control_group = 0 + self.unit_ids = collections.defaultdict(int) + self.valid_abilities = collections.defaultdict(int) + self.made_abilities = collections.defaultdict(int) + self.valid_actions = collections.defaultdict(int) + self.made_actions = collections.defaultdict(int) + + def merge(self, other): + """Merge another ReplayStats into this one.""" + def merge_dict(a, b): + for k, v in six.iteritems(b): + a[k] += v + super(ActionParser,self).merge(other) + self.camera_move += other.camera_move + self.select_pt += other.select_pt + self.select_rect += other.select_rect + self.control_group += other.control_group + merge_dict(self.unit_ids, other.unit_ids) + merge_dict(self.valid_abilities, other.valid_abilities) + merge_dict(self.made_abilities, other.made_abilities) + merge_dict(self.valid_actions, other.valid_actions) + merge_dict(self.made_actions, other.made_actions) + + def valid_replay(self, info, ping): + """Make sure the replay isn't corrupt, and is worth looking at.""" + if (info.HasField("error") or + info.base_build != ping.base_build or # different game version + info.game_duration_loops < 1 or + len(info.player_info) != 2): + # Probably corrupt, or just not interesting. + return False + for p in info.player_info: + if p.player_apm < 10 or p.player_mmr < 0: + # Low APM = player just standing around. + # Low MMR = corrupt replay or player who is weak. + return False + return True + + def __str__(self): + len_sorted_dict = lambda s: (len(s), self.sorted_dict_str(s)) + len_sorted_list = lambda s: (len(s), sorted(s)) + return "\n\n".join(( + "Replays: %s, Steps total: %s" % (self.replays, self.steps), + "Camera move: %s, Select pt: %s, Select rect: %s, Control group: %s" % ( + self.camera_move, self.select_pt, self.select_rect, + self.control_group), + "Maps: %s\n%s" % len_sorted_dict(self.maps), + "Races: %s\n%s" % len_sorted_dict(self.races), + "Unit ids: %s\n%s" % len_sorted_dict(self.unit_ids), + "Valid abilities: %s\n%s" % len_sorted_dict(self.valid_abilities), + "Made abilities: %s\n%s" % len_sorted_dict(self.made_abilities), + "Valid actions: %s\n%s" % len_sorted_dict(self.valid_actions), + "Made actions: %s\n%s" % len_sorted_dict(self.made_actions), + "Crashing replays: %s\n%s" % len_sorted_list(self.crashing_replays), + "Invalid replays: %s\n%s" % len_sorted_list(self.invalid_replays), + )) + + def parse_step(self, obs, feat, info): + for action in obs.actions: + act_fl = action.action_feature_layer + if act_fl.HasField("unit_command"): + self.made_abilities[ + act_fl.unit_command.ability_id] += 1 + if act_fl.HasField("camera_move"): + self.camera_move += 1 + if act_fl.HasField("unit_selection_point"): + self.select_pt += 1 + if act_fl.HasField("unit_selection_rect"): + self.select_rect += 1 + if action.action_ui.HasField("control_group"): + self.control_group += 1 + + try: + func = feat.reverse_action(action).function + except ValueError: + func = -1 + self.made_actions[func] += 1 + + for valid in obs.observation.abilities: + self.valid_abilities[valid.ability_id] += 1 + + for u in obs.observation.raw_data.units: + self.unit_ids[u.unit_type] += 1 + + for ability_id in feat.available_actions(obs.observation): + self.valid_actions[ability_id] += 1 diff --git a/pysc2/replay_parsers/base_parser.py b/pysc2/replay_parsers/base_parser.py new file mode 100644 index 000000000..c6a04f6ae --- /dev/null +++ b/pysc2/replay_parsers/base_parser.py @@ -0,0 +1,64 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A base replay parser to write custom replay data scrappers.""" + +import collections +import six + +class BaseParser(object): + """A base replay parser to write custom replay data scrappers.""" + def __init__(self): + self.replays = 0 + self.steps = 0 + self.maps = collections.defaultdict(int) + self.races = collections.defaultdict(int) + self.crashing_replays = set() + self.invalid_replays = set() + + def merge(self, other): + """Merge another ReplayStats into this one.""" + + def merge_dict(a, b): + for k, v in six.iteritems(b): + a[k] += v + self.replays += other.replays + self.steps += other.steps + merge_dict(self.maps, other.maps) + merge_dict(self.races, other.races) + self.crashing_replays |= other.crashing_replays + self.invalid_replays |= other.invalid_replays + + def __str__(self): + len_sorted_dict = lambda s: (len(s), self.sorted_dict_str(s)) + len_sorted_list = lambda s: (len(s), sorted(s)) + return "\n\n".join(( + "Replays: %s, Steps total: %s" % (self.replays, self.steps), + "Maps: %s\n%s" % len_sorted_dict(self.maps), + "Races: %s\n%s" % len_sorted_dict(self.races), + "Crashing replays: %s\n%s" % len_sorted_list(self.crashing_replays), + "Invalid replays: %s\n%s" % len_sorted_list(self.invalid_replays), + )) + + def valid_replay(self, info, ping): + # All replays are valid in the base parser + return True + + def parse_step(self, obs, feat, info): + # Base parser doesn't directly parse any data, + # parse_step is a required function for parsers + raise NotImplementedError() + + def sorted_dict_str(self, d): + return "{%s}" % ", ".join("%s: %s" % (k, d[k]) + for k in sorted(d, key=d.get, reverse=True)) diff --git a/pysc2/replay_parsers/player_info_parser.py b/pysc2/replay_parsers/player_info_parser.py new file mode 100644 index 000000000..a2e048e94 --- /dev/null +++ b/pysc2/replay_parsers/player_info_parser.py @@ -0,0 +1,56 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example parser to collect some basic state data from replays. + The parser collects the General player information at each step, + along with the winning player_id of the replay""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import six +import numpy as np + +from pysc2.replay_parsers import base_parser + +class PlayerInfoParser(base_parser.BaseParser): + """Example parser for collection General player information + from replays.""" + def valid_replay(self, info, ping): + """Make sure the replay isn't corrupt, and is worth looking at.""" + if (info.HasField("error") or + info.base_build != ping.base_build or # different game version + info.game_duration_loops < 1000 or + len(info.player_info) != 2): + # Probably corrupt, or just not interesting. + return False + for p in info.player_info: + if p.player_apm < 10 or p.player_mmr < 1000: + # Low APM = player just standing around. + # Low MMR = corrupt replay or player who is weak. + return False + return True + + def parse_step(self, obs, feat, info): + # Obtain feature layers from current step observations + all_features = feat.transform_obs(obs.observation) + player_resources = all_features['player'].tolist() + + if info.player_info[0].player_result.result == 'Victory': + winner = 1 + else: + winner = 2 + # Return current replay step data to be appended and save to file + return [player_resources,winner] diff --git a/pysc2/tests/replay_parser_test.py b/pysc2/tests/replay_parser_test.py new file mode 100644 index 000000000..144466c40 --- /dev/null +++ b/pysc2/tests/replay_parser_test.py @@ -0,0 +1,77 @@ +#!/usr/bin/python +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Process a test replay using the replay parsers + A replay named "test_replay.SC2Replay" is required to exist in + the StarCraft II install Replay directory.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import multiprocessing + +from pysc2 import run_configs +from pysc2.replay_parsers import base_parser +from pysc2.replay_parsers import action_parser +from pysc2.bin import process_replays +from pysc2.tests import utils + +from absl.testing import absltest as basetest + + +class TestBaseParser(utils.TestCase): + + def test_true_valid_replay(self): + '''BaseParser returns valid_replay = True for all replays, + test the replay info loading and assert BaseParser does + return True for valid_replay call''' + + run_config = run_configs.get() + processor = process_replays.ReplayProcessor(proc_id = 0, + run_config = run_config, + replay_queue = None, + stats_queue = None, + parser_cls = base_parser.BaseParser) + with run_config.start() as controller: + ping = controller.ping() + replay_path = "test_replay.SC2Replay" + replay_data = run_config.replay_data(replay_path) + info = controller.replay_info(replay_data) + self.assertTrue(processor.stats.parser.valid_replay(info, ping)) + + def test_parse_replay(self): + '''Run the process_replay script for the test replay file and ensure + consistency of processing meta data''' + + run_config = run_configs.get() + stats_queue = multiprocessing.Queue() + processor = process_replays.ReplayProcessor(proc_id = 0, + run_config = run_config, + replay_queue = None, + stats_queue = stats_queue, + parser_cls = action_parser.ActionParser) + with run_config.start() as controller: + ping = controller.ping() + replay_path = "test_replay.SC2Replay" + processor.load_replay(replay_path, controller, ping) + # Test replay count == 2 (one for each player persepctive in test replay) + self.assertEqual(processor.stats.parser.replays, 2) + # Ensure test replay is valid for ActionParser + self.assertFalse(processor.stats.parser.invalid_replays) + # Test parser processes more than 1 step from test replay + self.assertTrue(processor.stats.parser.steps > 0) + +if __name__ == "__main__": + basetest.main()