diff --git a/README.md b/README.md index 3b20808..0cb9c2b 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # NAIL agent Navigate Acquire Interact Learn -NAIL is a general game-playing agent designed for parser-based interactive fiction games +NAIL is a general game-playing agent designed for parser-based interactive fiction games ([Hausknecht et al. 2019](https://arxiv.org/abs/1902.04259)). NAIL employs a simple heuristic: examine the current location to identify relevant objects, interact with the identified objects, navigate to a new location, and repeat. @@ -16,36 +16,33 @@ where it was evaluated on a set of twenty unknown parser-based IF games. * Python 3 ## Installation -* Install basic build tools. - * sudo apt-get update - * sudo apt-get install build-essential - * sudo apt-get install python3-dev - -* Install [fastText](https://github.com/facebookresearch/fastText#building-fasttext-for-python) - * pip3 install pybind11 - * git clone https://github.com/facebookresearch/fastText.git - * cd fastText - * pip3 install . - * cd .. - -* Install [Jericho](https://github.com/Microsoft/jericho) - * pip3 install jericho -* Clone this nail_agent repository to your Linux machine. -* Download the NAIL agent's language model to the nail_agent/agent/affordance_extractors/language_model directory: - * wget http://download.microsoft.com/download/B/8/8/B88DDDC1-F316-412A-94B3-025788436054/nail_agent_lm.zip -* unzip nail_agent_lm.zip - * The unzipped directory should contain 1028 files. - -* pip3 install numpy -* pip3 install fuzzywuzzy -* pip3 install spacy -* python3 -m spacy download en -* pip3 install python-Levenshtein +Install basic build tools. + + sudo apt-get update + sudo apt-get install build-essential + sudo apt-get install python3-dev + +> **Note:** We advise users to use virtual environments to avoid Python packages from different projects to interfere with each other. Popular choices are [Conda Environments](https://conda.io/projects/conda/en/latest/user-guide/getting-started.html) and [venv](https://docs.python.org/3/library/venv.html). + +Clone this **nail_agent** repository to your Linux machine. + + git clone https://github.com/microsoft/nail_agent.git + +Download the NAIL agent's language model (8.1Gb, 1028 files) to the `nail_agent/agent/affordance_extractors/language_model` directory: + + cd nail_agent/ + wget http://download.microsoft.com/download/B/8/8/B88DDDC1-F316-412A-94B3-025788436054/nail_agent_lm.zip + unzip nail_agent_lm.zip -d agent/affordance_extractors/language_model/ + +Install dependencies: + + pip install -r requirements.txt + python -m spacy download en_core_web_sm ## Usage -* Obtain a z-machine game (like zork1.z5) -* cd nail_agent -* python3 run_nail_agent.py +Obtain a Z-Machine game (like `zork1.z5`). Then, within the `nail_agent/` folder, run the following command: + + python run_nail_agent.py ## Contributing diff --git a/agent/entity_detectors/spacy_entity_detector.py b/agent/entity_detectors/spacy_entity_detector.py index 5ce6c0e..f9932d3 100644 --- a/agent/entity_detectors/spacy_entity_detector.py +++ b/agent/entity_detectors/spacy_entity_detector.py @@ -13,6 +13,9 @@ def __init__(self): def detect(self, observation_text): + # Spacy has trouble detecting entities ending with \n. + # Ref: https://github.com/explosion/spaCy/issues/4792#issuecomment-614295948 + observation_text = observation_text.replace("\n", " ") doc = gv.nlp(observation_text) nouns = [] for chunk in doc.noun_chunks: diff --git a/agent/gv.py b/agent/gv.py index 79b407e..156ef4f 100644 --- a/agent/gv.py +++ b/agent/gv.py @@ -27,9 +27,9 @@ # Spacy NLP instance try: - nlp = spacy.load('en') + nlp = spacy.load('en_core_web_sm') except Exception as e: - print("Failed to load \'en\' with exception {}. Try: python -m spacy download en".format(e)) + print("Failed to load \'en\' with exception {}. Try: python -m spacy download en_core_web_sm".format(e)) sys.exit(1) # Global Action Definitions diff --git a/agent/util.py b/agent/util.py index e343124..8c9303f 100644 --- a/agent/util.py +++ b/agent/util.py @@ -7,7 +7,7 @@ def first_sentence(text): """ Extracts the first sentence from text. """ tokens = gv.nlp(text) - return next(tokens.sents).merge().text + return next(tokens.sents).text def tokenize(description): diff --git a/agent/valid_detectors/learned_valid_detector.py b/agent/valid_detectors/learned_valid_detector.py index 14bddba..529ecb8 100644 --- a/agent/valid_detectors/learned_valid_detector.py +++ b/agent/valid_detectors/learned_valid_detector.py @@ -1,21 +1,25 @@ import os, sys -import fastText +import fasttext sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from valid_detector import ValidDetector import gv import util +# Monkey patch to remove warning message in fasttext 0.9.2. +# Ref: https://github.com/facebookresearch/fastText/issues/1067 +fasttext.FastText.eprint = lambda x: None + model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "valid_model.bin") class LearnedValidDetector(ValidDetector): """ - Uses a fastText classifier to predict the validity of the response text. + Uses a fasttext classifier to predict the validity of the response text. """ def __init__(self): super().__init__() - self.model = fastText.load_model(model_path) + self.model = fasttext.load_model(model_path) def action_valid(self, action, response_text): if not util.action_recognized(action, response_text): diff --git a/requirements.txt b/requirements.txt index bf5a8be..58cbba0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,29 +1,6 @@ -certifi==2018.10.15 -chardet==3.0.4 -cymem==2.0.2 -cytoolz==0.9.0.1 -dill==0.2.8.2 -en-core-web-sm==2.0.0 -fasttext==0.8.22 -fuzzywuzzy==0.17.0 -idna==2.7 -jericho==1.1.1 -msgpack==0.5.6 -msgpack-numpy==0.4.3.2 -murmurhash==1.0.1 -numpy==1.15.3 -pkg-resources==0.0.0 -plac==0.9.6 -preshed==2.0.1 -pybind11==2.2.4 -python-Levenshtein==0.12.0 -regex==2018.1.10 -requests==2.20.0 -six==1.11.0 -spacy==2.0.16 -thinc==6.12.0 -toolz==0.9.0 -tqdm==4.28.1 -ujson==1.35 -urllib3==1.24 -wrapt==1.10.11 +spacy==3.1.1 +fasttext==0.9.2 +jericho==3.0.5 +numpy==1.20.3 +python-Levenshtein==0.12.2 +fuzzywuzzy==0.18.0 \ No newline at end of file diff --git a/run_nail_agent.py b/run_nail_agent.py index 6f790a8..96b314e 100644 --- a/run_nail_agent.py +++ b/run_nail_agent.py @@ -28,7 +28,10 @@ def main(): agent = NailAgent(seed=args.seed, env=env, rom_name=os.path.basename(args.game)) # Get the first observation from the environment. - obs = env.reset() + obs, info = env.reset() + + # Keep track of the maximum score achieved. + max_score = 0 # Run the agent on the environment for the specified number of steps. for step_num in range(args.steps): @@ -36,23 +39,26 @@ def main(): action = agent.take_action(obs) # Pass the action to the environment. - new_obs, score, done, info = env.step(action) + new_obs, reward, done, info = env.step(action) + max_score = max(info["score"], max_score) # Update the agent. - agent.observe(obs, action, score, new_obs, done) + agent.observe(obs, action, reward, new_obs, done) obs = new_obs # Output this step. - print("Step {} Action [{}] Score {}\n{}".format(step_num, action, score, obs)) + print("Step {} Action [{}] Score {}\n{}".format(step_num, action, info["score"], obs)) # Check for done (such as on death). if done: print("Environment returned done=True. So reset the environment.\n") - obs = env.reset() + obs, info = env.reset() # Clean up the agent. agent.finalize() + print(f"Max score achieved: {max_score}") + if __name__ == "__main__": main()